【深度学习 | LSTM】解开LSTM的秘密:门控机制如何控制信息流

在这里插入图片描述

🤵‍♂️ 个人主页: @AI_magician
📡主页地址: 作者简介:CSDN内容合伙人,全栈领域优质创作者。
👨‍💻景愿:旨在于能和更多的热爱计算机的伙伴一起成长!!🐱‍🏍
🙋‍♂️声明:本人目前大学就读于大二,研究兴趣方向人工智能&硬件(虽然硬件还没开始玩,但一直很感兴趣!希望大佬带带)

在这里插入图片描述

【深度学习 | LSTM】解开LSTM的秘密:门控机制如何控制信息流
作者: 计算机魔术师
版本: 1.0 ( 2023.8.27 )

摘要: 本系列旨在普及那些深度学习路上必经的核心概念,文章内容都是博主用心学习收集所写,欢迎大家三联支持!本系列会一直更新,核心概念系列会一直更新!欢迎大家订阅

该文章收录专栏
[✨— 《深入解析机器学习:从原理到应用的全面指南》 —✨]

LSTM最全详解,

  • LSTM
      • 原理详解(每个神经元)
        • a. 遗忘门:Forget Gate
        • b. 输入门:Input Gate
        • c. Cell State
        • d. 输出门:Output Gate
      • 参数详解
      • 参数计算
      • 实际场景

LSTM

原理详解(每个神经元)

LSTM(Long Short-Term Memory)是一种常用于处理序列数据的循环神经网络模型。LSTM的核心思想是在传递信息的过程中,通过门的控制来选择性地遗忘或更新信息。LSTM中主要包含三种门:输入门(input gate)、输出门(output gate)和遗忘门(forget gate),以及一个记忆单元(memory cell)。

在LSTM层中,有三个门控单元,即输入门、遗忘门和输出门。这些门控单元在每个时间步上控制着LSTM单元如何处理输入和记忆。在每个时间步上,LSTM单元从输入、前一个时间步的输出和前一个时间步的记忆中计算出当前时间步的输出和记忆。

在LSTM的每个时间步中,输入 x t x_t xt和前一时刻的隐状态 h t − 1 h_{t-1} ht1被馈送给门控制器,然后门控制器根据当前的输入 x t x_t xt和前一时刻的隐状态 h t − 1 h_{t-1} ht1计算出三种门的权重,然后将这些权重作用于前一时刻的记忆单元 c t − 1 c_{t-1} ct1。具体来说,门控制器计算出三个向量:**输入门的开启程度 i t i_t it、遗忘门的开启程度 f t f_t ft和输出门的开启程度 o t o_t ot,这三个向量的元素值均在[0,1]**之间。

然后,使用这些门的权重对前一时刻的记忆单元 c t − 1 c_{t-1} ct1进行更新,计算出当前时刻的记忆单元 c t c_t ct,并将它和当前时刻的输入 x t x_t xt作为LSTM的输出 y t y_t yt。最后,将当前时刻的记忆单元 c t c_t ct和隐状态 h t h_t ht一起作为下一时刻的输入,继续进行LSTM的计算。

在这里插入图片描述

