动手学深度学习(Pytorch版)代码实践 -循环神经网络- 56门控循环单元(`GRU`)

56门控循环单元(GRU

我们讨论了如何在循环神经网络中计算梯度, 以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。 下面我们简单思考一下这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。 考虑一个极端情况,其中第一个观测值包含一个校验和, 目标是在序列的末尾辨别校验和是否正确。 在这种情况下,第一个词元的影响至关重要。 我们希望有某些机制能够在一个记忆元里存储重要的早期信息。 如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度, 因为它会影响所有后续的观测值。
  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。 例如,在对网页内容进行情感分析时, 可能有一些辅助HTML代码与网页传达的情绪无关。 我们希望有一些机制来跳过隐状态表示中的此类词元。
  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。 例如,书的章节之间可能会有过渡存在, 或者证券的熊市和牛市之间可能会有过渡存在。 在这种情况下,最好有一种方法来重置我们的内部状态表示。

门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。 这些机制是可学习的,并且能够解决了上面列出的问题。 例如,如果第一个词元非常重要, 模型将学会在第一次观测之后不更新隐状态。 同样,模型也可以学会跳过不相关的临时观测。 最后,模型还将学会在需要的时候重置隐状态。

1.重置门和更新门
  • 重置门有助于捕获序列中的短期依赖关系。
  • 更新门有助于捕获序列中的长期依赖关系。

在这里插入图片描述

2.候选隐状态

在这里插入图片描述

3.隐状态

在这里插入图片描述

4.从零开始实现
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 定义批量大小和时间步数
batch_size, num_steps = 32, 35# 使用d2l库的load_data_time_machine函数加载数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)def get_params(vocab_size, num_hiddens, device):"""初始化GRU模型的参数。参数:vocab_size (int): 词汇表的大小。num_hiddens (int): 隐藏单元的数量。device (torch.device): 张量所在的设备。返回:list of torch.Tensor: 包含所有参数的列表。"""num_inputs = num_outputs = vocab_size  # 输入和输出的数量都等于词汇表大小def normal(shape):"""使用均值为0,标准差为0.01的正态分布初始化张量。参数: shape (tuple): 张量的形状。返回:torch.Tensor: 初始化后的张量。"""return torch.randn(size=shape, device=device) * 0.01def three():"""初始化GRU门的参数。返回:tuple of torch.Tensor: 包含门的权重和偏置的元组。"""return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()   # 更新门参数W_xr, W_hr, b_r = three()   # 重置门参数W_xh, W_hh, b_h = three()   # 候选隐藏状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 将所有参数收集到一个列表中params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params: # 启用所有参数的梯度计算param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):"""初始化GRU的隐藏状态。参数:batch_size (int): 批量大小。num_hiddens (int): 隐藏单元的数量。device (torch.device): 张量所在的设备。返回:tuple of torch.Tensor: 初始隐藏状态。"""return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):"""定义GRU的前向传播。参数:inputs (torch.Tensor): 输入数据。state (tuple of torch.Tensor): 隐藏状态。params (list of torch.Tensor): GRU的参数。返回:torch.Tensor: GRU的输出。tuple of torch.Tensor: 更新后的隐藏状态。"""W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state  # 获取隐藏状态outputs = []  # 存储输出的列表for X in inputs:  # 遍历每一个输入时间步# 计算更新门ZZ = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)# 计算重置门RR = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)# 计算候选隐藏状态H_tildaH_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)# 更新隐藏状态HH = Z * H + (1 - Z) * H_tilda# 计算输出YY = H @ W_hq + b_qoutputs.append(Y)  # 将输出添加到列表中return torch.cat(outputs, dim=0), (H,)  # 返回连接后的输出和更新后的隐藏状态# 获取词汇表大小、隐藏单元数量和设备
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
# 定义训练的轮数和学习率
num_epochs, lr = 500, 1
# 初始化GRU模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
# 使用d2l库的train_ch8函数训练模型
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.1, 38557.3 tokens/sec on cuda:0
# time traveller for so it will be convenient to speak of himwas e

在这里插入图片描述

5.简洁实现
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 定义批量大小和时间步数
batch_size, num_steps = 32, 35
# 使用d2l库的load_data_time_machine函数加载数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)num_epochs, lr = 500, 1
# # 获取词汇表大小、隐藏单元数量和设备
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens) # 定义一个GRU层,输入大小为num_inputs,隐藏单元数量为num_hiddens
model = d2l.RNNModel(gru_layer, len(vocab)) # 使用GRU层和词汇表大小创建一个RNN模型
model = model.to(device)
# 该函数需要模型、训练数据迭代器、词汇表、学习率、训练轮数和设备作为参数
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 248342.8 tokens/sec on cuda:0
# time travelleryou can show black is white by argument said filby

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/42595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器人动力学模型及其线性化阻抗控制模型

