LSTM(长短期记忆网络)详解

1️⃣ LSTM介绍

标准的RNN存在梯度消失梯度爆炸问题,无法捕捉长期依赖关系。那么如何理解这个长期依赖关系呢?

例如,有一个语言模型基于先前的词来预测下一个词,我们有一句话 “the clouds are in the sky”,基于"the clouds are in the",预测"sky",在这样的场景中,预测的词和提供的信息之间位置间隔是非常小的,如下图所示,RNN可以捕捉到先前的信息。
在这里插入图片描述
然而,针对复杂场景,我们有一句话"I grew up in France… I speak fluent French","French"基于"France"推断,但是它们之间的间隔很远很远,RNN 会丧失学习到连接如此远信息的能力。这就是长期依赖关系。

为了解决该问题,LSTM通过引入三种门遗忘门输入门输出门控制信息的流入和流出,有助于保留长期依赖关系,并缓解梯度消失【注意:没有梯度爆炸昂】。LSTM在1997年被提出


2️⃣ 原理

下面这张图是标准的RNN结构:

  • x t x_t xt是t时刻的输入
  • s t s_t st是t时刻的隐层输出, s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(U\cdot x_t+W\cdot s_{t-1}) st=f(Uxt+Wst1),f表示激活函数, s t − 1 s_{t-1} st1表示t-1时刻的隐层输出
  • h t h_t ht是t时刻的输出, h t = s o f t m a x ( V ⋅ s t ) h_t=softmax(V\cdot s_t) ht=softmax(Vst)
    在这里插入图片描述

LSTM的整体结构如下图所示,第一眼看到,反正我是看不懂。前面讲到LSTM引入三种门遗忘门输入门输出门,现在我们逐一击破,一个个分析一下它们到底是什么。
在这里插入图片描述
这是3D视角的LSTM:
在这里插入图片描述
首先来看遗忘门,也就是下面这张图:

在这里插入图片描述

遗忘门输入包含两部分

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆(即隐层输出),在LSTM中当前时间步的输出 h t − 1 h_{t-1} ht1就是隐层输出 s t − 1 s_{t-1} st1
  • x t x_t xt:表示t时刻的输入

遗忘门输出为 f t f_t ft,公式表示为:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma\left(W_f\cdot[h_{t-1},x_t] + b_f\right) ft=σ(Wf[ht1,xt]+bf)
其中, W f W_f Wf b f b_f bf是遗忘门的参数, [ s t − 1 , x t ] [s_{t-1},x_t] [st1,xt]表示concat操作。 σ ( ) \sigma() σ()表示sigmoid函数。

遗忘门定我们会从长期记忆中丢弃什么信息【理解为:删除什么日记】,输出一个在 0 到 1 之间的数值,1 表示“完全保留”,0 表示“完全舍弃”。

然后来看输入门
在这里插入图片描述

输入门的输入包含两部分:

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入

输入门的输出为新添加的内容 i t ∗ C ~ t i_t * \tilde{C}_t itC~t,其具体操作为:
i t = σ ( W i ⋅ [ s t − 1 , x t ] + b i ) C ~ t = tanh ⁡ ( W C ⋅ [ s t − 1 , x t ] + b C ) \begin{aligned}i_{t}&=\sigma\left(W_i\cdot[s_{t-1},x_t] + b_i\right)\\\tilde{C}_{t}&=\tanh(W_C\cdot[s_{t-1},x_t] + b_C)\end{aligned} itC~t=σ(Wi[st1,xt]+bi)=tanh(WC[st1,xt]+bC)

输入门决定什么样的新信息被加入到长期记忆(即细胞状态)中【理解为:添加什么日记】。

然后,我们来更新长期记忆,将 C t − 1 C_{t-1} Ct1更新为 C t C_t Ct。我们把旧状态 C t − 1 C_{t-1} Ct1与遗忘门的输出 f t f_t ft相乘,忘记一些东西。接着加上输入门的输出 i t ∗ C ~ t i_t * \tilde{C}_t itC~t,新加一些东西,最终得到新的长期记忆 C t C_t Ct。具体操作为:
在这里插入图片描述

C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}_t Ct=ftCt1+itC~t

最后来看输出门
在这里插入图片描述

输出门的输入包含:

  • s t − 1 s_{t-1} st1:表示t-1时刻的短期记忆
  • x t x_t xt:表示t时刻的输入
  • c t c_t ct:更新后的长期记忆

