动手学深度学习(Pytorch版)代码实践 -循环神经网络-57长短期记忆网络(LSTM)

57长短期记忆网络(LSTM

1.LSTM原理

LSTM是专为解决标准RNN的长时依赖问题而设计的。标准RNN在训练过程中,随着时间步的增加,梯度可能会消失或爆炸,导致模型难以学习和记忆长时间间隔的信息。LSTM通过引入一组称为门的机制来解决这个问题:

  1. 输入门(Input Gate):控制有多少新的信息可以传递到记忆单元中。
  2. 遗忘门(Forget Gate):控制当前记忆单元中有多少信息会被保留。
  3. 输出门(Output Gate):控制记忆单元的输出有多少被传递到下一步。

LSTM还引入了一个称为记忆单元(Cell State)的概念,用于携带长期信息。这些门的组合使得LSTM能够选择性地记住或遗忘信息,从而解决了长时依赖问题。
在这里插入图片描述
在这里插入图片描述

2.优点
  1. 解决梯度消失问题:通过门控机制,LSTM能够有效地传递梯度,避免了梯度消失和爆炸的问题。
  2. 捕捉长时依赖LSTM能够记住和利用长时间间隔的信息,这是标准RNN难以做到的。
  3. 灵活性LSTM适用于各种序列数据处理任务,如时间序列预测、语言建模和序列到序列的翻译等。
3.LSTMGRU的区别

GRU(门控循环单元)是另一种解决长时依赖问题的RNN变体。GRULSTM都引入了门控机制,但它们的具体实现有所不同。

  1. 结构简化GRU的结构比LSTM更简单,参数更少,计算效率更高。
  2. 性能对比:在一些任务上,GRULSTM的性能相当,但在某些情况下,GRU可能表现更好,特别是在较小的数据集或较短的序列上。
  3. 门的数量LSTM有三个门(输入门、遗忘门和输出门),而GRU只有两个门(更新门和重置门)。
4.LSTM代码实践
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 设置批量大小和序列步数
batch_size, num_steps = 32, 35
# 加载时间机器数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)# 初始化LSTM模型参数
def get_lstm_params(vocab_size, num_hiddens, device):# 输入输出的维度大小num_inputs = num_outputs = vocab_size# 正态分布初始化权重def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 三个权重参数(用于输入门、遗忘门、输出门和候选记忆元)def three():return (normal((num_inputs, num_hiddens)),  # 输入到隐藏状态的权重normal((num_hiddens, num_hiddens)),  # 隐藏状态到隐藏状态的权重torch.zeros(num_hiddens, device=device))  # 偏置W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))  # 隐藏状态到输出的权重b_q = torch.zeros(num_outputs, device=device)  # 输出偏置# 将所有参数附加到参数列表中params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)  # 设置参数需要梯度return params# 初始化LSTM的隐藏状态
def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),  # 隐藏状态torch.zeros((batch_size, num_hiddens), device=device))  # 记忆元# LSTM前向传播
def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = state  # 隐藏状态和记忆元outputs = []for X in inputs:# 输入门I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)# 遗忘门F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)# 输出门O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)# 候选记忆元C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)# 更新记忆元C = F * C + I * C_tilda# 更新隐藏状态H = O * torch.tanh(C)# 计算输出Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)  # 返回输出和状态# 训练和预测模型
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 创建自定义的LSTM模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.3, 34433.0 tokens/sec on cuda:0
# 预测结果示例:time traveller conellace there wardeal that are almost us we hou# 使用PyTorch的简洁实现
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)  # 创建LSTM层
model = d2l.RNNModel(lstm_layer, len(vocab))  # 创建模型
model = model.to(device)  # 将模型移动到GPU
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 317323.7 tokens/sec on cuda:0
# 预测结果示例:time travelleryou can show black is white by argument said filby

自定义的LSTM模型:

在这里插入图片描述
简洁实现:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/869750.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【卡尔曼滤波器】DR_CAN 2 学习笔记:_数据融合_协方差矩阵_状态空间方程_观测器问题

【卡尔曼滤波器】2_数学基础_数据融合_协方差矩阵_状态空间方程_观测器问题 非常重要1 数据融合 data fusion 有俩秤,各自有自己的正态分布:俩秤是相互独立的:俩秤都不准,但标准差都符合正态分布 正态分布又叫做高斯分布 向左、向右 都是2, 标准差是2覆盖了68.4 %的可能:…

开启HIVE中分区表支持中文字段

进入hive表: use hive; #修改hive database编码 alter database hive default character set utf8; #修改table编码 alter table PARTITIONS default character set utf8; alter table PARTITION_KEY_VALS default character set utf8; alter table SDS default cha…

机器学习和AI智能写作对未来文案编辑的影响

欢迎关注小知:知孤云出岫 目录 机器学习和AI智能写作对未来文案编辑的影响1. 简介2. AI智能写作工具的现状3. AI智能写作的优势3.1 提高效率3.2 降低成本3.3 数据驱动的个性化 4. AI智能写作的挑战4.1 创造力和独创性4.2 道德和伦理问题4.3 技术限制 5. 行业变化5.…

【排序 - 插入排序 和 希尔排序】

插入排序(Insertion Sort)是一种简单直观的排序算法,它的工作原理是逐步构建有序序列。在排序过程中,它将未排序的元素逐个插入到已排序的部分中,从而在每次插入时扩展已排序序列的长度。 原理介绍 插入排序的基本思…

RAG实践:ES混合搜索BM25+kNN(cosine)

