【深度学习笔记】6_7 门控循环单元(GRU)

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.7 门控循环单元(GRU)

上一节介绍了循环神经网络中的梯度计算方法。我们发现,当时间步数较大或者时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减的问题。通常由于这个原因,循环神经网络在实际中较难捕捉时间序列中时间步距离较大的依赖关系。

门控循环神经网络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较大的依赖关系。它通过可以学习的门来控制信息的流动。其中,门控循环单元(gated recurrent unit,GRU)是一种常用的门控循环神经网络 [1, 2]。另一种常用的门控循环神经网络则将在下一节中介绍。

6.7.1 门控循环单元

下面将介绍门控循环单元的设计。它引入了重置门(reset gate)和更新门(update gate)的概念,从而修改了循环神经网络中隐藏状态的计算方式。

6.7.1.1 重置门和更新门

如图6.4所示,门控循环单元中的重置门和更新门的输入均为当前时间步输入 X t \boldsymbol{X}_t Xt与上一时间步隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1,输出由激活函数为sigmoid函数的全连接层计算得到。

在这里插入图片描述

图6.4 门控循环单元中重置门和更新门的计算

具体来说,假设隐藏单元个数为 h h h,给定时间步 t t t的小批量输入 X t ∈ R n × d \boldsymbol{X}_t \in \mathbb{R}^{n \times d} XtRn×d(样本数为 n n n,输入个数为 d d d)和上一时间步隐藏状态 H t − 1 ∈ R n × h \boldsymbol{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h。重置门 R t ∈ R n × h \boldsymbol{R}_t \in \mathbb{R}^{n \times h} RtRn×h和更新门 Z t ∈ R n × h \boldsymbol{Z}_t \in \mathbb{R}^{n \times h} ZtRn×h的计算如下:

R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \boldsymbol{R}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xr} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hr} + \boldsymbol{b}_r),\\ \boldsymbol{Z}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xz} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hz} + \boldsymbol{b}_z), \end{aligned} Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),

其中 W x r , W x z ∈ R d × h \boldsymbol{W}_{xr}, \boldsymbol{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h \boldsymbol{W}_{hr}, \boldsymbol{W}_{hz} \in \mathbb{R}^{h \times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h \boldsymbol{b}_r, \boldsymbol{b}_z \in \mathbb{R}^{1 \times h} br,bzR1×h是偏差参数。3.8节(多层感知机)节中介绍过,sigmoid函数可以将元素的值变换到0和1之间。因此,重置门 R t \boldsymbol{R}_t Rt和更新门 Z t \boldsymbol{Z}_t Zt中每个元素的值域都是 [ 0 , 1 ] [0, 1] [0,1]

6.7.1.2 候选隐藏状态

接下来,门控循环单元将计算候选隐藏状态来辅助稍后的隐藏状态计算。如图6.5所示,我们将当前时间步重置门的输出与上一时间步隐藏状态做按元素乘法(符号为 ⊙ \odot )。如果重置门中元素值接近0,那么意味着重置对应隐藏状态元素为0,即丢弃上一时间步的隐藏状态。如果元素值接近1,那么表示保留上一时间步的隐藏状态。然后,将按元素乘法的结果与当前时间步的输入连结,再通过含激活函数tanh的全连接层计算出候选隐藏状态,其所有元素的值域为 [ − 1 , 1 ] [-1, 1] [1,1]

在这里插入图片描述

图6.5 门控循环单元中候选隐藏状态的计算

具体来说,时间步 t t t的候选隐藏状态 H ~ t ∈ R n × h \tilde{\boldsymbol{H}}_t \in \mathbb{R}^{n \times h} H~tRn×h的计算为

H ~ t = tanh ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\boldsymbol{H}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}_{xh} + \left(\boldsymbol{R}_t \odot \boldsymbol{H}_{t-1}\right) \boldsymbol{W}_{hh} + \boldsymbol{b}_h), H~t=tanh(XtWxh+(RtHt1)Whh+bh),

其中 W x h ∈ R d × h \boldsymbol{W}_{xh} \in \mathbb{R}^{d \times h} WxhRd×h W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是权重参数, b h ∈ R 1 × h \boldsymbol{b}_h \in \mathbb{R}^{1 \times h} bhR1×h是偏差参数。从上面这个公式可以看出,重置门控制了上一时间步的隐藏状态如何流入当前时间步的候选隐藏状态。而上一时间步的隐藏状态可能包含了时间序列截至上一时间步的全部历史信息。因此,重置门可以用来丢弃与预测无关的历史信息。

6.7.1.3 隐藏状态