输出门的输出为 h t h_{t} ht s t s_{t} st h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1,其具体操作为:
o t = σ ( W o [ s t − 1 , x t ] + b o ) s t = h t = o t ∗ t a n h ( C t ) \begin{aligned}&o_{t}=\sigma\left(W_{o} \left[ s_{t-1},x_{t}\right] + b_{o}\right)\\&s_{t}=h_{t}=o_{t}*\mathrm{tanh}\left(C_{t}\right)\end{aligned} ot=σ(Wo[st1,xt]+bo)st=ht=ottanh(Ct)

首先,我们运行一个 sigmoid 层来确定长期记忆的哪个部分将输出出去。接着,我们把长期记忆通过 tanh 进行处理(得到一个在-1到1之间的值)并将它和 o t o_{t} ot相乘,最终将输出copy成两份 h t h_t ht s t s_{t} st h t h_t ht作为当前时间步的输出, s t s_{t} st当做短期记忆输入到t+1。

LSTM的结构分析完了,那为什么LSTM能够缓解梯度消失呢?

我前面写的这篇文章中介绍了为什么RNN会有梯度消失和爆炸:点这里查看

主要原因是反向传播时,梯度中有这一部分:
∏ j = k + 1 3 ∂ s j ∂ s j − 1 = ∏ j = k + 1 3 t a n h ′ W \prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}}=\prod_{j=k+1}^3tanh^{'}W j=k+13sj1sj=j=k+13tanhW
LSTM的作用就是让 ∂ s j ∂ s j − 1 \frac{\partial s_j}{\partial s_{j-1}} sj1sj≈1

在LSTM里,隐藏层的输出换了个符号,从 s s s变成 C C C了,即 C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\tilde{C}_t Ct=ftCt1+itC~t。注意, f t f_t ft , i t 和 C ~ t i_{t\text{ 和}}\tilde{C}_t it C~t 都是 C t − 1 C_{t-1} Ct1的复合函数(因为它们都和 h t − 1 h_{t-1} ht1有关,而 h t − 1 h_{t-1} ht1又和 C t − 1 C_{t-1} Ct1有关)。因此我们来求一下 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} Ct1Ct
∂ C t ∂ C t − 1 = f t + ∂ f t ∂ C t − 1 ⋅ C t − 1 + … \frac{\partial C_t}{\partial C_{t-1}}=f_t+\frac{\partial f_t}{\partial C_{t-1}}\cdot C_{t-1}+\ldots Ct1Ct=ft+Ct1ftCt1+

后面的我们就不管了,展开求导太麻烦了。这里面 f t f_t ft是遗忘门的输出,1表示完全保留旧状态,0表示完全舍弃旧状态,如果我们把 f t f_t ft设置成1或者是接近于1,那 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} Ct1Ct就有梯度了。因此LSTM可以一定程度上缓解梯度消失,然而如果时间步很长的话,依然会存在梯度消失问题,所以只是缓解

注意:LSTM可以缓解梯度消失,但是梯度爆炸并不能解决,因为LSTM不影响参数W


3️⃣ 代码

# 创建一个LSTM模型
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LSTM(nn.Module):def __init__(self,input_size,hidden_size,num_layers,output_size):super().__init__()self.num_layers=num_layersself.hidden_size=hidden_size# 定义LSTM层# batch_first=True则输入形状为(batch, seq_len, input_size)self.lstm=nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)# 定义全连接层,用于输出self.fc=nn.Linear(hidden_size,output_size)def forward(self, x):# self.lstm(x)会返回两个值# out:形状为 (batch,seq_len,hidden_size)# 隐层状态和细胞状态:形状为 (batch, num_layers, hidden_size);在这里,我们忽略隐层状态和细胞状态的输出,因此使用了占位符out, _ = self.lstm(x)out = self.fc(out)return outif __name__=='__main__':input_size=10hidden_size=64num_layers=1output_size=1net=LSTM(input_size,hidden_size,num_layers,output_size)# x的形状为(batch_size, seq_len, input_size)x=torch.randn(16,8,input_size)out=net(x)print(out.shape)

输出结果为:

torch.Size([16, 8, 1]),表示有16个batch,对于每个batch,有8个时间步,每个时间步的output大小为1