如果你对LSTM以及其与反向传播算法之间的详细联系感兴趣,我建议你参考以下资源:

  1. “Understanding LSTM Networks” by Christopher Olah: https://colah.github.io/posts/2015-08-Understanding-LSTMs/ 强烈推荐!!!
  2. TensorFlow官方教程:Sequence models and long-short term memory network (https://www.tensorflow.org/tutorials/text/text_classification_rnn)
  3. PyTorch官方文档:nn.LSTM (https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html)
  4. 详细讲解RNN,LSTM,GRU https://towardsdatascience.com/a-brief-introduction-to-recurrent-neural-networks-638f64a61ff4

以上资源将为你提供更多关于LSTM及其与反向传播算法结合使用的详细解释、示例代码和进一步阅读材料。

LSTM 的核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”,记忆门一个控制信号控制门是否应该保留该信息,在实现上通常是乘1或乘0来选择保留或忘记。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。

LSTM的参数包括输入到状态的权重 W x i , W h i , b i W_{xi},W_{hi},b_i Wxi,Whi,bi,输入到遗忘门的权重 W x f , W h f , b f W_{xf},W_{hf},b_f Wxf,Whf,bf,输入到输出门的权重 W x o , W h o , b o W_{xo},W_{ho},b_o Wxo,Who,bo,以及输入到记忆单元的权重 W x c , W h c , b c W_{xc},W_{hc},b_c Wxc,Whc,bc,其中 W W W表示权重矩阵, b b b表示偏置向量。在实际应用中,LSTM模型的参数通常需要通过训练来获得,以最小化预测误差或最大化目标函数。

a. 遗忘门:Forget Gate

遗忘门的功能是决定应丢弃或保留哪些信息。来自前一个隐藏状态的信息和当前输入的信息同时传递到 sigmoid 函数中去,输出值介于 0 和 1 之间,越接近 0 意味着越应该丢弃,越接近 1 意味着越应该保留。

遗忘门的计算公式

b. 输入门:Input Gate

输入门用于更新细胞状态。首先将前一层隐藏状态的信息和当前输入的信息传递到 sigmoid 函数中去。将值调整到 0~1 之间来决定要更新哪些信息。0 表示不重要,1 表示重要。其次还要将前一层隐藏状态的信息和当前输入的信息传递到 tanh 函数中去,创造一个新的侯选值向量。最后将 sigmoid 的输出值与 tanh 的输出值相乘,sigmoid 的输出值将决定 tanh 的输出值中哪些信息是重要且需要保留下来

使用tanh作为LSTM输入层的激活函数,一定程度上可以避免梯度消失和梯度爆炸的问题。

在LSTM中,如果权重值较大或者较小,那么在反向传播时,梯度值会非常大或者非常小,导致梯度爆炸或者消失的情况。而tanh函数的导数范围在[-1, 1]之间,可以抑制梯度的放大和缩小,从而避免了梯度爆炸和消失的问题(RNN遇到的问题)。此外,tanh函数在输入为0附近的时候输出接近于线性,使得网络更容易学习到线性相关的特征。另外,tanh 函数具有对称性,在处理序列数据时能够更好地捕捉序列中的长期依赖关系。

因此,使用tanh作为LSTM输入层的激活函数是比较常见的做法。

c. Cell State

首先前一层的细胞状态与遗忘向量逐点相乘。如果它乘以接近 0 的值,意味着在新的细胞状态中,这些信息是需要丢弃掉的。然后再将该值与输入门的输出值逐点相加,将神经网络发现的新信息更新到细胞状态中去。至此,就得到了更新后的细胞状态。

d. 输出门:Output Gate

输出门用来确定下一个隐藏状态的值,隐藏状态包含了先前输入的信息。首先,我们将前一个隐藏状态和当前输入传递到 sigmoid 函数中,然后将新得到的细胞状态传递给 tanh 函数。最后将 tanh 的输出与 sigmoid 的输出相乘,以确定隐藏状态应携带的信息。再将隐藏状态作为当前细胞的输出,把新的细胞状态和新的隐藏状态传递到下一个时间步长中去。

在LSTM层中,每个时间步上的计算涉及到许多参数,包括输入、遗忘和输出门的权重,以及当前时间步和前一个时间步的输出和记忆之间的权重。这些参数在模型训练过程中通过反向传播进行学习,以最小化模型在训练数据上的损失函数。总之,LSTM通过门的控制,使得信息在传递过程中可以有选择地被遗忘或更新,从而更好地捕捉长序列之间的依赖关系,广泛应用于语音识别、自然语言处理等领域。

LSTM的输出可以是它的最终状态(最后一个时间步的隐藏状态)或者是所有时间步的隐藏状态序列。通常,LSTM的最终状态可以被看作是输入序列的一种编码,可以被送入其他层进行下一步处理。如果需要使用LSTM的中间状态,可以将return_sequences参数设置为True,这样LSTM层将返回所有时间步的隐藏状态序列,而不是仅仅最终状态。

需要注意的是,LSTM层在处理长序列时容易出现梯度消失或爆炸的问题。为了解决这个问题,通常会使用一些技巧,比如截断反向传播、梯度裁剪、残差连接等

参数详解

layers.LSTM 是一个带有内部状态的循环神经网络层,其中包含了多个可训练的参数。具体地,LSTM层的输入是一个形状为(batch_size, timesteps, input_dim)的三维张量,其中batch_size表示输入数据的批次大小,timesteps表示序列数据的时间步数,input_dim表示每个时间步的输入特征数。LSTM层的输出是一个形状为**(batch_size, timesteps, units)的三维张量,其中units表示LSTM层的输出特征数**。以下是各个参数的详细说明:

  • units:LSTM 层中的单元数,即 LSTM 层输出的维度。
  • activation:激活函数,用于计算 LSTM 层的输出和激活门。
  • recurrent_activation:循环激活函数,用于计算 LSTM 层的循环状态。
  • use_bias:是否使用偏置向量。
  • kernel_initializer:用于初始化 LSTM 层的权重矩阵的初始化器。
  • recurrent_initializer:用于初始化 LSTM 层的循环权重矩阵的初始化器。
  • bias_initializer:用于初始化 LSTM 层的偏置向量的初始化器。
  • unit_forget_bias:控制 LSTM 单元的偏置初始化,如果为 True,则将遗忘门的偏置设置为 1,否则设置为 0。
  • kernel_regularizer:LSTM 层权重的正则化方法。
  • recurrent_regularizer:LSTM 层循环权重的正则化方法。
  • bias_regularizer:LSTM 层偏置的正则化方法。
  • activity_regularizer:LSTM 层输出的正则化方法。
  • dropout:LSTM 层输出上的 Dropout 比率。
  • recurrent_dropout:LSTM 层循环状态上的 Dropout 比率。
  • return_sequences: 可以控制LSTM的输出形式。如果设置为True,则输出每个时间步的LSTM的输出,如果设置为False,则只输出最后一个时间步的LSTM的输出。因此,return_sequences的默认值为False,如果需要输出每个时间步的LSTM的输出,则需要将其设置为True。

这些参数的不同设置将直接影响到 LSTM 层的输出和学习能力。需要根据具体的应用场景和数据特点进行选择和调整。

tf.keras.layers.LSTM(
units,
activation=“tanh”,
recurrent_activation=“sigmoid”, #用于重复步骤的激活功能
use_bias=True, #是否图层使用偏置向量
kernel_initializer=“glorot_uniform”, #kernel权重矩阵的 初始化程序,用于输入的线性转换
recurrent_initializer=“orthogonal”, #权重矩阵的 初始化程序,用于递归状态的线性转换
bias_initializer=“zeros”, #偏差向量的初始化程序
unit_forget_bias=True, #则在初始化时将1加到遗忘门的偏置上
kernel_regularizer=None, #正则化函数应用于kernel权重矩阵
recurrent_regularizer=None, #正则化函数应用于 权重矩阵
bias_regularizer=None, #正则化函数应用于偏差向量
activity_regularizer=None, #正则化函数应用于图层的输出(其“激活”)
kernel_constraint=None,#约束函数应用于kernel权重矩阵
recurrent_constraint=None,#约束函数应用于 权重矩阵
bias_constraint=None,#约束函数应用于偏差向量
dropout=0.0,#要进行线性转换的输入单位的分数
recurrent_dropout=0.0,#为递归状态的线性转换而下降的单位小数
return_sequences=False,#是否返回最后一个输出。在输出序列或完整序列中
return_state=False,#除输出外,是否返回最后一个状态
go_backwards=False,#如果为True,则向后处理输入序列并返回反向的序列
stateful=False,#如果为True,则批次中索引i的每个样本的最后状态将用作下一个批次中索引i的样本的初始状态。
time_major=False,
unroll=False,#如果为True,则将展开网络,否则将使用符号循环。展开可以加快RNN的速度,尽管它通常会占用更多的内存。展开仅适用于短序列。
)

参数计算

对于一个LSTM(长短期记忆)模型,参数的计算涉及输入维度、隐藏神经元数量和输出维度。在给定输入维度为(64,32)和LSTM神经元数量为32的情况下,我们可以计算出以下参数:

  1. 输入维度:(64,32)
    这表示每个时间步长(sequence step)的输入特征维度为32,序列长度为64。

  2. 隐藏神经元数量:32
    这是指LSTM层中的隐藏神经元数量。每个时间步长都有32个隐藏神经元。

  3. 输入门参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。其中32是上一时间步的隐藏状态大小,另外32是当前时间步的输入维度。
    • 偏置向量:形状为(32,)的向量。
  4. 遗忘门参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。
    • 偏置向量:形状为(32,)的向量。
  5. 输出门参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。
    • 偏置向量:形状为(32,)的向量。
  6. 单元状态参数:

    • 权重矩阵:形状为(32,32 + 32)的矩阵。
    • 偏置向量:形状为(32,)的向量。
  7. 输出参数:

    • 权重矩阵:形状为(32,32)的矩阵。将隐藏状态映射到最终的输出维度。
    • 偏置向量:形状为(32,)的向量。

因此,总共的参数数量可以通过计算上述所有矩阵和向量中的元素总数来确定。

实际场景

当使用LSTM(长短期记忆)神经网络进行时间序列预测时,可以根据输入和输出的方式将其分为四种类型:单变量单步预测、单变量多步预测、多变量单步预测和多变量多步预测。

  1. 单变量单步预测:

    • 输入:只包含单个时间序列特征的历史数据。
    • 输出:预测下一个时间步的单个时间序列值。
    • 例如,给定过去几天的某股票的收盘价,使用LSTM进行单变量单步预测将预测未来一天的收盘价。
  2. 单变量多步预测:

    • 输入:只包含单个时间序列特征的历史数据。
    • 输出:预测接下来的多个时间步的单个时间序列值。
    • 例如,给定过去几天的某股票的收盘价,使用LSTM进行单变量多步预测将预测未来三天的收盘价。
  3. 多变量单步预测:

    • 输入:包含多个时间序列特征的历史数据。
    • 输出:预测下一个时间步的一个或多个时间序列值。
    • 例如,给定过去几天的某股票的收盘价、交易量和市值等特征,使用LSTM进行多变量单步预测可以预测未来一天的收盘价。
  4. 多变量多步预测:

    • 输入:包含多个时间序列特征的历史数据。
    • 输出:预测接下来的多个时间步的一个或多个时间序列值。
    • 例如,给定过去几天的某股票的收盘价、交易量和市值等特征,使用LSTM进行多变量多步预测将预测未来三天的收盘价。

这些不同类型的时间序列预测任务在输入和输出的维度上略有差异,但都可以通过适当配置LSTM模型来实现。具体的模型架构和训练方法可能会因任务类型和数据特点而有所不同。

在这里插入图片描述

						  🤞到这里,如果还有什么疑问🤞🎩欢迎私信博主问题哦,博主会尽自己能力为你解答疑惑的!🎩🥳如果对你有帮助,你的赞是对博主最大的支持!!🥳

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/82390.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(超详解)堆排序+(图解)

目录: 1:如何建堆(两种方法) 2:两种方法建堆的时间复杂度分析与计算 3:不同类型的排序方式我们应该如何建堆 文章正式开始: 1:如何建堆 在实现堆排序之前我们必须得建堆,才能够实现堆排序 首先在讲解如何建堆之前让我们先来回顾一…

JDK8新特性

Lembda表达式 lembda表达式是一个简洁、可传递的匿名函数,实现了把代码块赋值给一个变量的功能 是我认为jdk1.8中最让人眼前一亮的特性(我没用过其他函数式的语言) 在了解表达式之前,我们先看两个概念 函数式接口 含有且仅含有一个抽象方法&…

CSS核心使用

CSS核心使用 box-sizingbox-shdowtext-shadowpositionwriting-mode box-sizing 定义计算一个元素的总高度和总宽度. 属性值 content-box 默认值,width 内容宽度,height内容的高度border-box 宽度和高度包含内容,内边距和边框 widthborderpadding内容宽度, heightborderpaddi…

测试进阶知识之零日攻击的发现和防御

零日攻击是指针对软件或系统中未公开(或未被开发者知晓)的漏洞进行的攻击。这些漏洞被称为零日漏洞,因为在被公开之前,它们对开发者或安全研究人员来说是未知的,所以没有足够的时间进行防御或修复。 发现零日漏洞 发…

启动YOLO进行图片物体识别

查看官方文档YOLO: Real-Time Object Detection 这些是一些模型的对比,显示了YOLO的优势,继续往下面看 CoCoData set 是一个数据库,用来训练模型,这里面有丰富的物体检测,分割数据集,图像经过了精确的segm…

Pikachu Burte Force(暴力破解)

一、Burte Force(暴力破解)概述 ​ “暴力破解”是一攻击具手段,在web攻击中,一般会使用这种手段对应用系统的认证信息进行获取。 其过程就是使用大量的认证信息在认证接口进行尝试登录,直到得到正确的结果。 为了提高…

Jenkins List Git Branches插件 构建选择指定git分支

List Git Branches Parameter | Jenkins pluginAdds ability to choose from git repository revisions or tagshttps://plugins.jenkins.io/list-git-branches-parameter/ 1、安装组件 List Git Branches 2、验证功能 1)新建任务 2)新增构建参数 3&…

JavaSE List

目录 1 预备知识-泛型(Generic)1.1 泛型的引入1.2 泛型类的定义的简单演示 1.3 泛型背后作用时期和背后的简单原理1.4 泛型类的使用1.5 泛型总结 2 预备知识-包装类(Wrapper Class)2.1 基本数据类型和包装类直接的对应关系2.2 包装类的使用,装…

【教程】微信小程序导入外部字体详细流程

前言 在微信小程序中,我们在wxss文件中通过font-family这一CSS属性来设置文本的字体,并且微信小程序有自身支持的内置字体,可以通过代码提示查看微信小程序支持字体: 这些字体具体是什么样式可以参考: 微信小程序--字…

ATF(TF-A) SPMC威胁模型-安全检测与评估

安全之安全(security)博客目录导读 ATF(TF-A) 威胁模型汇总 目录 一、简介 二、评估目标 1、数据流图 三、威胁分析 1、信任边界 2、资产 3、威胁代理 4、威胁类型 5、威胁评估 5.1 端点在直接请求/响应调用中模拟发送方或接收方FF-A ID 5.2 篡改端点和SPMC之间的…

基于element-ui的年份范围选择器

基于element-ui的年份范围选择器 element-ui官方只有日期范围和月份范围选择器,根据需求场景需要,支持年份选择器,原本使用两个分开的年份选择器实现的,但是往往有些是不能接受的。在网上找了很多都没有合适的,所以打…

【内网穿透】公网远程访问本地硬盘文件

公网远程访问本地硬盘文件【内网穿透】 文章目录 公网远程访问本地硬盘文件【内网穿透】前言1. 下载cpolar和Everything软件3. 设定http服务器端口4. 进入cpolar的设置5. 生成公网连到本地内网穿透数据隧道 总结 前言 随着云概念的流行,不少企业采用云存储技术来保…

Linux 信号相关

int kill(pid_t pid, int sig); -功能:给某个进程pid 发送某个信号 参数sig可以使用宏值或者和它对应的编号 参数pid: >0 ;将信号发给指定的进程 0;将信号发送给当前的进程组 -1;发送给每一个有权限接受这个信号的…

【面试必刷TOP101】删除链表的倒数第n个节点 两个链表的第一个公共结点

目录 题目:删除链表的倒数第n个节点_牛客题霸_牛客网 (nowcoder.com) 题目的接口: 解题思路: 代码: 过啦!!! 题目:两个链表的第一个公共结点_牛客题霸_牛客网 (nowcoder.com) …

【C++ 学习 ㉑】- 详解 map 和 set(上)

目录 一、C STL 关联式容器 二、pair 类模板 三、set 3.1 - set 的基本介绍 3.2 - set 的成员函数 3.1.1 - 构造函数 3.1.2 - 迭代器 3.1.3 - 修改操作 3.1.4 - 其他操作 四、map 4.1 - map 的基本介绍 4.2 - map 的成员函数 4.2.1 - 迭代器 4.2.2 - operator[] …

go语言---锁

什么是锁呢?就是某个协程(线程)在访问某个资源时先锁住,防止其它协程的访问,等访问完毕解锁后其他协程再来加锁进行访问。这和我们生活中加锁使用公共资源相似,例如:公共卫生间。 死锁 死锁是…

Ubuntu安装中文拼音输入法

ubuntu安装中文拼音输入法 ubuntu版本为23.04 1、安装中文语言包 首先安装中文输入法必须要让系统支持中文语言,可以在 Language Support 中安装中文语言包。 添加或删除语音选项,添加中文简体,然后会有Applying changes的对话框&#x…

vue 把echarts封装成一个方法 并且从后端读取数据 +转换数据格式 =动态echarts 联动echarts表

1.把echarts 在 methods 封装成一个方法mounted 在中调用 折线图 和柱状图 mounted调用下边两个方法 mounted(){//最早获取DOM元素的生命周期函数 挂载完毕console.log(mounted-id , document.getElementById(charts))this.line()this.pie()},methods里边的方法 line() {// …

在Android studio 创建Flutter项目运行出现问题总结

在Android studio 中配置Flutter出现的问题 A problem occurred configuring root project ‘android’出现这个问题。解决办法 首先找到flutter配置的位置 在D:\xxx\flutter\packages\flutter_tools\gradle位置中的flutter.gradle buildscript { repositories { googl…

3D目标检测框架 MMDetection3D环境搭建 docker篇

本文介绍如何搭建3D目标检测框架,使用docker快速搭建MMDetection3D的开发环境,实现视觉3D目标检测、点云3D目标检测、多模态3D目标检测等等。 需要大家提前安装好docker,并且docker版本> 19.03。 1、下载MMDetection3D源码 https://gith…