最后,时间步 t t t的隐藏状态 H t ∈ R n × h \boldsymbol{H}_t \in \mathbb{R}^{n \times h} HtRn×h的计算使用当前时间步的更新门 Z t \boldsymbol{Z}_t Zt来对上一时间步的隐藏状态 H t − 1 \boldsymbol{H}_{t-1} Ht1和当前时间步的候选隐藏状态 H ~ t \tilde{\boldsymbol{H}}_t H~t做组合:

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \boldsymbol{H}_t = \boldsymbol{Z}_t \odot \boldsymbol{H}_{t-1} + (1 - \boldsymbol{Z}_t) \odot \tilde{\boldsymbol{H}}_t. Ht=ZtHt1+(1Zt)H~t.

在这里插入图片描述

图6.6 门控循环单元中隐藏状态的计算

值得注意的是,更新门可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,如图6.6所示。假设更新门在时间步 t ′ t' t t t t t ′ < t t' < t t<t)之间一直近似1。那么,在时间步 t ′ t' t t t t之间的输入信息几乎没有流入时间步 t t t的隐藏状态 H t \boldsymbol{H}_t Ht。实际上,这可以看作是较早时刻的隐藏状态 H t ′ − 1 \boldsymbol{H}_{t'-1} Ht1一直通过时间保存并传递至当前时间步 t t t。这个设计可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较大的依赖关系。

我们对门控循环单元的设计稍作总结:

  • 重置门有助于捕捉时间序列里短期的依赖关系;
  • 更新门有助于捕捉时间序列里长期的依赖关系。

6.7.2 读取数据集

为了实现并展示门控循环单元,下面依然使用周杰伦歌词数据集来训练模型作词。这里除门控循环单元以外的实现已在6.2节(循环神经网络)中介绍过。以下为读取数据集部分。

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

6.7.3 从零开始实现

我们先介绍如何从零开始实现门控循环单元。

6.7.3.1 初始化模型参数

下面的代码对模型参数进行初始化。超参数num_hiddens定义了隐藏单元的个数。

num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)def get_params():def _one(shape):ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)return torch.nn.Parameter(ts, requires_grad=True)def _three():return (_one((num_inputs, num_hiddens)),_one((num_hiddens, num_hiddens)),torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))W_xz, W_hz, b_z = _three()  # 更新门参数W_xr, W_hr, b_r = _three()  # 重置门参数W_xh, W_hh, b_h = _three()  # 候选隐藏状态参数# 输出层参数W_hq = _one((num_hiddens, num_outputs))b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])

6.7.3.2 定义模型

下面的代码定义隐藏状态初始化函数init_gru_state。同6.4节(循环神经网络的从零开始实现)中定义的init_rnn_state函数一样,它返回由一个形状为(批量大小, 隐藏单元个数)的值为0的Tensor组成的元组。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

下面根据门控循环单元的计算表达式定义模型。

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)H_tilda = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(R * H, W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = torch.matmul(H, W_hq) + b_qoutputs.append(Y)return outputs, (H,)

6.7.3.3 训练模型并创作歌词

我们在训练模型时只使用相邻采样。设置好超参数后,我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。

num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

我们每过40个迭代周期便根据当前训练的模型创作一段歌词。

d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,vocab_size, device, corpus_indices, idx_to_char,char_to_idx, False, num_epochs, num_steps, lr,clipping_theta, batch_size, pred_period, pred_len,prefixes)

输出:

epoch 40, perplexity 149.477598, time 1.08 sec- 分开 我不不你 我想你你的爱我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你- 不分开 我想你你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我 你不你的让我
epoch 80, perplexity 31.689210, time 1.10 sec- 分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不- 不分开 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
epoch 120, perplexity 4.866115, time 1.08 sec- 分开 我想要这样牵着你的手不放开 爱过 让我来的肩膀 一起好酒 你来了这节秋 后知后觉 我该好好生活 我- 不分开 你已经不了我不要 我不要再想你 我不要再想你 我不要再想你 不知不觉 我跟了这节奏 后知后觉 又过
epoch 160, perplexity 1.442282, time 1.51 sec- 分开 我一定好生忧 唱着歌 一直走 我想就这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的- 不分开 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活

6.7.4 简洁实现

在PyTorch中我们直接调用nn模块中的GRU类即可。

lr = 1e-2 # 注意调整学习率
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes)

输出:

epoch 40, perplexity 1.022157, time 1.02 sec- 分开手牵手 一步两步三步四步望著天 看星星 一颗两颗三颗四颗 连成线背著背默默许下心愿 看远方的星是否听- 不分开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不能 爱情走的太快就像龙卷风 不能承受我已无处
epoch 80, perplexity 1.014535, time 1.04 sec- 分开始想像 爸和妈当年的模样 说著一口吴侬软语的姑娘缓缓走过外滩 消失的 旧时光 一九四三 在回忆 的路- 不分开始爱像  不知不觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好
epoch 120, perplexity 1.147843, time 1.04 sec- 分开都靠我 你拿着球不投 又不会掩护我 选你这种队友 瞎透了我 说你说 分数怎么停留 所有回忆对着我进攻- 不分开球我有多烦恼多 牧草有没有危险 一场梦 我面对我 甩开球我满腔的怒火 我想揍你已经很久 别想躲 说你
epoch 160, perplexity 1.018370, time 1.05 sec- 分开爱上你 那场悲剧 是你完美演出的一场戏 宁愿心碎哭泣 再狠狠忘记 你爱过我的证据 让晶莹的泪滴 闪烁- 不分开始 担心今天的你过得好不好 整个画面是你 想你想的睡不著 嘴嘟嘟那可爱的模样 还有在你身上香香的味道