机器人动力学模型 机器人动力学模型描述了机器人的运动与所受力和力矩之间的关系。这个模型考虑了机器人的质量、惯性、关节摩擦、重力等多种因素,用于预测和解释机器人在给定输入下的动态行为。动力学模型是设计机器人控制器的基础,它可以帮助我们理解…

html的浮动作用详解

HTML中的“浮动”(Float)是一个CSS布局技术,它原本设计用于文本环绕图像或实现简单的布局效果,比如并排排列元素。然而,随着Web开发的演进,浮动也被广泛用于更复杂的页面布局设计中,尽管现代CSS…

2024/7/7周报

文章目录 摘要Abstract文献阅读题目问题本文贡献问题描述图神经网络Framework实验数据集实验结果 深度学习MAGNN模型相关代码GNN为什么要用GNN?GNN面临挑战 总结 摘要 本周阅读了一篇用于多变量时间序列预测的多尺度自适应图神经网络的文章,多变量时间序…

SAP已下发EWM的交货单修改下发状态

此种情况针对EWM未接收到ERP交货单时,可以使用此程序将ERP交货单调整为未分配状态,在进行调整数据后,然后使用VL06I(启用自动下发EWM配置,则在交货单修改保存后会立即下发EWM)重新下发EWM系统。 操作步骤如…

3ds Max渲染曝光过度怎么办?

3dmax效果图云渲染平台——渲染100 以3ds Max 2025、VR 6.2、CR 11.2等最新版本为基础,兼容fp、acescg等常用插件,同时LUT滤镜等参数也得到了同步支持。 注册填邀请码【7788】可领30元礼包和免费渲染券哦~ 遇到3ds Max渲染过程中曝光过度的问题&#xf…

SLF4J的介绍与使用(有logback和log4j2的具体实现案例)

