时间序列预测(九)——门控循环单元网络(GRU)

目录

一、GRU结构

二、GRU核心思想

1、更新门(Update Gate):决定了当前时刻隐藏状态中旧状态和新候选状态的混合比例。

2、重置门(Reset Gate):用于控制前一时刻隐藏状态对当前候选隐藏状态的影响程度。

3、候选隐藏状态(Candidate Hidden State):生成当前隐藏状态的候选值

三、GRU 分步演练

1、输入与初始化:

2、计算重置门:

3、计算候选隐藏状态:

4、计算更新门:

5、计算当前隐藏状态:

四、代码实现

1、任务:

2、做法:

3、主要修改点:

4、具体代码:

5、结果


GRU是一种循环神经网络(RNN)的变体,由Cho等人在2014年提出。相比于传统的RNN,GRU引入了门控机制,可以通过该机制来确定应该何时更新隐状态,以及应该何时重置隐状态,使得网络能够更好地捕捉长期依赖性,同时减少了梯度消失的问题。

一、GRU结构

GRU的结构和基础的RNN相比,并没有特别大的不同,都是一种重复神经网络模块的链式结构,由输入层、隐藏层和输出层组成,其中隐藏层是其核心部分,包含了门控机制相关的计算单元

二、GRU核心思想

与LSTM不同,GRU没有细胞状态,而是直接使用隐藏状态。GRU由两个门控制:更新门(Update Gate)和重置门(Reset Gate)。

1、更新门(Update Gate):决定了当前时刻隐藏状态中旧状态和新候选状态的混合比例。
2、重置门(Reset Gate):用于控制前一时刻隐藏状态对当前候选隐藏状态的影响程度。

补充:

3、候选隐藏状态(Candidate Hidden State):生成当前隐藏状态的候选值

三、GRU 分步演练

1、输入与初始化
  • 假设我们有一个输入序列 X=[x1​,x2​,...,xT​],其中 xt​ 是第 t 个时间步的输入。
  • 初始化隐藏状态 h0​,通常为零向量或随机初始化。
2、计算重置门
  • 重置门 rt​ 决定了前一时间步的隐藏状态 ht−1​ 对当前候选隐藏状态 h~t​ 的影响程度。                                      其中 σ 是sigmoid函数,Wr​ 和 Ur​ 是可训练的权重矩阵。
3、计算候选隐藏状态
  • 使用重置门 rt​ 来控制前一时间步的隐藏状态 ht−1​ 的影响。                       其中 ⊙ 表示元素乘法,tanh 是双曲正切函数,W 和 U 是可训练的权重矩阵。
4、计算更新门
  • 更新门 zt​ 决定了当前隐藏状态 ht​ 应该保留多少前一时间步的隐藏状态 ht−1​ 和多少当前候选隐藏状态 h~t​。            其中 Wz​ 和 Uz​ 是可训练的权重矩阵。
5、计算当前隐藏状态
  • 使用更新门 zt​ 来组合前一时间步的隐藏状态 ht−1​ 和当前候选隐藏状态 h~t​。

四、代码实现

1、任务:

根据一个包含道路曲率(Curvature)、车速(Velocity)、侧向加速度(Ay)和方向盘转角(Steering_Angle)真实的数据集,去预测未来的方向盘转角。

2、做法:

提取前5个历史曲率、速度、方向盘转角作为输入特征,同时添加后5个未来曲率(由于车辆的预瞄距离)。目标输出为未来5个方向盘转角。采用GRU网络训练。

3、主要修改点:
  1. 模型定义:将 LSTM 替换为 GRU,并更新模型类名为 GRUModel
  2. 前向传播forward 方法中相应地使用 GRU 的输出。
