机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义GRU模型
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.gru(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/690654.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯嵌入式STM32G431RBT6知识点(主观题部分)

目录 1 前置准备 1.1 Keil 1.1.1 编译器版本及微库 1.1.2 添加官方提供的LCD及I2C文件 1.2 CubeMX 1.2.1 时钟树 1.2.2 其他 1.2.3 明确CubeMX路径,放置芯片包 2 GPIO 2.1 实验1:LED1-LED8循环亮灭 ​编辑 2.2 实验2&#xff1a…

Gitlab CI/CD docker命令报错:/usr/bin/bash: line 136: docker:command not found

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

深入实战:ElasticSearch的Rest API与迭代器模式在高效查询中的应用

在我们公司,大多数Java开发工程师在项目中都有使用Elasticsearch的经验。通常,他们会通过引入第三方工具包或使用Elasticsearch Client等方式来进行数据查询。然而,当涉及到基于Elasticsearch Rest API的/_sql?formatjson接口时,…

2 物理层(三):数据传输的方式,同步传输和异步传输

目录 1 数据的传输方式1.1 并行传输1.2 串行传输 2 同步传输和异步传输2.1 同步传输2.2 异步传输2.3 同步和异步传输对比 1 数据的传输方式 在数据通信中,数据传输方式有并行传输和串行传输两种 1.1 并行传输 定义:并行传输是指数据以成组的方式在多个…

NC 输出模板自定义变量使用加减乘除余等公式计算时无法显示结果的问题处理办法

NC 输出模板自定义变量使用加减乘除余等公式计算时无法显示结果的问题处理办法 比如,求两个字段的差,如果这样写,模板打印输出的时候,是不会显示有值的: sub(vouchercreditamount, voucherdebitamount) 或者 voucherc…

picker选择器-年月日选择

从底部弹起的滚动选择器。支持五种选择器,通过mode来区分,分别是普通选择器,多列选择器,时间选择器,日期选择器,省市区选择器,默认是普通选择器。 学习一下日期选择器 平台差异说明 日期选择默…

K8s进阶之路-控制器无状态服务:

RC/RS/Deployment 控制器 deployment无状态(最常用): nginx和Apache statefulset有状态: mysql和redis damonset初始化 job一次性任务 cronjob任务计划 1无状态:不会对本地环境产生依赖如:nginx和Apache …

Kubernetes基础(二十二)-k8s持久化存储详解

1 volume 1.1 介绍 在容器中的磁盘文件是短暂的,当容器崩溃时,Kubelet会重新启动容器,但容器运行时产生的数据文件都将会丢失,之后容器会以最干净的状态启动。另外,当一个Pod运行多个容器时,各个容器可能…

新版Java面试专题视频教程——框架篇

新版Java面试专题视频教程——框架篇 框架篇 01-框架篇介绍02-Spring-单例bean是线程安全的吗03-Spring-AOP相关面试题04-Spring-事务失效的场景05-Spring-bean的生命周期5.1 BeanDefinition 06-Spring-bean的循环依赖(循环引用)6.1 一般对象的循环依…

【C++】类与对象的项目实践 — 日期管理工具

类与对象的实践 项目背景项目需求项目实现1 日期结构设计2 构造函数2.1 全缺省构造函数2.2 拷贝构造函数2.3 析构函数 3 赋值运算符重载3.1 重载3.2 重载重载前置 和 后置 4 关系操作符重载5 工具方法5.1 计算日期差5.2 日期转换为字符串5.3 通过字符串构建对象 完整源代码Dat…

云数贸云生活中心:用云生活理念引领社会和谐发展

在数字经济的浪潮下,云数贸云生活中心不仅在科技进步与文明程度上作出了积极贡献,更在推动社会和谐、承担企业社会责任方面展现出了模范作用。通过与“草根互助爱心社区”的紧密合作,云数贸云生活中心正致力于构建一个更加和谐、互助的社会环…

socket通信 smallchat简介

文章目录 前言一、socket的基本操作(1) socket()函数(2) bind()函数(3) listen()、connect()函数(4) accept()函数(5) read()、write()等函数(6) close()函数 二、smallchat代码流程smallchat-server.csmallchat-client.cchatlib.c 参考资料 前言 本文介绍了socket通信的相关A…

六、图像的几何变换

文章目录 前言一、镜像变换二、缩放变换 前言 在计算机视觉中,图像几何变换是指对图像进行平移、旋转、缩放、仿射变换和镜像变换等操作,以改变图像的位置、尺寸、形状或视角,而不改变图像的内容。这些变换在图像处理、模式识别、机器人视觉…

更改WordPress作者存档链接author和用户名插件Change Author Link Structure

WordPress作者存档链接默认情况为/author/Administrator(用户名),为了防止用户名泄露,我们可以将其改为/author/1(用户ID),具体操作可参考『如何将WordPress作者存档链接中的用户名改为昵称或ID…

猪圈Pigsty-PG私有RDS集群搭建教程

博客 https://songxwn.com/Pigsty-PG-RDS/ 简介 Pigsty 是一个更好的本地自建且开源 RDS for PostgreSQL 替代,具有以下特点: 开箱即用的 PostgreSQL 发行版,深度整合地理、时序、分布式、图、向量、分词、AI等 150 余个扩展插件&#xff…

文件IO的lseek以及目录IO

文件IO之 lseek: 1. lseek off_t lseek(int fd, off_t offset, int whence); 功能: 重新设定文件描述符的偏移量 参数: fd:文件描述符 offset:偏移量 whence: SEEK_SET 文件开头 …

基于scrapy框架的单机爬虫与分布式爬虫

我们知道,对于scrapy框架来说,不仅可以单机构建复杂的爬虫项目,还可以通过简单的修改,将单机版爬虫改为分布式的,大大提高爬取效率。下面我就以一个简单的爬虫案例,介绍一下如何构建一个单机版的爬虫&#…

更快找到远程/自由工作的网站

不要使用Fiver或Upwork。 它们已经饱和了。 下面是10个更快找到远程/自由工作的网站: 1. Toptal 这个网站专门为熟练的自由职业者提供远程工作机会,如Shopify和Priceline等一流公司。 他们只接受软件开发、设计和金融等领域的顶级3%自由职业者。 htt…

2024-02-19(Flume)

1.flume中拦截器的作用:个人认为就是修改或者删除事件中的信息(处理一下事件)。 2.一些拦截器 Host Interceptor,Timestamp Interceptor,Static Interceptor,UUID Interceptor,Search and Rep…

C++集群聊天服务器 nginx+redis安装 笔记 (中)

一、nginx安装 nginx: download 下载nginx安装包 hehedalinux:~/package$ tar -zvxf nginx-1.24.0.tar.gz nginx-1.24.0/ nginx-1.24.0/auto/ nginx-1.24.0/conf/ nginx-1.24.0/contrib/ nginx-1.24.0/src/ nginx-1.24.0/configure nginx-1.24.0/LICENSE nginx-1.24.0/README…