小结

  • 门控循环神经网络可以更好地捕捉时间序列中时间步距离较大的依赖关系。
  • 门控循环单元引入了门的概念,从而修改了循环神经网络中隐藏状态的计算方式。它包括重置门、更新门、候选隐藏状态和隐藏状态。
  • 重置门有助于捕捉时间序列里短期的依赖关系。
  • 更新门有助于捕捉时间序列里长期的依赖关系。

参考文献

[1] Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: Encoder-decoder approaches. arXiv preprint arXiv:1409.1259.

[2] Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/732450.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Objective -- C】—— 自引用计数

【Objective -- C】—— 自引用计数 一. 内存管理/自引用计数1.自引用计数2.内存管理的思考方式自己生成的对象&#xff0c;自己持有非自己生成的对象&#xff0c;自己也能持有不再需要自己持有的对象时释放无法释放非自己持有的对象 3.alloc/retain/release/dealloc实现4. aut…

全链路Python环境迁移

全链路Python环境迁移 在当前的Python环境中&#xff0c;安装一些库以后&#xff0c;如果换了一套Python环境&#xff0c;难道再来一次不停的pip install&#xff1f;当然不是。 第一步&#xff0c;使用pip freeze&#xff08;冻结&#xff09;备份当前Python库的环境 pip f…

智慧公厕系统的运作过程

智慧公厕是一种新型的未来城市公共厕所&#xff0c;通过物联网、互联网、大数据、云计算、自动化控制等技术&#xff0c;实现公共厕所使用、运营、管理、养护的全过程全方位信息化。 那么&#xff0c;智慧公厕是如何运作的&#xff1f;智慧公厕的运作过程包括什么技术&#xf…

【Pytorch、torchvision、CUDA 各个版本对应关系以及安装指令】

Pytorch、torchvision、CUDA 各个版本对应关系以及安装指令 1、名词解释 1.1 CUDA CUDA&#xff08;Compute Unified Device Architecture&#xff09;是由NVIDIA开发的用于并行计算的平台和编程模型。CUDA旨在利用NVIDIA GPU&#xff08;图形处理单元&#xff09;的强大计算…

使用R语言进行聚类分析

一、样本数据描述 城镇居民人均消费支出水平包括食品、衣着、居住、生活用品及服务、通信、文教娱乐、医疗保健和其他用品及服务支出这八项指标来描述。表中列出了2016年我国分地区的城镇居民的人均消费支出的原始数据&#xff0c;数据来源于2017年的《中国统计年鉴》&#xf…

Publii和GitHub:搭建个人网站的完美组合

在数字时代&#xff0c;拥有一个个人网站已经非常普遍了&#xff0c;但是&#xff0c;很多人因为技术难题而望而却步。现在&#xff0c;有了Publii&#xff0c;这一切都将变得简单。Publii是一个静态网站生成器&#xff0c;它允许你在本地计算机上创建和管理内容&#xff0c;然…

ARM中汇编语言的学习(加法、乘法、除法、左移、右移、按位与等多种命令操作实例以及ARM的 N、Z、C、V 标志位的解释)

汇编概述 汇编需要学习的大致框架如下&#xff1a; 汇编中的符号 1.指令&#xff1b;能够北嘁肷梢惶?2bit机器码&#xff0c;并且能够被cpui识别和执行 2.伪指令&#xff1a;本身不是指令&#xff0c;编译器可以将其替换成若干条指令 3.伪操作&#xff1a;不会生成指令…

如何修复advapi32.dll丢失无法启动程序的问题

如果你在运行Windows程序时遇到了“advapi32.dll丢失无法启动程序”的错误消息&#xff0c;那么这意味着你的计算机缺少这个DLL文件。在本文中&#xff0c;我们将提供一些解决方案&#xff0c;帮助你解决这个问题并恢复计算机的正常运行。 一.advapi32.dll丢失电脑的提示 关于…

软件项目试运行方案

一、 试运行目的 &#xff08;一&#xff09; 系统功能、性能与稳定性考核 &#xff08;二&#xff09; 系统在各种环境和工况条件下的工作稳定性和可靠性 &#xff08;三&#xff09; 检验系统实际应用效果和应用功能的完善 &#xff08;四&#xff09; 健全系统运行管理体制&…