4、具体代码:
# GRU 模型
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error as mae, r2_score
import matplotlib.pyplot as plt# 1. 数据预处理
# 读取数据
data = pd.read_excel('input_data_20241010160240.xlsx')  # 替换为你的数据文件路径  # 提取特征和标签
curvature = data['Curvature'].values
velocity = data['Velocity'].values
steering = data['Steering_Angle'].values# 定义历史和未来的窗口大小
history_size = 5
future_size = 5features = []
labels = []
for i in range(history_size, len(data) - future_size):# 提取前5个历史的曲率、速度和方向盘转角history_curvature = curvature[i - history_size:i]history_velocity = velocity[i - history_size:i]history_steering = steering[i - history_size:i]# 提取后5个未来的曲率(用于预测)future_curvature = curvature[i:i + future_size]# 输入特征:历史 + 未来曲率feature = np.hstack((history_curvature, history_velocity, history_steering, future_curvature))features.append(feature)# 输出标签:未来5个方向盘转角label = steering[i:i + future_size]labels.append(label)# 转换为 NumPy 数组
features = np.array(features)
labels = np.array(labels)# 归一化
scaler_x = StandardScaler()
scaler_y = StandardScaler()features = scaler_x.fit_transform(features)
labels = scaler_y.fit_transform(labels)# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.05)# 将特征转换为三维张量,形状为 [样本数, 时间序列长度, 特征数]
input_feature_size = history_size * 3 + future_size  # 历史曲率、速度、方向盘转角 + 未来曲率
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).view(-1, 1, input_feature_size)  # [batch_size, seq_len=1, input_size]
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, future_size)  # 输出未来的5个方向盘转角
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).view(-1, 1, input_feature_size)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, future_size)# 2. 创建GRU模型
class GRUModel(nn.Module):def __init__(self, input_size, hidden_size, num_layers, output_size):super(GRUModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)  # 使用GRUself.fc = nn.Linear(hidden_size, output_size)  # 输出层def forward(self, x):# 前向传播out, _ = self.gru(x)  # GRU输出out = self.fc(out[:, -1, :])  # 只取最后一个时间步的输出return out# 实例化模型
input_size = input_feature_size  # 输入特征数
hidden_size = 64  # 隐藏层大小
num_layers = 2  # GRU层数
output_size = future_size  # 输出5个未来方向盘转角
model = GRUModel(input_size, hidden_size, num_layers, output_size)# 3. 设置损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 4. 训练模型
num_epochs = 1000
for epoch in range(num_epochs):model.train()# 前向传播outputs = model(x_train_tensor)loss = criterion(outputs, y_train_tensor)# 后向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 5. 预测
model.eval()
with torch.no_grad():y_pred_tensor = model(x_test_tensor)y_pred = scaler_y.inverse_transform(y_pred_tensor.numpy())  # 将预测值逆归一化
y_test = scaler_y.inverse_transform(y_test_tensor.numpy())  # 逆归一化真实值# 评估指标
r2 = r2_score(y_test, y_pred, multioutput='uniform_average')  # 多维输出下的R^2
mae_score = mae(y_test, y_pred)
print(f"R^2 score: {r2:.4f}")
print(f"MAE: {mae_score:.4f}")# 支持中文
plt.rcParams['font.sans-serif'] = ['SimSun']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号# 绘制未来5个方向盘转角的预测和真实值对比
plt.figure(figsize=(10, 6))
for i in range(future_size):plt.plot(range(len(y_test)), y_test[:, i], label=f'真实值 {i+1} 步', color='blue')plt.plot(range(len(y_pred)), y_pred[:, i], label=f'预测值 {i+1} 步', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle')
plt.title('未来5个方向盘转角的实际值与预测值对比图')
plt.legend()
plt.grid(True)
plt.show()# 计算预测和真实方向盘转角的平均值
y_pred_mean = np.mean(y_pred, axis=1)  # 每个样本的5个预测值取平均
y_test_mean = np.mean(y_test, axis=1)  # 每个样本的5个真实值取平均# 绘制平均值的实际值与预测值对比图
plt.figure(figsize=(10, 6))
plt.plot(range(len(y_test_mean)), y_test_mean, label='真实值(平均)', color='blue')
plt.plot(range(len(y_pred_mean)), y_pred_mean, label='预测值(平均)', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle (平均)')
plt.title('未来5个方向盘转角的平均值对比图')
plt.legend()
plt.grid(True)
plt.show()# 绘制第1个时间步的实际值与预测值对比图
plt.figure(figsize=(10, 6))
plt.plot(range(len(y_test)), y_test[:, 0], label='真实值 (第1步)', color='blue')
plt.plot(range(len(y_pred)), y_pred[:, 0], label='预测值 (第1步)', color='red')
plt.xlabel('样本索引')
plt.ylabel('Steering Angle')
plt.title('未来第1步方向盘转角的实际值与预测值对比图')
plt.legend()
plt.grid(True)
plt.show()# 计算每个时间步的平均绝对误差
time_steps = y_test.shape[1]
mae_per_step = [mae(y_test[:, i], y_pred[:, i]) for i in range(time_steps)]# 绘制每个时间步的平均绝对误差
plt.figure(figsize=(10, 6))
plt.bar(range(1, time_steps + 1), mae_per_step, color='orange')
plt.xlabel('时间步')
plt.ylabel('MAE')
plt.title('不同时间步的平均绝对误差')
plt.grid(True)
plt.show()
5、结果

五、总结

GRU是LSTM的简化版本,减少了门的数量,使得训练和推理速度更快。它在许多序列建模任务中表现良好,适用于时间序列预测、自然语言处理等领域。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/882972.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java项目-基于springboot框架的智慧外贸系统项目实战(附源码+文档)

作者:计算机学长阿伟 开发技术:SpringBoot、SSM、Vue、MySQL、ElementUI等,“文末源码”。 开发运行环境 开发语言:Java数据库:MySQL技术:SpringBoot、Vue、Mybaits Plus、ELementUI工具:IDEA/…

小新学习K8s第一天之K8s基础概念

目录 一、Kubernetes(K8s)概述 1.1、什么是K8s 1.2、K8s的作用 1.3、K8s的功能 二、K8s的特性 2.1、弹性伸缩 2.2、自我修复 2.3、服务发现和负载均衡 2.4、自动发布(默认滚动发布模式)和回滚 2.5、集中化配置管理和密钥…

高效改进!防止DataX从HDFS导入关系型数据库丢数据

高效改进!防止DataX从HDFS导入关系型数据库丢数据 针对DataX在从HDFS导入数据到关系型数据库过程中的数据丢失问题,优化了分片处理代码。改动包括将之前单一分片处理逻辑重构为循环处理所有分片,确保了每个分片数据都得到全面读取和传输&…

Python 实现 excel 数据过滤

一、场景分析 假设有如下一份 excel 数据 shop.xlsx, 写一段 python 程序,实现对于车牌的分组数据过滤。 并以车牌为文件名,把店名输出到 车牌.txt 文件中。 比如 闽A.txt 文件内容为: 小林书店福州店1 小林书店福州店2 二、依赖安装 程序依…

TBWeb正式稳定版V3.4.0+AI+MJ绘画+免授权无后门+详细安装教程

TBWeb正式稳定版V3.4.0AIMJ绘画免授权无后门详细安装教程; 运行环境 Nginx1.22 PHP5.7 MySQL7.4 Redis7.0 Node.js(16.19.1) PM2管理器5.6 TBWeb系统是基于 NineAI 二开的可商业化 TB Web 应用(免授权,无后门&a…

【隐私计算】隐语HEU同态加密算法解读

HEU: 一个高性能的同态加密算法库,提供了多种 PHE 算法, 包括ZPaillier、FPaillier、IPCL、Damgard Jurik、DGK、OU、EC ElGamal 以及基于FPGA和GPU硬件加速版本的Paillier版本。 本文我们会基于GPU运行HEU Docker容器,编译打包GPaillier并测…

算法的学习笔记—两个链表的第一个公共结点(牛客JZ52)

😀前言 在链表问题中,寻找两个链表的第一个公共结点是一个经典问题。这个问题的本质是在两个单链表中找到它们的相交点,或者说它们开始共享相同节点的地方。本文将详细讲解这个问题的解题思路,并提供一种高效的解决方法。 &#x…

蓝牙资讯|iOS 18.1 正式版下周推送,AirPods Pro 2耳机将带来助听器功能

苹果公司宣布将在下周发布 iOS 18.1 正式版,同时确认该更新将为 AirPods Pro 2 耳机带来新增“临床级”助听器功能。在启用功能后,用户首先需要使用 AirPods 和 iPhone 进行简短的听力测试,如果检测到听力损失,系统将创建一项“个…

docker run 命令解析

docker run 命令解析 docker run 命令用于从给定的镜像启动一个新的容器。这个命令可以包含许多选项,下面是一些常用的选项: -d:后台运行容器,并返回容器ID;-i:以交互模式运行容器,通常与 -t …

【C++】string类 (模拟实现详解 下)

我们接着上一篇【C】string类 (模拟实现详解 上)-CSDN博客继续对string模拟实现。从这篇内容开始,string相关函数的实现就要声明和定义分离了。 1.reserve、push_back和append 在string.h的string类里进行函数的声明。 void reserve(size_…

JVM(HotSpot):GC之垃圾回收器的分类

文章目录 前言一、串行二、吞吐量优先三、响应时间优先四、常见垃圾回收器使用组合 前言 上一篇,我们学习了分代回收机制 它的主要内容是对JVM内存的一个划分,以及垃圾回收器工作时,区域运作顺序的一个规定。 所以,它是一个规范。…

Spring Boot论坛网站:开发、部署与管理

3系统分析 3.1可行性分析 通过对本论坛网站实行的目的初步调查和分析,提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本论坛网站采用SSM框架,JAVA作为开发语言,是…

智慧楼宇平台,构筑未来智慧城市的基石

随着城市化进程的加速,城市面临着前所未有的挑战。人口密度的增加、资源的紧张、环境的恶化以及对高效能源管理的需求,都在推动着我们寻找更加智能、可持续的城市解决方案。智慧楼宇作为智慧城市建设的重要组成部分,正逐渐成为推动城市可持续…

MATLAB电化学特性评估石墨和锂电

🎯要点 模拟对比石墨电池的放电电压曲线与实验数据定性差异。对比双箔、多相多孔电极理论和锂电有限体积模型实现。通过孔隙电极理论模型了解粗粒平均质量和电荷传输以及孔隙率的表征意义。锂电中锂离子正向和逆向反应速率与驱动力的指数以及电解质和电极表面的锂浓…

Docker 部署 EMQX 一分钟极速部署

部署 EMQX ( Docker ) [Step 1] : 拉取 EMQX 镜像 docker pull emqx/emqx:latest[Step 2] : 创建目录 ➡️ 创建容器 ➡️ 拷贝文件 ➡️ 授权文件 ➡️ 删除容器 # 创建目录 mkdir -p /data/emqx/{etc,data,log}# 创建容器 docker run -d --name emqx -p 1883:1883 -p 1808…

「Qt Widget中文示例指南」如何实现半透明背景?

Qt 是目前最先进、最完整的跨平台C开发工具。它不仅完全实现了一次编写,所有平台无差别运行,更提供了几乎所有开发过程中需要用到的工具。如今,Qt已被运用于超过70个行业、数千家企业,支持数百万设备及应用。 本文将为大家展示如…

《Linux从小白到高手》综合应用篇:深入理解Linux常用关键内核参数及其调优

1. 题记 有关Linux关键内核参数的调整,我前面的调优文章其实就有涉及到,只是比较零散,本篇集中深入介绍Linux常用关键内核参数及其调优,Linux调优80%以上都涉及到内核的这些参数的调整。 2. 文件系统相关参数 fs.file-max 参数…

Excel 对数据进行脱敏

身份证号脱敏:LEFT(A2,6)&REPT("*",6)&RIGHT(A2,6) 手机号脱敏:LEFT(B2,3)&REPT("*",5)&RIGHT(B2,3) 姓名脱敏:LEFT(C2,1)&REPT("*",1)&RIGHT(C2,1) 参考: excel匹配替换…

STM32F103C8T6 IO 操作

1.开启相关时钟 在 STM32 微控制器中,开启 GPIO 端口的时钟是确保 IO 口可以正常工作的第一步。 查找 RCC 寄存器使能时钟 在 STM32 中,时钟控制的寄存器通常位于 RCC (Reset and Clock Control) 模块中。不同的 STM32 系列(如 STM32F1、STM…

【单元测试】深入解剖单元测试的思维逻辑

目录 一、前言二、准备环境三、 常用的mock语句3.1 模拟指定类的对象实例,用于模拟依赖对象(类成员)3.2 定义被测试对象3.3 模拟枚举类型/静态方法3.4 模拟依赖方法3.5 模拟构造方法3.6 验证方法调用次数3.7 验证返回值3.8 验证异常对象 四、…