GRU模块:nn.GRU层

摘要: 

       如果需要深入理解GRU的话,内部实现的详细代码和计算公式就比较重要,中间的一些过程及中间变量的意义需要详细关注。只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等。所以,仔细研究这个模块的实现还是非常有必要的。以此类推,对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

       本文中介绍GRU内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即语句:self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。

       先从输入输出的角度看,即代码中的这一行:output, state = self.rnn(X) 。在 GRU(Gated Recurrent Unit)中,outputstate 都是由 GRU 层的循环计算产生的,它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。 

代码示例

class Seq2SeqEncoder(d2l.Encoder):
"""⽤于序列到序列学习的循环神经⽹络编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)# 嵌⼊层self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,dropout=dropout)def forward(self, X, *args):# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X)# 在循环神经⽹络模型中,第⼀个轴对应于时间步X = X.permute(1, 0, 2)# 如果未提及状态,则默认为0output, state = self.rnn(X)# output的形状:(num_steps,batch_size,num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)return output, state

output:在完成所有时间步后,最后⼀层的隐状态的输出output是⼀个张量(output由编码器的循环层返回),其形状为(时间步数,批量⼤⼩,隐藏单元数)。

state:最后⼀个时间步的多层隐状态是state的形状是(隐藏层的数量,批量⼤⼩, 隐藏单元的数量)。

GRU模块的框图 

GRU 的基本公式

GRU 的核心计算包括更新门(update gate)和重置门(reset gate),以及候选隐藏状态(candidate hidden state)。数学表达式如下:

  1. 更新门 \( z_t \): \[ z_t = \sigma(W_z \cdot h_{t-1} + U_z \cdot x_t) \]
       其中,\( \sigma \) 是sigmoid 函数,\( W_z \) 和 \( U_z \) 分别是对应于隐藏状态和输入的权重矩阵,\( h_{t-1} \) 是上一个时间步的隐藏状态,\( x_t \) 是当前时间步的输入。

  2. 重置门 \( r_t \):
       \[ r_t = \sigma(W_r \cdot h_{t-1} + U_r \cdot x_t) \]
       \( W_r \) 和 \( U_r \) 是更新门中定义的相似权重矩阵。

  3. 候选隐藏状态 \( \tilde{h}_t \):
       \[ \tilde{h}_t = \tanh(W \cdot r_t \odot h_{t-1} + U \cdot x_t) \]
       这里,\( \tanh \) 是激活函数,\( \odot \) 表示元素乘法(Hadamard product),\( W \) 和 \( U \) 是隐藏状态的权重矩阵。

  4. 最终隐藏状态 \( h_t \):
       \[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

output 和 state 的关系

  • output:在 GRU 中,output 包含了序列中每个时间步的隐藏状态。具体来说,对于每个时间步 \( t \),output 的第 \( t \) 个元素就是该时间步的隐藏状态 \( h_t \)。

  • state:state 是 GRU 层最后一层的隐藏状态,也就是 output 中最后一个时间步的隐藏状态 \( h_{T-1} \),其中 \( T \) 是序列的长度。

数学表达式

如果我们用 \( O \) 表示 output,\( S \) 表示 state,\( T \) 表示时间步的总数,那么:

\[ O = [h_0, h_1, ..., h_{T-1}] \]
\[ S = h_{T-1} \]

因此,state 实际上是 output 中最后一个元素,即 \( S = O[T-1] \)。

在 PyTorch 中,output 和 state 都是由 GRU 层的 `forward` 方法计算得到的。`output` 是一个三维张量,包含了序列中每个时间步的隐藏状态,而 `state` 是一个二维张量,仅包含最后一个时间步的隐藏状态。

GRU的内部实现

上面一节的代码示例,是通过调用API实现的,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)。那么,GRU内部是如何实现的呢?

分为模型、模型参数初始化和隐状态初始化三个部分:

模型定义(模型定义与数学表示式一致,也可以参考上图):

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

  模型参数初始化(权重是从标准差0.01的高斯分布中提取的,超参数num_hiddens定义隐藏单元的数量,偏置项设置为0):

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three() # 更新⻔参数W_xr, W_hr, b_r = three() # 重置⻔参数W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)
# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

隐状态初始化函数(此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值都为0

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

最后由一个函数统一起来,实现模型:

model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)

小结 

       总体上说,内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。但是,如果需要深入理解GRU的话,那么内部实现的详细代码和计算公式就比较重要,中间的一些过程和变量的意义需要详细关注,只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等,所以,仔细研究这个模块的实现还是非常有必要的。对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835231.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[力扣题解]455. 分发饼干

题目&#xff1a;455. 分发饼干 思路 贪心法 代码 class Solution { public:int findContentChildren(vector<int>& g, vector<int>& s) {int cookie_i, belly_i, result 0;// 满足的小孩数 : result// 胃口: g// 饼干: ssort(g.begin(), g.end());so…

QML及VTK配合构建类MVVM模式DEMO

1 创建QT QUICK项目 这次我们不在主程中加载VTK的几何&#xff1b; 在qml建立的控件&#xff0c;创建MyVtkObject类的单例&#xff0c;main中将指针和单例挂钩&#xff1b; 在MyVtkObject实例中操作 QQuickVTKRenderItem 类即可&#xff1b; 由于VTK的opengl显示是状态机&a…

零成本零门槛蓝海副业:一单利润高,适合新手操作!

零成本零门槛蓝海副业&#xff1a;一单利润高&#xff0c;适合新手操作&#xff01; 在这个信息化时代&#xff0c;随着互联网的发展&#xff0c;越来越多的人开始关注如何通过副业赚取额外的收入。而零成本、零门槛的副业项目更是备受关注&#xff0c;因为它们不仅容易上手&a…

学习笔记:【QC】Android Q telephony-data 模块

一、data init 流程图 主要分为3部分&#xff1a; 1.加载TelephonyProvider&#xff0c;解析apns-config.xml文件&#xff0c;调用loadApns将 xml中定义的数据&#xff0c;插入到TelephonyProvider底层的数据库中 2.初始化phone、DcTracker、TelephonyNetworkFactory、Conne…

Ubuntu意外断电vmdk损坏--打不开磁盘“***.vmdk”或它所依赖的某个快照磁盘。

背景&#xff1a;电脑资源管理器崩溃卡死&#xff0c;强行断电重启&#xff0c;结果虚拟机打不开了&#xff0c;提示打不开磁盘“***.vmdk”或它所依赖的某个快照磁盘。 删除lck文件&#xff1a;失败vmware-vdiskmanager修复 &#xff1a;提示无法修复最终用 VMFS Recovery挂载…

ntfs文件系统的优势 NTFS文件系统的特性有哪些 ntfs和fat32有什么区别 苹果电脑怎么管理硬盘

对于数码科技宅在新购得磁盘之后&#xff0c;出于某种原因会在新的磁盘安装操作系统。在安装操作系统时&#xff0c;首先要对磁盘进行分区和格式化&#xff0c;而在此过程中&#xff0c;操作者们需要选择文件系统。文件系统也决定了之后操作的流程程度&#xff0c;一般文件系统…

CSS-页面导航栏实现-每文一言(过有意义的生活,做最好的自己)

&#x1f390;每文一言 过有意义的生活,做最好的自己 目录 &#x1f390;每文一言 &#x1f6d2;盒子模型 &#x1f453;外间距 (margin) &#x1f97c;边框 &#x1f45c;内边距 切换盒子模型计算方案&#xff1a; &#x1f3a2; 浮动布局 浮动特点 &#x1f3c6;导航…

通俗的理解网关的概念的用途(四):什么是网关设备?(网络层面)

任何一台Windows XP操作系统之后的个人电脑、Linux操作系统电脑都可以简单的设置&#xff0c;就可以成为一台具备“网关”性质的设备&#xff0c;因为它们都直接内置了其中的实现程序。MacOS有没有就不知道&#xff0c;因为没用过。 简单的理解&#xff0c;就是运行了具备第二…

MATLAB雨刮通风空调模糊器和发电厂电力聚变器卷积神经

&#x1f3af;要点 &#x1f3af;状态估计和控制&#xff1a;&#x1f58a;线性卡尔曼滤波弹簧-质量-阻尼器系统状态估计&#xff1a;定义数学差分方程模型、从贝叶斯定律导出卡尔曼滤波器算法 | &#x1f58a;扩展卡尔曼滤波非线性角度测量追踪阻尼振动器运动&#xff1a;定义…

关于JAVA-JSP电子政务网实现参考论文(论文 + 源码)

【免费】关于JAVA-JSP电子政务网.zip资源-CSDN文库https://download.csdn.net/download/JW_559/89292355关于JAVA-JSP电子政务网 摘 要 当前阶段&#xff0c;伴随着社会信息技术的快速发展&#xff0c;使得电子政务能够成为我国政府职能部门进行办公管理的一个重要内容&#x…

【PyTorch单点知识】深入理解与应用转置卷积ConvTranspose2d模块

文章目录 0. 前言1. 转置卷积概述2. nn.ConvTranspose2d 模块详解2.1 主要参数2.2 属性与方法 3. 计算过程&#xff08;重点&#xff09;3.1 基本过程3.2 调整stride3.3 调整dilation3.4 调整padding3.5 调整output_padding 4. 应用实例5. 总结 0. 前言 按照国际惯例&#xff0…

Docker快速搭建NAS服务——FileBrowser

Docker快速搭建NAS服务——FileBrowser 文章目录 前言FileBrowser的搭建docker-compose文件编写运行及访问 总结 前言 本文主要讲解如何使用docker在本地快速搭建NAS服务&#xff0c;这里主要写如下两种&#xff1a; FileBrowser1&#xff1a;是一个开源的Web文件管理器&…

运行SpringBoot项目失败?异常显示Can‘t load IA 32-bit .dll on a AMD 64-bit platform,让我来看看~

原因是&#xff0c;我放入jdk的bin文件夹下的tcnative-1.dll文件是32位的&#xff0c;那么肯定是无法在AMD 64位平台上加载IA 32位.dll。但是网站上给出的都是32位呀&#xff0c;没有64位怎么办&#xff1a; 其实当我们把“tomcat-native-1.2.34-openssl-1.1.1o-win32-bin.zip”…

机器学习-如何为模型选择评估指标?

为机器学习模型选择评估指标是一个关键步骤&#xff0c;因为它直接关联到如何衡量模型的性能。以下是选择评估指标的一些建议&#xff1a; 1、理解问题类型&#xff1a; 分类问题&#xff1a;对于二分类问题&#xff0c;常见的评估指标包括准确率、精确率、召回率、F1分数、R…

对多重继承关系的父子抽象类中子类的方法进行测试时如何回避Mock父类中的Protected方法

标题的说法就比较绕口&#xff0c;但是这个具体的问题大家看了下面内容就明白了。 如果在自己工作中遇到类似问题时可以试试这个解决办法。如果您技术好的话&#xff0c;其实不仔细看也行的&#xff0c;哈哈。 假设你有以下的类结构&#xff0c;该如何使用junit5,cdi-unit,moc…

以无侵方式实现Deployment原地升级

如何以无侵方式实现Deployment原地升级&#xff1f; 本文将展示如何以无侵、原生的方式实现Deployment原地升级。 在文章末尾会提供shell脚本供大家参考。 本文的原地升级仅指镜像更新 本篇kubernetes版本为v1.27.3。 原地升级的概念以及OpenKruise的实现方式可以参考文章&a…

Linux下GraspNet复现流程

Linux&#xff0c;Ubuntu中GraspNet复现流程 文章目录 Linux&#xff0c;Ubuntu中GraspNet复现流程1.安装cuda和cudnn2.安装pytorch3.编译graspnetAPIReference &#x1f680;非常重要的环境配置&#x1f680; ubuntu 20.04cuda 11.0.1cudnn v8.9.7python 3.8.19pytorch 1.7.0…

十大排序算法(java实现)

注&#xff1a;本篇仅用来自己学习&#xff0c;大量内容来自菜鸟教程&#xff08;地址&#xff1a;1.0 十大经典排序算法 | 菜鸟教程&#xff09; 排序算法可以分为内部排序和外部排序&#xff0c;内部排序是数据记录在内存中进行排序&#xff0c;而外部排序是因排序的数据很大…

Microsoft Edge浏览器,便携增强版 v118.0.5993.69

01 软件介绍 Microsoft Edge浏览器&#xff0c;便携增强版&#xff0c;旨在无需更新组件的情况下提供额外的功能强化。这一增强版专注于优化用户体验和系统兼容性&#xff0c;具体包含以下核心功能的提升&#xff1a; 数据保存&#xff1a;通过优化算法增强了其数据保存能力&…

1707jsp电影视频网站系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 校园商城派送系统 是一套完善的web设计系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统采用web模式&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数…