动手学深度学习(Pytorch版)代码实践 -循环神经网络-57长短期记忆网络(LSTM)

57长短期记忆网络(LSTM

1.LSTM原理

LSTM是专为解决标准RNN的长时依赖问题而设计的。标准RNN在训练过程中,随着时间步的增加,梯度可能会消失或爆炸,导致模型难以学习和记忆长时间间隔的信息。LSTM通过引入一组称为门的机制来解决这个问题:

  1. 输入门(Input Gate):控制有多少新的信息可以传递到记忆单元中。
  2. 遗忘门(Forget Gate):控制当前记忆单元中有多少信息会被保留。
  3. 输出门(Output Gate):控制记忆单元的输出有多少被传递到下一步。

LSTM还引入了一个称为记忆单元(Cell State)的概念,用于携带长期信息。这些门的组合使得LSTM能够选择性地记住或遗忘信息,从而解决了长时依赖问题。
在这里插入图片描述
在这里插入图片描述

2.优点
  1. 解决梯度消失问题:通过门控机制,LSTM能够有效地传递梯度,避免了梯度消失和爆炸的问题。
  2. 捕捉长时依赖LSTM能够记住和利用长时间间隔的信息,这是标准RNN难以做到的。
  3. 灵活性LSTM适用于各种序列数据处理任务,如时间序列预测、语言建模和序列到序列的翻译等。
3.LSTMGRU的区别

GRU(门控循环单元)是另一种解决长时依赖问题的RNN变体。GRULSTM都引入了门控机制,但它们的具体实现有所不同。

  1. 结构简化GRU的结构比LSTM更简单,参数更少,计算效率更高。
  2. 性能对比:在一些任务上,GRULSTM的性能相当,但在某些情况下,GRU可能表现更好,特别是在较小的数据集或较短的序列上。
  3. 门的数量LSTM有三个门(输入门、遗忘门和输出门),而GRU只有两个门(更新门和重置门)。
4.LSTM代码实践
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 设置批量大小和序列步数
batch_size, num_steps = 32, 35
# 加载时间机器数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)# 初始化LSTM模型参数
def get_lstm_params(vocab_size, num_hiddens, device):# 输入输出的维度大小num_inputs = num_outputs = vocab_size# 正态分布初始化权重def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 三个权重参数(用于输入门、遗忘门、输出门和候选记忆元)def three():return (normal((num_inputs, num_hiddens)),  # 输入到隐藏状态的权重normal((num_hiddens, num_hiddens)),  # 隐藏状态到隐藏状态的权重torch.zeros(num_hiddens, device=device))  # 偏置W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))  # 隐藏状态到输出的权重b_q = torch.zeros(num_outputs, device=device)  # 输出偏置# 将所有参数附加到参数列表中params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)  # 设置参数需要梯度return params# 初始化LSTM的隐藏状态
def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),  # 隐藏状态torch.zeros((batch_size, num_hiddens), device=device))  # 记忆元# LSTM前向传播
def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = state  # 隐藏状态和记忆元outputs = []for X in inputs:# 输入门I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)# 遗忘门F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)# 输出门O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)# 候选记忆元C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)# 更新记忆元C = F * C + I * C_tilda# 更新隐藏状态H = O * torch.tanh(C)# 计算输出Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)  # 返回输出和状态# 训练和预测模型
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 创建自定义的LSTM模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.3, 34433.0 tokens/sec on cuda:0
# 预测结果示例:time traveller conellace there wardeal that are almost us we hou# 使用PyTorch的简洁实现
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)  # 创建LSTM层
model = d2l.RNNModel(lstm_layer, len(vocab))  # 创建模型
model = model.to(device)  # 将模型移动到GPU
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 317323.7 tokens/sec on cuda:0
# 预测结果示例:time travelleryou can show black is white by argument said filby

自定义的LSTM模型:

在这里插入图片描述
简洁实现:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/869750.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【卡尔曼滤波器】DR_CAN 2 学习笔记:_数据融合_协方差矩阵_状态空间方程_观测器问题

【卡尔曼滤波器】2_数学基础_数据融合_协方差矩阵_状态空间方程_观测器问题 非常重要1 数据融合 data fusion 有俩秤,各自有自己的正态分布:俩秤是相互独立的:俩秤都不准,但标准差都符合正态分布 正态分布又叫做高斯分布 向左、向右 都是2, 标准差是2覆盖了68.4 %的可能:…

开启HIVE中分区表支持中文字段

进入hive表: use hive; #修改hive database编码 alter database hive default character set utf8; #修改table编码 alter table PARTITIONS default character set utf8; alter table PARTITION_KEY_VALS default character set utf8; alter table SDS default cha…

机器学习和AI智能写作对未来文案编辑的影响

欢迎关注小知:知孤云出岫 目录 机器学习和AI智能写作对未来文案编辑的影响1. 简介2. AI智能写作工具的现状3. AI智能写作的优势3.1 提高效率3.2 降低成本3.3 数据驱动的个性化 4. AI智能写作的挑战4.1 创造力和独创性4.2 道德和伦理问题4.3 技术限制 5. 行业变化5.…

【排序 - 插入排序 和 希尔排序】

插入排序(Insertion Sort)是一种简单直观的排序算法,它的工作原理是逐步构建有序序列。在排序过程中,它将未排序的元素逐个插入到已排序的部分中,从而在每次插入时扩展已排序序列的长度。 原理介绍 插入排序的基本思…

RAG实践:ES混合搜索BM25+kNN(cosine)