4️⃣ 总结

  • 思考一个问题,对于多层LSTM,如何理解呢?
    在这里插入图片描述
    注意:图中颜色相同的其实表达的值一样, h = s h=s h=s

    1. 第一层 LSTM 首先初始隐层状态 s 0 l a y e r 1 s^{layer1}_0 s0layer1和细胞状态 c 0 l a y e r 1 c^{layer1}_0 c0layer1,然后输入 x t − 1 x_{t-1} xt1 生成隐层状态和输出 s t − 1 l a y e r 1 = h t − 1 l a r y e r 1 s^{layer1}_{t-1}=h_{t-1}^{laryer1} st1layer1=ht1laryer1和细胞状态 c t − 1 l a y e r 1 c^{layer1}_{t-1} ct1layer1
    2. 第二层 LSTM首先初始隐层状态 s 0 l a y e r 2 s^{layer2}_0 s0layer2和细胞状态 c 0 l a y e r 2 c^{layer2}_0 c0layer2,然后接收第一层的输出 h t − 1 l a r y e r 1 h_{t-1}^{laryer1} ht1laryer1作为输入,生成 s t − 1 l a y e r 2 = h t − 1 l a r y e r 2 s^{layer2}_{t-1}=h_{t-1}^{laryer2} st1layer2=ht1laryer2 c t − 1 l a y e r 2 c^{layer2}_{t-1} ct1layer2
    3. 第N层 LSTM首先初始隐层状态 s 0 l a y e r N s^{layerN}_0 s0layerN和细胞状态 c 0 l a y e r N c^{layerN}_0 c0layerN,然后接收第N-1层的输出 h t − 1 l a r y e r N − 1 h_{t-1}^{laryer N-1} ht1laryerN1作为输入,生成最终的 s t − 1 l a y e r N = h t − 1 l a r y e r 2 s^{layerN}_{t-1}=h_{t-1}^{laryer2} st1layerN=ht1laryer2 c t − 1 l a y e r N c^{layerN}_{t-1} ct1layerN
  • 为什么需要多层LSTM?
    多层 LSTM 通过增加深度来增强模型的表示能力和复杂度,能够学习到更高阶、更抽象的特征

  • 通过控制遗忘门的输出 f t f_t ft来控制梯度,以缓解梯度消失问题,但不能缓解梯度爆炸

5️⃣ 参考

  • 理解 LSTM 网络

  • 【LSTM长短期记忆网络】3D模型一目了然,带你领略算法背后的逻辑

  • 关于RNN的梯度消失&爆炸问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/886370.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vulnhub靶场 Billu_b0x 练习

目录 0x00 准备0x01 主机信息收集0x02 站点信息收集0x03 漏洞查找与利用1. 文件包含2. SQL注入3. 文件上传4. 反弹shell5. 提权(思路1:ssh)6. 提权(思路2:内核)7. 补充 0x04 总结 0x00 准备 下载链接&#…

重拾CSS,前端样式精读-媒体查询

前言 本文收录于CSS系列文章中,欢迎阅读指正 说到媒体查询,大家首先想到的可能是有关响应式的知识点,除此之外,它还可以用于条件加载资源,字体大小,图像和视频的优化,用户界面调整等等方面&am…

普通用户切换到 root 用户不需要输入密码配置(Ubuntu20)

在 Ubuntu 系统中,允许一个普通用户切换到 root 用户而不需要输入密码,可以通过以下步骤配置 sudo 设置来实现。 步骤: 打开 sudoers 文件进行编辑: 在终端中,输入以下命令来编辑 sudoers 文件: sudo visu…

MySQL系统优化

文章目录 MySQL系统优化第一章:引言第二章:MySQL服务架构优化1. 读写分离2. 水平分区与垂直分区3. 缓存策略 第三章:MySQL配置优化1. 内存分配优化Buffer Pool 的优化查询缓存与表缓存Key Buffer 2. 连接优化最大连接数会话超时连接池 3. 日志…

菲涅耳全息图

菲涅耳全息图:记录介质在物光波场的菲涅耳衍射区(物体到记录介质表面的距离在菲涅耳衍射区内)。 一、点源全息图的记录和再现 1.1 记录 设物光波和参考光波是从点源O(xo, yo, zo)和点源 R(xr, yr, zr)发出的球面波, 波长为λ1, 全息底片位于z0 的平面上, 与两个点源…

多线程-阻塞队列

目录 阻塞队列 消息队列 阻塞队列用于生产者消费者模型 概念 实现原理 生产者消费者主要优势 缺陷 阻塞队列的实现 1.写一个普通队列 2.加上线程安全和阻塞等待 3.解决代码中的问题 阻塞队列 阻塞队列,是带有线程安全功能的队列,拥有队列先进…

css样式:flex布局

文章目录 简介简单使用直接使用一行放不下的换行水平方向上对齐方式竖直方向上对齐方式布局中排列顺序放大比例缩小比例单个元素与其他元素不同的对齐 文章目录 简介简单使用直接使用一行放不下的换行水平方向上对齐方式竖直方向上对齐方式布局中排列顺序放大比例缩小比例单个元…

MySQL LOAD DATA INFILE导入数据报错