1 缘起 最近在研究与应用混合搜索, 存储介质为ES,ES作为大佬牌数据库, 非常友好地支持关键词检索和向量检索, 当然,支持混合检索(关键词检索向量检索), 是提升LLM响应质量RAG(Retri…

Ethercat这里的CRC是如何计算和使用的?

这里的CRC是如何计算和使用的? 非常好的问题。让我详细解释一下CRC在这个上下文中是如何计算和使用的: CRC 计算过程: a. 准备输入数据: 收集从开始标签到结束标签的所有字符(不包括带Crc32属性的开始标签)。将这些字…

Java 8革新:现代编程的全新标准与挑战

文章目录 一、方法引用二、接口默认方法三、接口静态方法四、集合遍历forEach()方法 一、方法引用 方法引用是Java 8中一种简化Lambda表达式的方式,通过直接引用现有方法来代替Lambda表达式。 方法引用使得代码更加简洁和易读,特别是在处理函数式接口时&…

Kotlin Class

普通 Class kotlin class TestClz {val name: String "hsw"val age: Int 18 }java public final class TestClz {NotNullprivate final String name "hsw";private final int age 18;NotNullpublic final String getName() {return this.name;}publ…

WPF-控件样式设置

1、控件样式设置 1.1、内嵌式为相同控件设置样式 <Window.Resources><Style TargetType"Button"><Setter Property"Background" Value"Yellow"></Setter><Setter Property"Width" Value"60"&g…

大数据专业创新人才培养体系的探索与实践

一、引言 随着大数据技术的迅猛发展&#xff0c;其在各行各业中的应用日益广泛&#xff0c;对大数据专业人才的需求也日益增长。我国高度重视大数据产业的发展&#xff0c;将大数据作为国家战略资源&#xff0c;推动大数据与各行业的深度融合。教育部也积极响应国家战略&#…

C语言 将密码译回原文

有一行电文,已按下面规律译成密码: A→Z a→z B→Y b→y C→X c→x … … 即第1个字母变成第26个字母,第i个字母变成第(26-i1)个字母,非字母字符不变。要求编程序将密码译回原文,并输出密码和原文。 #include <stdio.h> #include <ctype.h>void decrypt(c…

JVM:字节码文件

文章目录 一、Java虚拟机的组成二、字节码文件的组成1、基本信息2、常量池3、字段4、方法5、属性 三、常用的字节码工具1、javap -v 命令2、jclasslib插件3、阿里arthas 一、Java虚拟机的组成 二、字节码文件的组成 1、基本信息 魔数、字节码文件对应的Java版本号访问标识&am…

MySQL 日期和时间函数

NOW(): 返回当前的日期和时间。 SELECT NOW() AS current_datetime; -- 结果: 当前的日期和时间 CURDATE(): 返回当前日期。 SELECT CURDATE() AS current_date; -- 结果: 当前的日期 CURTIME(): 返回当前时间。 SELECT CURTIME() AS current_time; -- 结果: 当前的…

Docker 使用基础(2)—镜像

&#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;秒針を噛む—ずっと真夜中でいいのに。 0:34━━━━━━️&#x1f49f;──────── 4:20 &#x1f504; ◀️ ⏸ …

Vue组件通信props和$emit用法

父传子&#xff0c;通过props 子传父&#xff0c;通过$emit App.vue <template><div class"app" style"border: 3px solid #000; margin: 10px">我是APP组件<!-- 1.给组件标签&#xff0c;添加属性方式 赋值 --><!-- 添加属性传值 …

Oracle字符串类型常涉及的问题

lpad(字段,位数,‘0’) 输出指定位数字符串不足左补0 rpad(字段,位数,‘0’) 输出指定位数字符串不足右补0 to_char(字段,‘fm000000’) 输出5位字符串不足左补0&#xff0c;fm去除前面的空格 to_number(decode(trim(字段),null,‘0’,‘’,‘0’,trim(字段))) 字符串转为数值…

【java算法专场】双指针(下)

611. 有效三角形的个数 目录 611. 有效三角形的个数 算法思路 算法代码 LCR 179. 查找总价格为目标值的两个商品 算法思路 算法代码 HashSet 双指针 15. 三数之和 算法思路 算法代码 18. 四数之和 ​编辑算法思路 算法代码 611. 有效三角形的个数 算法思路 算法…

前端面试题(CSS篇六)

一、浏览器如何判断是否支持 webp 格式图片 &#xff08;1&#xff09;宽高判断法。通过创建image对象&#xff0c;将其src属性设置为webp格式的图片&#xff0c;然后在onload事件中获取图片的宽高&#xff0c;如果能够获取&#xff0c;则说明浏览器支持webp格式图片。如果不能…

leetcode-动态规划-01背包

一、二维数组 1、状态转移方程&#xff1a; 不放物品i&#xff1a;由dp[i - 1][j]推出&#xff0c;即背包容量为j&#xff0c;里面不放物品i的最大价值&#xff0c;此时dp[i][j]就是dp[i - 1][j]。(其实就是当物品i的重量大于背包j的重量时&#xff0c;物品i无法放进背包中&a…

IAR 编译优化等级详解

目录 1.编译时优化器何时介入 2.编译优化等级汇总 3.优化项解读 3.1 代码移动 3.2 函数内联 3.3 循环交换 3.4 循环展开 3.5 公用表达式消除 3.6 链接阶段的优化 4 小结 大家好&#xff0c;这里是快乐的肌肉。 最近在迁移工程到IAR编译器上&#xff0c;发现编译优化…