1 缘起 最近在研究与应用混合搜索, 存储介质为ES,ES作为大佬牌数据库, 非常友好地支持关键词检索和向量检索, 当然,支持混合检索(关键词检索向量检索), 是提升LLM响应质量RAG(Retri…

Java 8革新:现代编程的全新标准与挑战

文章目录 一、方法引用二、接口默认方法三、接口静态方法四、集合遍历forEach()方法 一、方法引用 方法引用是Java 8中一种简化Lambda表达式的方式,通过直接引用现有方法来代替Lambda表达式。 方法引用使得代码更加简洁和易读,特别是在处理函数式接口时&…

大数据专业创新人才培养体系的探索与实践

一、引言 随着大数据技术的迅猛发展,其在各行各业中的应用日益广泛,对大数据专业人才的需求也日益增长。我国高度重视大数据产业的发展,将大数据作为国家战略资源,推动大数据与各行业的深度融合。教育部也积极响应国家战略&#…

JVM:字节码文件

文章目录 一、Java虚拟机的组成二、字节码文件的组成1、基本信息2、常量池3、字段4、方法5、属性 三、常用的字节码工具1、javap -v 命令2、jclasslib插件3、阿里arthas 一、Java虚拟机的组成 二、字节码文件的组成 1、基本信息 魔数、字节码文件对应的Java版本号访问标识&am…

Docker 使用基础(2)—镜像

🎬慕斯主页:修仙—别有洞天 ♈️今日夜电波:秒針を噛む—ずっと真夜中でいいのに。 0:34━━━━━━️💟──────── 4:20 🔄 ◀️ ⏸ …

Vue组件通信props和$emit用法

父传子&#xff0c;通过props 子传父&#xff0c;通过$emit App.vue <template><div class"app" style"border: 3px solid #000; margin: 10px">我是APP组件<!-- 1.给组件标签&#xff0c;添加属性方式 赋值 --><!-- 添加属性传值 …

【java算法专场】双指针(下)

611. 有效三角形的个数 目录 611. 有效三角形的个数 算法思路 算法代码 LCR 179. 查找总价格为目标值的两个商品 算法思路 算法代码 HashSet 双指针 15. 三数之和 算法思路 算法代码 18. 四数之和 ​编辑算法思路 算法代码 611. 有效三角形的个数 算法思路 算法…

前端面试题(CSS篇六)

一、浏览器如何判断是否支持 webp 格式图片 &#xff08;1&#xff09;宽高判断法。通过创建image对象&#xff0c;将其src属性设置为webp格式的图片&#xff0c;然后在onload事件中获取图片的宽高&#xff0c;如果能够获取&#xff0c;则说明浏览器支持webp格式图片。如果不能…

IAR 编译优化等级详解

目录 1.编译时优化器何时介入 2.编译优化等级汇总 3.优化项解读 3.1 代码移动 3.2 函数内联 3.3 循环交换 3.4 循环展开 3.5 公用表达式消除 3.6 链接阶段的优化 4 小结 大家好&#xff0c;这里是快乐的肌肉。 最近在迁移工程到IAR编译器上&#xff0c;发现编译优化…

AI赛道成功的“小”AI平台,都在做什么?

在深入了解30多家跨界拓展AI赛道业务的企业后&#xff0c;我们发现大家对目前的AI市场存在一定程度的误解&#xff1a;即认为在AI领域想要分一杯羹&#xff0c;只需要搞几个API&#xff0c;把大语言模型、绘画、视频、数字人等功能都放上去&#xff0c;可能就有机会占一席之地了…

递归 迷宫问题-java

1&#xff09;findWay方法是为了找出走出迷宫的路径&#xff0c;找到返回true&#xff0c;否则返回false 2&#xff09;&#xff08;i&#xff0c;j&#xff09;是老鼠的位置&#xff0c;初始化的位置为&#xff08;1&#xff0c;1&#xff09; 3&#xff09;因为是递归找路&am…

2024年网络监控软件排名|10大网络监控软件是哪些

网络安全&#xff0c;小到关系到企业的生死存亡&#xff0c;大到关系到国家的生死存亡。 因此网络安全刻不容缓&#xff0c;在这里推荐网络监控软件。 2024年这10款软件火爆监控市场。 1.安企神软件&#xff1a; 7天免费试用https://work.weixin.qq.com/ca/cawcde06a33907e6…

【Linux】一文看懂Linux静态库和动态库

文章目录 一、静态库&#xff08;Static Library&#xff09;二、动态库&#xff08;Dynamic Library&#xff09;三、静态库和动态库的比较四、静态库的制作与使用五、动态库的制作与使用六、如何区分链接的是动态库还是静态库 在Linux系统编程中&#xff0c;库是一组预先编写…

【全面讲解下Foxit Reader】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

3年经验的B端产品经理,应该是什么水平?

问你一个问题&#xff1a;你觉得3年经验的B端产品经理&#xff0c;应该是什么水平&#xff1f;很多朋友可能也没有仔细想过&#xff0c;自己3年后应该达到一个什么水平&#xff1f;能做什么体量的业务&#xff1f;要能拿多少薪资&#xff1f; 前几天和一个B端产品经理聊天&…

SQL之delete、truncate和drop区别

MySQL删除数据的方式都有哪些&#xff1f; 常用的三种删除方式&#xff1a;通过 delete、truncate、drop 关键字进行删除&#xff1b;这三种都可以用来删除数据&#xff0c;但场景不同。 一、从执行速度上来说 drop > truncate >> DELETE;二、从原理上讲 1、DELET…