1.导入命令 LOAD DATA INFILE "merge.csv" INTO TABLE 报名数据 FIELDS TERMINATED BY , ENCLOSED BY " LINES TERMINATED BY \n IGNORE 1 LINES; 2.表结构 CREATE TABLE IF NOT EXISTS 报名数据 ( pid VARCHAR(100) NOT NULL, 查询日期 VARCHAR(25) NO…

详解模版类pair

目录 一、pair简介 二、 pair的创建 三、pair的赋值 四、pair的排序 (1)用sort默认排序 (2)用sort中的自定义排序进行排序 五、pair的交换操作 一、pair简介 pair是一个模版类,可以存储两个值的键值对.first以…

C#从入门到放弃

C#和.NET的区别 C# C#是一个编程语言 .NET .NET是一个在window下创建程序的框架 .NET框架不仅局限于C#,它还可以支持很多语言 .NET包括了2个组件,一个叫CLR(通用语言运行时),另一个是用来构建程序的类库 CLR 用C写一个程序,在一台8688的机器…

算法复杂度详解

目录 算法定义 复杂度概念 时间复杂度 大O的渐近表示法 空间复杂度 常见复杂度对比 算法定义 算法(Algorithm):就是定义良好的计算过程,他取一个或一组的值为输入,并产生出一个或一组值作为 输出。简单来说算法就是一系列的计算步骤,用来…

AI写作(十)发展趋势与展望(10/10)

一、AI 写作的崛起之势 在当今科技飞速发展的时代,AI 写作如同一颗耀眼的新星,迅速崛起并在多个领域展现出强大的力量。 随着人工智能技术的不断进步,AI 写作在内容创作领域发挥着越来越重要的作用。据统计,目前已有众多企业开始…

电子应用设计方案-12:智能窗帘系统方案设计

一、系统概述 本设计方案旨在打造便捷、高效的全自动智能窗帘系统。 二、硬件选择 1. 电机:选用低噪音、扭矩合适的智能电机,根据窗帘尺寸和重量确定电机功率,确保能平稳拉动窗帘。 2. 轨道:选择坚固、顺滑的铝合金轨道&…

数据结构《栈和队列》

文章目录 一、什么是栈?1.1 栈的模拟实现1.2 关于栈的例题 二、什么是队列?2.2 队列的模拟实现2.2 关于队列的例题 总结 提示:关于栈和队列的实现其实很简单,基本上是对之前的顺序表和链表的一种应用,代码部分也不难。…

从0-1训练自己的数据集实现火焰检测

随着工业、建筑、交通等领域的快速发展,火灾作为一种常见的灾难性事件,对生命财产安全造成了严重威胁。为了提高火灾的预警能力,减少火灾损失,火焰检测技术应运而生,成为火灾监控和预防的有效手段之一。 传统的火灾检测方法,如烟雾探测器、温度传感器等,存在响应时间慢…

WSL--无需安装虚拟机和docker可以直接在Windows操作系统上使用Linux操作系统

安装WSL命令 管理员打开PowerShell或Windows命令提示符,输入wsl --install,然后回车 注意:此命令将启用运行 WSL 和安装 Linux 的 Ubuntu 发行版所需的功能。 注意:默认安装最新的Ubuntu发行版。 注意:默认安装路径是…

云原生-docker安装与基础操作

一、云原生 Docker 介绍 Docker 在云原生中的优势 二、docker的安装 三、docker的基础命令 1. docker pull(拉取镜像) 2. docker images(查看本地镜像) 3. docker run(创建并启动容器) 4. docker ps…

@Autowired 和 @Resource思考(注入redisTemplate时发现一些奇怪的现象)

1. 前置知识 Configuration public class RedisConfig {Beanpublic RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {RedisTemplate<String, Object> template new RedisTemplate<>();template.setConnectionFactory(facto…

HarmonyOS ArkUI(基于ArkTS) 常用组件

一 Button 按钮 Button是按钮组件&#xff0c;通常用于响应用户的点击操作,可以加子组件 Button(我是button)Button(){Text(我是button)}type 按钮类型 Button有三种可选类型&#xff0c;分别为胶囊类型&#xff08;Capsule&#xff09;、圆形按钮&#xff08;Circle&#xf…

Jenkins + gitee 自动触发项目拉取部署(Webhook配置)

目录 前言 Generic Webhook Trigger 插件 下载插件 ​编辑 配置WebHook 生成tocken 总结 前言 前文简单介绍了Jenkins环境搭建&#xff0c;本文主要来介绍一下如何使用 WebHook 触发自动拉取构建项目&#xff1b; Generic Webhook Trigger 插件 实现代码推送后&#xff0c;触…