宏碁掠夺者:4K144Hz显示器,让你爽翻天

大家好&#xff0c;我又来了。 买了PS5后&#xff0c;我发现这样的主机放在客厅里可玩性不太高&#xff08;我没机会玩&#xff09;。 毕竟家里还有今年要上小学的孩子。 每天回家打卡交作业都让我发疯。 客厅里放一台PS5无疑是每天对孩子最大的影响&#xff08;也划破了我的心…

vue 下载的插件从哪里上传?npm发布插件详细记录

文章参考&#xff1a; 参考文章一&#xff1a; 封装vue插件并发布到npm详细步骤_vue-cli 封装插件-CSDN博客 参考文章二&#xff1a; npm发布vue插件步骤、组件、package、adduser、publish、getElementsByClassName、important、export、default、target、dest_export default…

智能驾驶规划控制理论学习08-自动驾驶控制模块(轨迹跟踪)

目录 一、基于几何的轨迹跟踪方法 1、基本思想 2、纯追踪 3、Stanly Method 二、PID控制器 三、LQR&#xff08;Linear Quadratic Regulator&#xff09; 1、基本思想 2、LQR解法 3、案例学习 基于LQR的路径跟踪 基于LQR的速度跟踪 4、MPC&#xff08;Mode…

day59 线程

创建线程的第二种方式 实现接口Runnable 重写run方法 创建线程的第三种方式 java.util.concurrent下的Callable重写call()方法 java.util.concurrent.FutureTask 创建线程类对象 获取返回值 线程的四种生命周期 线程的优先级1-10 default为5&#xff0c;优先级越高&#xff0c…

基于梯度统计学的渐变型亮缝识别算法

作者&#xff1a;翟天保Steven 版权声明&#xff1a;著作权归作者所有&#xff0c;商业转载请联系作者获得授权&#xff0c;非商业转载请注明出处 一、场景痛点 在图像处理相关的实际工程中&#xff0c;会出现各式各样的现实复杂问题&#xff0c;有的是因为机械设计导致&#x…

【洛谷 P8668】[蓝桥杯 2018 省 B] 螺旋折线 题解(数学+平面几何)

[蓝桥杯 2018 省 B] 螺旋折线 题目描述 如图所示的螺旋折线经过平面上所有整点恰好一次。 对于整点 ( X , Y ) (X, Y) (X,Y)&#xff0c;我们定义它到原点的距离 dis ( X , Y ) \text{dis}(X, Y) dis(X,Y) 是从原点到 ( X , Y ) (X, Y) (X,Y) 的螺旋折线段的长度。 例如 …

蓝桥杯练习系统(算法训练)ALGO-981 过河马

资源限制 内存限制&#xff1a;256.0MB C/C时间限制&#xff1a;1.0s Java时间限制&#xff1a;3.0s Python时间限制&#xff1a;5.0s 问题描述 在那个过河卒逃过了马的控制以超级超级多的走法走到了终点之后&#xff0c;这匹马表示它不开心了……   于是&#xff0c…

Java教程:RabbitMq讲解与SpringBoot项目如何对接RabbitMq实现生产者与消费者

在往期文章中&#xff0c;我们讲了如何在Windows与Linux环境下安装RabbitMq服务&#xff0c;并访问Web管理端。 有很多同学其实并不知道RabbitMq是用来干嘛的&#xff0c;它起到一个什么作用&#xff0c;并且如何在常见的SpringBoot项目中集成mq并实现消息收发&#xff0c;本章…

Nginx实现高并发

注&#xff1a;文章是4年前在自己网站上写的&#xff0c;迁移过来了。现在看我之前写的这篇文章&#xff0c;描述得不是特别详细&#xff0c;但描述了Nginx的整体架构思想。如果对Nginx玩得透得或者想了解深入的&#xff0c;可以在网上找找其他的文章。 ......................…

day17_订单(结算,提交订单,支付页,立即购买,我的订单)

文章目录 订单模块1 结算1.1 需求说明1.2 获取用户地址1.2.1 UserAddress1.2.2 UserAddressController1.2.3 UserAddressService1.2.4 UserAddressMapper1.2.5 UserAddressMapper.xml 1.3 获取购物项数据1.3.1 CartController1.3.2 CartService1.3.3 openFeign接口定义 1.4 环境…

NIFI从Oracle11G同步数据到Mysql_亲测可用_解决数据重复_数据跟源表不一致的问题---大数据之Nifi工作笔记0065

首先来看一下整体的流程: 可以看到了用到了上面的这些处理器,然后我们主要看看,这里之前 同步的时候,总是出现重复的数据,奇怪. 比如源表中只有166条数据,但是同步过去以后变成了11万条数据了. ${db.table.name:equals(table1):or(${db.table.name:equals(table2)})} 可以看…