目录 1.日志门面的介绍 常见的日志门面 : 常见的日志实现: 日志门面和日志实现的关系: 2.SLF4J 的介绍 业务场景(问题): SLF4J的作用 SLF4J 的基本介绍 日志框架的绑定(重点&#xff09…

Influxdb中,Flux常用的函数

目录 一、Flux常用的函数及其简要描述 1. 数据源和筛选函数 2. 聚合函数 3. 时间序列操作函数 4. 转换和映射函数 5. 窗口函数 6. 其他常用函数 注意事项 二、使用方法举例 1. 数据源和筛选 2. 聚合 3. 时间序列操作 4. 窗口函数 5. 转换和映射 注意事项 三、…

跨越界限的温柔坚守

跨越界限的温柔坚守 —— 郑乃馨与男友的甜蜜抉择在这个光怪陆离、瞬息万变的娱乐圈里,每一段恋情像是夜空中划过的流星,璀璨短暂。然而,当“郑乃馨与男友甜蜜约会”的消息再次跃入公众视野,它不仅仅是一段简单的爱情故事&#xf…

iOS中多个tableView 嵌套滚动特性探索

嵌套滚动的机制 目前的结构是这样的,整个页面是一个大的tableView, Cell 是整个页面的大小,cell 中嵌套了一个tableView 通过测试我们发现滚动的时候,系统的机制是这样的, 我们滑动内部小的tableView, 开始滑动的时候&#xff0c…

C/C++ 代码注释规范及 doxygen 工具

参考 谷歌项目风格指南——注释 C doxygen 风格注释示例 ubuntu20 中 doxygen 文档生成 doxygen 官方文档 在 /Doxygen/Special Command/ 章节介绍 doxygen 的关键字 注释说明 注释的目的是提高代码的可读性与可维护性。 C 风格注释 // 单行注释/* 多行注释 */ C 风格注…

设置某些路由为公开访问,不需要登录状态即可访问

在单页面应用(SPA)框架中,如Vue.js,路由守卫是一种非常有用的功能,它允许你控制访问路由的权限。Vue.js 使用 Vue Router 作为其官方路由管理器。路由守卫主要分为全局守卫和组件内守卫。 以下是如何设置路由守卫以允…

k8s 部署RuoYi-Vue-Plus之mysql搭建

1.直接部署一个pod 需要挂载存储款, 可参考 之前文章设置 https://blog.csdn.net/weimeibuqieryu/article/details/140183843 2.部署yaml 先创建命名空间ruoyi kubectl create namespace ruoyi创建部署文件 mysql-deploy.yaml --- apiVersion: v1 kind: PersistentVolume …

【论文阅读笔记】Meta 3D AssetGen

【论文阅读笔记】Meta 3D AssetGen: Text-to-Mesh Generation with High-Quality Geometry, Texture, and PBR Materials Info摘要引言创新点 相关工作T23D基于图片的3d 重建使用 PBR 材料的 3D 建模。 方法文本到图像:从文本中生成阴影和反照率图像Image-to-3D:基于pbr的大型重…

搭建NEMU与QEMU的DiffTest环境(动态库方式)

搭建NEMU与QEMU的DiffTest环境(动态库方式) 1 DiffTest原理简述2 编译NEMU3 编译qemu-dl-difftest3.1 修改NEMU/scripts/isa.mk3.2 修改NEMU/tools/qemu-dl-diff/src/diff-test.c3.3 修改NEMU/scripts/build.mk3.4 让qemu-dl-difftest带调试信息3.5 编译…

C语言实现字符串排序

如果只有英文字符且不区分大小写的话按照字典序排序可以用strcmp函数&#xff0c;两个字符串自左向右逐个字符相比&#xff08;按ASCII值大小相比较&#xff09; strcmp(s1,s2) 当s1<s2时&#xff0c;返回为负数&#xff1b; 当s1s2时&#xff0c;返回值 0&#xff1b; …

安卓的组件

人不走空 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌赋&#xff1a;斯是陋室&#xff0c;惟吾德馨 目录 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌…

【Linux】打包命令——tar

打包和压缩 虽然打包和压缩都涉及将多个文件组合成单个实体&#xff0c;但它们之间存在重要差异。 打包和压缩的区别&#xff1a; 打包是将多个文件或目录组合在一起&#xff0c;但不对其进行压缩。这意味着打包后的文件大小可能与原始文件相同或更大。此外&#xff0c;打包…

Win10精英控制器2代青春版 设备删除失败,蓝牙连接断断续续

前提 更新了主板rog z790带WiFi、蓝牙&#xff0c;但是精英控制器连上老师断断续续。 过程 在设备管理中尝试了卸载、重装主板对应的蓝牙驱动&#xff0c;怎么都不行&#xff0c;都已经想放弃了。 但是想起来之前主板没有蓝牙&#xff0c;用的是绿联的USB蓝牙接收器&#xf…

Ubuntu24.04修改系统的环境变量

apache/tomcat配置要用到JDK&#xff0c;使用torch有时也会用到系统库&#xff0c;涉及到环境变量 1. 查看环境变量 cat /etc/environment2. 新建环境变量 sudo nano /etc/environment在文件底部添加新的环境变量 MY_VARIABLE"your_value"3. 修改环境变量 临时—…

数字化精益生产系统--APS 排程管理系统

APS&#xff08;Advanced Planning and Scheduling&#xff09;排程管理系统&#xff0c;即高级生产计划与排程系统&#xff0c;是一种高度智能化的计划和排程系统。它通过整合各种生产和供应链数据&#xff0c;运用先进的算法和数据模型&#xff0c;根据各种约束条件&#xff…