【LSTM】LSTM cell的门结构学习笔记

文章目录

      • 1. LSTM cell
      • 2. 门结构
      • 3. 门的公式
      • 4. 门的参数
      • 5. 重点关系厘清

1. LSTM cell

  • 如文章 LSTM网络与参数学习笔记 中介绍, LSTM cell指的是一个包含隐藏层所有神经元的结构.
  • 但是LSTM门控单元的公式如何理解、门和LSTM cell神经元如何对应、门函数的参数维度、不同时间步不同隐藏层之间数据如何传递等, 这些问题将在本文厘清

2. 门结构

  • RNN存在长期记忆逐渐消失以及梯度消失/爆炸的问题

  • LSTM通过引入cell state保存长期记忆, 通过设置精妙的门控机制很大程度缓解梯度消失/爆炸问题.

  • LSTM中有三个门, 分别是遗忘门, 输入门和输出门

    • 遗忘门: f门

      • 决定我们会从细胞状态中丢弃什么信息
      • 它接收 h t − 1 h_{t-1} ht1 x t x_{t} xt作为输入参数,通过 s i g m o i d sigmoid sigmoid层得到对应的遗忘门的参数
      • 弄清楚今天发生的事情(输入 x x x)和最近发生的事情(隐藏状态 h h h),二者会影响你对情况的长期判断(细胞状态 C C C
    • 输入门: i门

      • 确定什么样的新信息被存放在细胞状态中
      • s i g m o i d sigmoid sigmoid层得到输入门参数 i t i_t it, 确定要更新的信息, t a n h tanh tanh层产生新的候选值 C ~ t \widetilde{C}_t C t. 最后将 i t i_t it C ~ t \widetilde{C}_t C t相乘得到更新的信息;同时将上面得到的遗忘门 f t f_t ft和旧元胞状态 C t − 1 C_{t-1} Ct1相乘,以忘掉其中的一些信息。二者相结合,便得到更新后的状态 C t C_t Ct
      • 最近发生的事情(隐藏状态 h h h)和今天发生的事情(输入 x x x)中的哪些信息需要记录到你对所处情况的长远判断中(细胞状态 C C C)
    • 输出门:o门

      • 计算最后的输出信息
      • 通过 t a n h tanh tanh层将细胞状态的值规范到 − 1 ∼ 1 -1\sim 1 11之间, 然后由 s i g m o i d sigmoid sigmoid层得到输出门参数 o t o_t ot, 最后将 o t o_t ot与规范化后的细胞状态点乘, 得到最终过滤后的结果 h t h_t ht
      • 得到所处情况的短期判断, 比如近期跟老板提加薪会不会答应

3. 门的公式

  • 首先回顾输入和输出的维度

    • batch_first = true

      input(batch_size, seq_len, input_size)
      output(batch_size, seq_len, hidden_size * num_directions)
      
    • batch_first = false

      input(seq_len, batch_size, input_size)
      output(seq_len, batch_size, hidden_size * num_directions)
      
  • 遗忘门

    • 公式:

    • 输入: h t − 1 h_{t-1} ht1, x t x_{t} xt的联合,即 [ h t − 1 , x t ] [h_{t-1}, x_{t}] [ht1,xt]

    • 输出:由于使用了 s i g m o i d sigmoid sigmoid函数,输出值在 0 ∼ 1 0\sim 1 01之间。0表示完全丢弃,1表示完全保留

    • 维度:如下。 f t f_t ft的维度是hidden_size,也就是 s i g m o i d sigmoid sigmoid层有hidden_size个神经元

      变量维度
      h t − 1 h_{t-1} ht1hidden_size
      x t x_{t} xtfeature_size
      [ h t − 1 , x t ] [h_{t-1}, x_{t}] [ht1,xt]hidden_size + feature_size
      W f W_{f} Wf[hidden_size, hidden_size + feature_size]
      b f b_{f} bfhidden_size
      f t f_{t} fthidden_size
    • 公式合并

  • 输入门

    • 公式:

    • 输入: i t i_t it C ~ t \widetilde{C}_t C t都以 [ h t − 1 , x t ] [h_{t-1}, x_{t}] [ht1,xt]为输入; i t i_t it通过 s i g m o i d sigmoid sigmoid层来实现; C ~ t \widetilde{C}_t C t通过 t a n h tanh tanh层来实现

    • 输出:

      • 同样的 i t i_t it 0 ∼ 1 0\sim 1 01之间, C ~ t \widetilde{C} _{t} C t^在 − 1 ∼ 1 -1\sim 1 11之间; *不是矩阵乘法, 是对应元素点乘
      • 将要更新的信息 i t ∗ C ~ t i_t*\widetilde{C} _{t} itC t要忘记的信息 f t ∗ C t − 1 f_t*C_{t-1} ftCt1相结合得到更新后的状态 C t C_t Ct
    • 含义:

      • f ∗ C t − 1 f*C_{t-1} fCt1的点乘(按元素相乘)实际上是在决定哪些信息从上一时刻的cell state中保留下来,哪些被遗忘。保留和遗忘的比例就是 f f f的值
      • i t ∗ C ~ t i_t*\widetilde{C} _{t} itC t表示有多少cell state候选值的新信息要更新到cell state中,更新的比例就是输入门 i t i_t it的值
      • C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_t*C_{t-1}+i_t*\widetilde{C} _{t} Ct=ftCt1+itC t将旧的信息(可能已被遗忘部分)与新的信息结合,形成当前时刻的cell state
    • 维度:如下。 s i g m o i d sigmoid sigmoid层有hidden_size个神经元, t a n h tanh tanh层有hidden_size个神经元

      变量维度
      h t − 1 h_{t-1} ht1hidden_size
      x t x_{t} xtfeature_size
      [ h t − 1 , x t ] [h_{t-1}, x_{t}] [ht1,xt]hidden_size + feature_size
      i t i_{t} ithidden_size
      C ~ t \widetilde{C} _{t} C thidden_size
      C t − 1 C_{t-1} Ct1hidden_size
      C t C_{t} Cthidden_size
      W i W_{i} Wi[hidden_size, hidden_size + feature_size]
      W C W_{C} WC[hidden_size, hidden_size + feature_size]
    • 公式合并:

  • 输出门

    • 公式

    • 输入: o t o_t ot [ h t − 1 , x t ] [h_{t-1}, x_{t}] [ht1,xt]为输入。 C t C_t Ct需要经过 t a n h tanh tanh层进行值缩放

    • 输出:这里的输出不是LSTM网络的输出,LSTM网络输出包括网络的 o u t p u t output output h n , c n h_n, c_n hn,cn

    • 含义:

      • t a n h ( C t ) tanh(C_t) tanh(Ct)将cell state的值压缩到-1和1之间,使得信息的表示更加集中
      • o t ∗ t a n h ( C t ) o_t*tanh(C_t) ottanh(Ct)的点乘决定cell state的哪些信息将被传到隐藏状态中,压缩后的cell state的传入比例通过 o t o_t ot的值来控制
    • 维度:如下。 s i g m o i d sigmoid sigmoid层有hidden_size个神经元, t a n h tanh tanh层也有hidden_size哥神经元;*是点乘,按元素相乘

      变量维度
      o t o_{t} othidden_size
      h t h_{t} htfeature_size
      W o W_o Wo[hidden_size, hidden_size + feature_size]
    • 公式合并:

4. 门的参数

  • 参数矩阵

  • 参数维度

    变量维度
    x t , h t − 1 , f t , i t , C ~ t , C t − 1 , C t , o t x_t, h_{t-1},f_t, i_t, \widetilde{C} _{t}, C_{t-1},C_t,o_t xt,ht1,ft,it,C t,Ct1,Ct,othidden_size
    x t x_{t} xtfeature_size
    [ h t − 1 , x t ] [h_{t-1}, x_{t}] [ht1,xt]hidden_size + feature_size
    W f , W i , W C , W o W_{f},W_{i},W_{C},W_{o} Wf,Wi,WC,Wo[hidden_size, hidden_size + feature_size]
    b f , b i , b C , b o b_f, b_i, b_C, b_o bf,bi,bC,bohidden_size
    s i g m o i d sigmoid sigmoid网络, t a n h tanh tanh网络hidden_size 个神经元
    ∗ * 对应元素点乘,维度不变

5. 重点关系厘清

图1
  • (1) 图中黄色框都是前馈神经网络, 神经元个数都是hidden_size个, 激活函数就是sigmoid和tanh
  • (2) 主要涉及到的参数就是 W f , W i , W C , W o W_{f},W_{i},W_{C},W_{o} Wf,Wi,WC,Wo权重参数和 b f , b i , b C , b o b_f, b_i, b_C, b_o bf,bi,bC,bo偏置参数
  • (3) 这些参数不是存在于神经元的数据结构中,而是存在于神经元之间的连接。由此理解,网络的要点就是"连接",神经元只是对应的数学运算
图2
  • (4) LSTM cell的整个隐藏层实现了这三个门,这三个门是LSTM cell的一部分
  • (5) LSTM cell的hidden_size是一个超参,是遗忘门/输入门/输出门的神经元个数
  • (6) 遗忘门/输入门/输出门各自维护自己的神经网络,不是共用神经网络。如图1中每个黄色框都是一个神经网络
  • (7) 每个门都有一组自己的权重和参数,也就是章节4中剃刀的W和b参数,这些参数在所有时刻是权值共享的,权值随着时间步在不断地更新
  • (8) b的参数个数 4 ∗ h i d d e n _ s i z e 4*hidden\_size 4hidden_size个,W的参数个数 4 ∗ ( h i d d e n _ s i z e ∗ ( h i d d e n _ s i z e + f e a t u r e _ s i z e ) ) 4*(hidden\_size * (hidden\_size + feature\_size)) 4(hidden_size(hidden_size+feature_size)),即 4 ∗ n u m _ u n i t s ∗ ( h t − 1 + x t ) 4*num\_units*(h_{t-1}+x_t) 4num_units(ht1+xt)
  • (9) 每个门的神经元个数也决定了其输出的维度
图3
  • (10) 如上图,输出门和神经网络的输出不是同一个东西。输出门得到的 h t h_t ht只是短期状态信息,输入到下一时刻或下一层使用;而LSTM网络的最终输出是输出层接 s i g m o i d / s o f t m a x sigmoid/softmax sigmoid/softmax等全连接层后的结果
图4
  • (11) 如上图,多层LSTM的情况, h t h_t ht横向传递给下一时刻作为 h t − 1 h_{t-1} ht1,纵向传递给下一层作为 x t x_t xt

 


 
创作不易,如有帮助,请 点赞 收藏 支持
 


 

[参考文章]
[1].通俗理解门的原理, 推荐
[2].门的公式的衔接和多层LSTM输出的关系
[3].同样,门的公式
[4].cell的内部参数图和公式推导
[5].参考逻辑结构:门和神经元之间的关系, 推荐
[6].对hidden_size的理解,门的计算过程
[7].反向传播的推导
[8].反向传播算法推导过程

created by shuaixio, 2024.05.21

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/15619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙 DevEco Studio 3.1 Release 下载sdk报错的解决办法

鸿蒙 解决下载SDK报错的解决方法 最近在学习鸿蒙开发,以后也会记录一些关于鸿蒙相关的问题和解决方法,希望能帮助到大家。 总的来说一般有下面这样的报错 报错一: Components to install: - ArkTS 3.2.12.5 - System-image-phone 3.1.0.3…

leecode 1206|跳表的设计

跳表 跳表,一种链表数据结构,其增删改茶的效率能和平衡树相媲美 leecode1206 可以看上面的那个动画,动画效果很贴切。 我简单讲讲它的机制吧,每个节点不单单是一个,测试好几层,然后同一层的节点和统一节点…

Tomcat部署项目的方式

目录 1、Tomcat发布项目的方式 方式1: 直接把项目发布到webapps目录下 方式2:项目发布到ROOT目录 方式3:虚拟路径方式发布项目 方式4:(推荐)虚拟路径,另外的方式! 方式5:发布多个网站 1、…

掩码生成蒸馏——知识蒸馏

摘要 https://arxiv.org/pdf/2205.01529 知识蒸馏已成功应用于各种任务。当前的蒸馏算法通常通过模仿教师的输出来提高学生的性能。本文表明,教师还可以通过指导学生的特征恢复来提高学生的表示能力。从这一观点出发,我们提出了掩码生成蒸馏&#xff08…

【字典树(前缀树) 异或 离线查询】1707. 与数组中元素的最大异或值

本文涉及知识点 字典树(前缀树) 位运算 异或 离线查询 LeetCode1707. 与数组中元素的最大异或值 给你一个由非负整数组成的数组 nums 。另有一个查询数组 queries ,其中 queries[i] [xi, mi] 。 第 i 个查询的答案是 xi 和任何 nums 数组…

C++ | Leetcode C++题解之第97题交错字符串

题目&#xff1a; 题解&#xff1a; class Solution { public:bool isInterleave(string s1, string s2, string s3) {auto f vector <int> (s2.size() 1, false);int n s1.size(), m s2.size(), t s3.size();if (n m ! t) {return false;}f[0] true;for (int i …

264 基于matlab的自适应语音盲分离

基于matlab的自适应语音盲分离&#xff0c;当a和b同时对着传声器A,B说话且传声器靠得很近时&#xff0c;传声器A,B会同时接受到a和b的声音&#xff0c;即a和b产生了混叠干扰&#xff0c;此时通过自适应语音盲分离系统可以将a,b的声音分离开&#xff0c;使得一个信道只有一个人的…

2024.05.25学习记录

1、面经复习&#xff1a; JS异步进阶、vue-react-diff、vue-router模式、requestldleCallback、React Fiber 2、代码随想录刷题、动态规划 3、组件库使用storybook

python抽取pdf中的参考文献

想将一份 pdf 论文中的所有参考文献都提取出来&#xff0c;去掉不必要的换行&#xff0c;放入一个 text 文件&#xff0c;方便复制。其引用是 ieee 格式的&#xff0c;形如&#xff1a; 想要只在引用序号&#xff08;如 [3]&#xff09;前换行&#xff0c;其它换行都去掉&…

VTK 数据处理:特征边提取

VTK 数据处理&#xff1a;特征边提取 VTK 数据处理&#xff1a;特征边提取原理实例 1&#xff1a;边界边提取实例 2&#xff1a;模型特征边提取实例 3&#xff1a;利用 vtkFeatureEdges 提取的边界补洞实例 4&#xff1a;利用 vtkFillHolesFilter 补洞 VTK 数据处理&#xff1a…

OC属性关键字和单例模式

OC的属性关键字和单例模式 文章目录 OC的属性关键字和单例模式单例模式基本创建重写allocWithZone方法的同时使用dispatch_once 属性和属性关键字property和synthesize&#xff0c;dynamic属性关键字atomic和nonatomicstrong和weakreadonly和readwritestrong和copy 单例模式 单…

MySQL--存储引擎

一、存储引擎介绍 1.介绍 存储引擎相当于Linux的文件系统&#xff0c;以插件的模式存在&#xff0c;是作用在表的一种属性 2.MySQL中的存储引擎类型 InnoDB、MyISAM、CSV、Memory 3.InnoDB核心特性的介绍 聚簇索引、事务、MVCC多版本并发控制、行级锁、外键、AHI、主从复制特…

Python | Leetcode Python题解之第112题路径总和

题目&#xff1a; 题解&#xff1a; class Solution:def hasPathSum(self, root: TreeNode, sum: int) -> bool:if not root:return Falseif not root.left and not root.right:return sum root.valreturn self.hasPathSum(root.left, sum - root.val) or self.hasPathSum…

关于在子线程中获取不到HttpServletRequest对象的问题

这篇文章主要分享一下项目里遇到的获取request对象为null的问题&#xff0c;具体是在登录的时候触发的邮箱提醒&#xff0c;获取客户端ip地址&#xff0c;然后通过ip地址定位获取定位信息&#xff0c;从而提示账号在哪里登录。 但是登录却发现获取request对象的时候报错了。 具…

Docker提示某网络不存在如何解决,添加完网络之后如何删除?

Docker提示某网络不存在如何解决&#xff1f; 创建 Docker 网络 假设现在需要创建一个名为my-mysql-network的网络 docker network create my-mysql-network运行容器 创建网络之后&#xff0c;再运行 mysqld_exporter 容器。完整命令如下&#xff1a; docker run -d -p 9104…

认识K8s集群的声明式资源管理方法

前言 Kubernetes 集群的声明式资源管理方法是当今云原生领域中的核心概念之一&#xff0c;使得容器化应用程序的部署和管理变得更加高效和可靠。本文将认识了解 Kubernetes 中声明式管理的相关理念、实际应用以及优势。 目录 一、管理方法介绍 1. 概述 2. 语法格式 2.1 管…

Spring Boot Interceptor(拦截器使用及原理)

之前的博客中讲解了关于 Spring AOP的思想和原理&#xff0c;而实际开发中Spring Boot对于AOP的思想的具体实现就是Spring Boot Interceptor。在 Spring Boot 应用程序开发中&#xff0c;拦截器&#xff08;Interceptor&#xff09;是一个非常有用的工具。它允许我们在 HTTP 请…

Redis可视化工具:Another Redis Desktop Manager下载安装使用

1.Github下载 github下载地址&#xff1a; Releases qishibo/AnotherRedisDesktopManager GitHub 2. 安装 直接双击exe文件进行安装 3. 连接Redis服务 先启动Redis服务&#xff0c;具体启动过程可参考&#xff1a; Windows安装并启动Redis服务端&#xff08;zip包&#xff09…

Golang | Leetcode Golang题解之第111题二叉树的最小深度

题目&#xff1a; 题解&#xff1a; func minDepth(root *TreeNode) int {if root nil {return 0}queue : []*TreeNode{}count : []int{}queue append(queue, root)count append(count, 1)for i : 0; i < len(queue); i {node : queue[i]depth : count[i]if node.Left …