循环神经网络变形之 (Long Short Term Memory,LSTM)

1、长短期记忆网络LSTM简介

在RNN 计算中,讲到对于传统RNN水平方向进行长时刻序列依赖时可能会出现梯度消失或者梯度爆炸的问题。LSTM 特别适合解决这种需要长时间依赖的问题。

LSTM(Long Short Term Memory,长短期记忆网络)是RNN的一种,大体结构一直,区别在于:

  • LSTM 的‘记忆cell’ 是被改造过的,水平方向减少梯度消失与梯度爆炸
  • 该记录的信息会一直传递,不该记录的信息会被截断掉,部分输出和输入被从网络中删除

RNN 在语音识别,语言建模,翻译,图片描述等问题的应用的成功,都是通过 LSTM 达到的。

在这里插入图片描述

2、LSTM工作原理

2.1、传统的RNN“细胞”结构

所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。


2.2、LSTM结构

如下图展示了LSTM的一个神经元内部的结构。单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。

图中使用的各种元素的图标:

  • 每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。合在一起的线表示向量的连接,比如一个十维向量和一个二十维向量合并后形成一个三十维向量;分开的线表示内容被复制,然后分发到不同的位置。
  • 粉色的圈代表 pointwise 的操作,诸如向量的加法,减法,乘法,除法,都是矩阵的。
  • 黄色的矩阵就是神经网络层。

2.3、细胞状态

LSTM关键:“细胞状态” 。细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变很容易。

LSTM怎么控制“细胞状态”?

  • LSTM可以通过gates(“门”)结构来去除或者增加“细胞状态”的信息
  • 包含一个sigmoid神经网络层次和一个pointwist乘法操作
  • Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允 许任务变量通过”,1表示“运行所有变量通过”
  • LSTM中主要有三个“门”结构来控制“细胞状态”
  • C_{t-1} 和 C_t 矩阵的shape 是一样的

忘记门

决定从“细胞状态”中丢弃什么信息;比如在语言模型中,细胞状态可能包含了性别信息(“他”或者“她”),当我们看到新的代名词的时候,可以考虑忘记旧的数据

信息增加门

  • 决定放什么新信息到“细胞状态”中;
  • Sigmoid层决定什么值需要更新; 
  • Tanh层创建一个新的候选向量C_t
  • 主要是为了状态更新做准备

在这里插入图片描述

经过第一个和第二个“门”后,可以确定传递信息的删除和增加,即可以进行 “细胞状态”的更新

  • 更新C_(t-1)C_ t;
  • 将旧状态与f_t相乘,丢失掉确定不要的信息;
  • 加上新的候选值i_t*C_t得到最终更新后的“细胞状态”

输出门

  • 首先运行一个sigmoid网络层来确定细胞状态的那个部分将输出
  • 使用tanh处理细胞状态得到一个-1到1之间的值,再将它和sigmoid门的输出相乘,输出程序确定输出的部分。

  • 前向传播和反向传播可以参看前面的传播过程写下来,更新LSTM中的参数。具体的公式可以参看:RNN以及LSTM的介绍和公式梳理_Dark_Scope的博客-CSDN博客_lstm公式
  • 作者的论文:https://arxiv.org/pdf/1402.1128v1.pdf

2.4、 前向传播和后向传播

前向传播

现在我们来总结下LSTM前向传播算法。LSTM模型有两个隐藏状态h^{(t)},C^{(t)},模型参数几乎是RNN的4倍,因为现在多了W_f,U_f,b_f,W_a,U_a,b_a,W_i,U_i,b_i,W_o,U_o,b_o这些参数。   

前向传播过程在每个序列索引位置的过程为:

1)更新遗忘门输出:

f^{(t)}=\sigma (W_fh^{(t-1)}+U_fx^{(t)}+b_f)

2)更新输入门两部分输出:

i^{(t)}=\sigma (W_ih^{(t-1)}+U_ix^{(t)}+b_i)

a^{(t)}=tanh(W_ah^{(t-1)}+U_ax^{(t)}+b_a)

3)更新细胞状态:

C^{(t)}=C^{(t-1)}\odot f^{(t)}+i^{(t)}\odot a^{(t)}

4)更新输出门输出:

o^{(t)}=\sigma (W_oh^{(t-1)}+U_ox^{(t)}+b_o)

h^{(t))}=o^{(t)}\odot tanh(C^{(t)})

5)更新当前序列索引预测输出:

\hat{y} ^{(t)}=\sigma (Vh^{(t)}+c)

整体的过程如下图所示

可以看到,在t tt时刻,C^{(t)}用于计算h^{(t)}C^{(t+1)}

反向传播

有了LSTM前向传播算法,推导反向传播算法就很容易了, 思路和RNN的反向传播算法思路一致,也是通过梯度下降法迭代更新我们所有的参数,关键点在于计算所有参数基于损失函数的偏导数。

在RNN中,为了反向传播误差,我们通过隐藏状态h^{(t)}的梯度\delta^{(t)}一步步向前传播。在LSTM这里也类似。只不过我们这里有两个隐藏状态h^{(t)}C^{(t)}。这里我们定义两个\delta,即:

\delta_h^{(t)} = \frac{\partial L}{\partial h^{(t)}}\qquad \text{(8)}

\delta_C^{(t)} = \frac{\partial L}{\partial C^{(t)}}\qquad \text{(9)}

反向传播时只使用了\delta_C^{(t)},变量\delta_h^{(t)}仅为帮助我们在某一层计算用,并没有参与反向传播,这里要注意。如下图所示:

因为,我们在输出层定义的损失函数为对数损失,激活函数为softmax激活函数。因为,与RNN的推导类似,在最后的序列索引位置 \tau\delta_h^{(\tau)}\delta_C^{(\tau)}为:

\delta_h^{(\tau)} =\frac{\partial L}{\partial O^{(\tau)}} \frac{\partial O^{(\tau)}}{\partial h^{(\tau)}} = V^T(\hat{y}^{(\tau)} - y^{(\tau)})\qquad \text{(10)}
\delta_C^{(\tau)} =\frac{\partial L}{\partial h^{(\tau)}} \frac{\partial h^{(\tau)}}{\partial C^{(\tau)}} = \delta_h^{(\tau)} \odot o^{(\tau)} \odot (1 - tanh^2(C^{(\tau)}))\qquad \text{(11)}

接着我们由\delta_C^{(t+1)}反向推导\delta_C^{(t)}

\delta_h^{(t)}的梯度由本层的输出梯度误差决定,与公式(10)类似,即:

\delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}} = V^T(\hat{y}^{(t)} - y^{(t)})\qquad \text{(12)}

\delta_C^{(t)}​的反向梯度误差由前一层\delta_C^{(t+1)}​的梯度误差和本层的从h^{(t)}传回来的梯度误差两部分组成,即:

\delta_C^{(t)} =\frac{\partial L}{\partial C^{(t+1)}} \frac{\partial C^{(t+1)}}{\partial C^{(t)}} + \frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial C^{(t)}} = \delta_C^{(t+1)}\odot f^{(t+1)} + \delta_h^{(t)} \odot o^{(t)} \odot (1 - tanh^2(C^{(t)})) \qquad \text{(13)}

公式(13)的前半部分由公式(4)和公式(9)得到,公式(13)的后半部分由公式(6)和公式(8)得到。

有了\delta_h^{(t)}\delta_C^{(t)}, 计算这一大堆参数的梯度就很容易了,这里只给出W_f​的梯度计算过程,其他的U_f, b_f, W_a, U_a, b_a, W_i, U_i, b_i, W_o, U_o, b_o,V, c的梯度大家只要照搬就可以了。
\frac{\partial L}{\partial W_f} = \sum\limits_{t=1}^{\tau}\frac{\partial L}{\partial C^{(t)}} \frac{\partial C^{(t)}}{\partial f^{(t)}} \frac{\partial f^{(t)}}{\partial W_f} =\sum\limits_{t=1}^{\tau} \delta_C^{(t)} \odot C^{(t-1)} \odot f^{(t)}\odot(1-f^{(t)}) (h^{(t-1)})^T \qquad \text{(14)}

公式(13)的由公式(1)、公式(4)和公式(9)得到。

由上面可以得到,只要我们清晰地搞清楚前向传播过程,且只使用了\delta_C^{(t)}​进行反向传播的话,反向传播的整个过程是比较清晰的。

在这里有必要解释下为什么反向传播不使用\delta_h^{(t)}​,如果与循环神经网络(RNN)模型的前向反向传播算法里一样的话,那么\delta_h^{(t)}的计算方式就不应该是(12)式了

\delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}} = V^T(\hat{y}^{(t)} - y^{(t)})\qquad \text{(12)}

因为,h^{(t)}参与了L^{(t)}L^{(t+1)}的计算,所以在RNN文章里的求梯度方法,\frac{\partial L}{\partial h^{(t)}}​应该是
\delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}} + \frac{\partial L}{\partial h^{(t+1)}}\frac{\partial h^{(t+1)}}{\partial h^{(t)}}\qquad \text{(12*)}

但是,这里是一个比较复杂的时序模型,如果使用RNN的思路,将h^{(t+1)}的部分也一起反向传播回来的话,这里的反向梯度根本无法得到闭式解。而只考虑一个的话,也可以做反向梯度优化,进度下降,但是优化起来容易的多,可以理解为这里做了一个近似。

3、另一种理解方式

图中方框我们称为记忆单元,其中实线箭头代表当前时刻的信息传递,虚线箭头表示上一时刻的信息传递。从结构图中我们看出,LSTM模型共增加了三个门: 输入门、遗忘门和输出门。进入block的箭头代表输入,而出去的箭头代表输出。

前向传播公式

fig2

上图中所有带h的权重矩阵均代表一种泛指,为LSTM的各种变种做准备,表示任意一条从上一时刻指向当前时刻的边,本文暂不考虑。与上篇公式类似,a代表汇集计算结果,b代表激活计算结果, Wil代表输入数据与输入门之间的权重矩阵, Wcl代表上一时刻Cell状态与输入门之间的权重矩阵, WiΦ代表输入数据与遗忘门之间的权重矩阵, WcΦ代表上一时刻Cell状态与遗忘门之间的权重矩阵, Wiω代表输入数据与输出门之间的权重矩阵, Wcω代表Cell状态与输出门之间的权重矩阵, Wic代表输入层原有的权重矩阵。 需要注意的是,图中Cell一栏描述的是从下方输入到中间Cell输出的整个传播过程。

反向传播

和朴素RNN的推导一样,有了前向传播公式,我们就能逐个写出LSTM网络中各个参数矩阵的梯度计算公式。首先,由于输出门不牵扯时间维度,我们可以直接写出输出门WiωWcω的迭代公式,如下图:

fig3

遗忘门的权重矩阵 WiΦ也可以直接给出,如下图:

fig4

而对于遗忘门的权重矩阵 WcΦ,由于是和上一时刻Cell状态做汇集计算,残差除了来自当前Cell,还来自下一时刻的Cell,因此需要写出下一时刻Cell传播至本时刻遗忘门的时间维度前向传播公式,如下图:

fig5

有了上面的公式,我们就能完整写出 WcΦ的梯度公式了。如下图所示(如果对这个时间维度前向公式不理解,可以参考上一篇我对朴素RNN的公式推导过程):

fig6

请注意,上图中L”和前面的L’不一样,这里只是为了式子简洁。

推完遗忘门公式,就可以此类推输入门与Cell的公式。其中输入门基本与遗忘门的推法一样,残差都是来自本时刻和下一时刻Cell。而Cell的残差则来自三个地方:输出层、输出门和下一时刻Cell。其中输出层和输出门残差可直接写出;而下一时刻Cell的残差,我们只要写出对应的时间维度前向传播公式便可写出。由于时间关系,这里就不详细推导遗忘门和Cell的梯度公式了,各位若有兴趣可自行继续推导。

相比于朴素RNN模型,LSTM模型更为复杂,且可调整和变化的地方也更多。比如:增加peephole将Cell状态连接到每个门,变体模型Gated Recurrent Unit (GRU),以及后面出现的Attention模型等。LSTM模型在语音识别、图像识别、手写识别、以及预测疾病、点击率和股票等众多领域中都发挥着惊人的效果,是目前最火的神经网络模型之一。敬请期待下节。

4、LSTM 的变体

我们到目前为止都还在介绍正常的 LSTM。但是不是所有的 LSTM 都长成一个样子的。实际上,几乎所有包含 LSTM 的论文都采用了微小的变体。差异非常小,目前为止有上百种,常用的也就几种。 其中一个流形的 LSTM 变体,就是由 Gers & Schmidhuber (2000) 提出的,

4.1、变种1

  • 增加了 “peephole connection”层。
  • 让门层也会接受细胞状态的输入

4.2、变种2

通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增 加什么信息,而是一起进行考虑。

4.3、Gated Recurrent Unit(GRU),2014年提出

  • 将忘记门和输入门合并成为一个单一的更新门
  • 同时合并了数据单元状态和隐藏状态
  • 结构比LSTM的结构更加简单

4.4、https://arxiv.org/pdf/1402.1128v1.pdf 论文

论文中定义的 LTSM cell 如下图所示:

lstm model

图示

  1. \bigotimes代表两个数据源乘上参数后相加。代表两个数据源相加。 
  2. \bigotimes外面再加花边的,代表两个数据源相乘后再取 sigmoid 。
  3.  圆圈里是gg的,代表取 tanh 。
  4. state下标-1代表这是上一次迭代时的结果。

所以像论文里指出的,这里实现的 LSTM Cell 含有更多参数,效果更好?

一般的 LSTM 就够用了,GRU 用的也比较多。

参考

  • LSTM Forward and Backward Pass
  • Understanding LSTM Networks
  • https://arxiv.org/pdf/1402.1128v1.pdf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/453873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows 系统下使用 MinGW + MSYS + GCC 编译 FFMPEG

一定要按照顺序操作,否则你很可能持续遇到很多奇怪的问题(ffmpeg对编译系统版本要求比较高)。 1. www.mingw.org: 下载并安装 MinGW 5.1.4 (http://jaist.dl.sourceforge.net/sourceforge/mingw/MinGW-5.1.4.exe),安装时选中 g, m…

eclipse怎样改编码格式_Eclipse中各种编码格式及设置

操作系统:Windows 10(家庭中文版)Eclipse版本:Version: Oxygen.1a Release (4.7.1a)刚看到一篇文章,里面介绍说Ascii、Unicode是编码,而GBK、UTD-8等是编码格式。Java中的编码问题(by 迷失之路):https://www.cnblogs.c…

UE4 ShooterGame Demo的开火的代码

之前一直没搞懂按下鼠标左键开火之后&#xff0c;代码的逻辑是怎么走的&#xff0c;今天看懂了之前没看懂的部分&#xff0c;进了一步 ShooterCharacter.cpp void AShooterCharacter::OnStartFire() {AShooterPlayerController* MyPC Cast<AShooterPlayerController>(Co…

kafka 异常:return ‘<SimpleProducer batch=%s>‘ % self.async ^ SyntaxError: invalid syntax

Python3.X 执行Python编写的生产者和消费者报错&#xff0c;报错信息如下&#xff1a; Traceback (most recent call last): File "mykit_kafka_producer.py", line 9, in <module> from kafka import KafkaProducer File "/usr/local/lib/python3.7/sit…

python 分布式计算框架_漫谈分布式计算框架

如果问 mapreduce 和 spark 什么关系&#xff0c;或者说有什么共同属性&#xff0c;你可能会回答他们都是大数据处理引擎。如果问 spark 与 tensorflow 呢&#xff0c;就可能有点迷糊&#xff0c;这俩关注的领域不太一样啊。但是再问 spark 与 MPI 呢&#xff1f;这个就更远了。…

Codeforces 899D Shovel Sale

题目大意 给定正整数 $n$&#xff08;$2\le n\le 10^9$&#xff09;。 考虑无序整数对 $(x, y)$&#xff08;$1\le x,y\le n, x\ne y$&#xff09;。 求满足 「$xy$ 结尾连续的 9 最多」的数对 $(x,y)$ 的个数。 例子&#xff1a; $n50$&#xff0c;$(49,50)$ 是一个满足条件的…

Windows系统使用minGW+msys 编译ffmpeg 0.5的全过程详述

一.环境配置 1.下载并安装 MinGW-5.1.4.exe (http://jaist.dl.sourceforge.net/sourcef … -5.1.4.exe)&#xff0c;安装时选中 g, mingw make。建议安装到c:/mingw. 2.下载并安装 MSYS-1.0.11-rc-1.exe (http://jaist.dl.sourceforge.net/sourcef … 1-rc-1.exe)&#xff0c;安…

Liunx安装gogs,mysql,jdk,tomcat等常用软件

Liunx CentOS系统采用yum安装Mysql 一.安装mysql客户端 yum -y install mysql 二.安装mysql服务器端 [注意:由于CentOS7下的不自带mysql-server,所以得先安装资源包,步骤: 1.wget http://repo.mysql.com/mysql-community-release-el7-5.noarch.rpm (采用wget获取必须有wge…

stm32单片机端口映射_STM32单片机的重映射与地址映射的使用方法及步骤

重映射STM32中对于一些端口的外设已经被其他引脚所使用&#xff0c;这是就需要用端口重映射来解决了&#xff0c;很方便。以USART1为例重映射的步骤为&#xff1a;打开重映射时钟和USART重映射后的I/O口引脚时钟&#xff0c;RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOB|RCC_A…

python 第三方模块 yaml - 处理 YAML (专门用来写配置文件的语言)

markdown 的配置使用 Yaml —— Yet Another Markup Language &#xff1a;另一种标记语言。 简介 YAML 是专门用来写配置文件的语言&#xff0c;非常简洁和强大&#xff0c;远比 JSON 格式方便。 YAML在python语言中有PyYAML安装包。 YAML 语言&#xff08;发音 /ˈjməl/ &…

程序员 赚钱

业余编程赚钱 程序员的好方法 现在的人生活水平高了&#xff0c;开销也大了&#xff0c;同时对于一些技术性人员来说有很多种&#xff0c;有些程序员自己开公司&#xff0c;开发自己的产品&#xff0c;年赚百万&#xff0c;有些程序员还在给别人打工&#xff0c;每天累死累活的…

java合并单元格的快捷键_java poi合并单元格问题

使用poi导出的execl合并单元格&#xff0c;会出现下图问题整个单元格看似合并了&#xff0c;但是文字没有垂直居中&#xff0c;而且execl中所有的合并都会在第三行开始出现灰色分层样式合并单元格伪代码String upCompareField ""; //上一行的对比值for(int i 0; i …

webpack自动化构建脚本指令npm run dev/build

指令 为不同环境配置可执行指令&#xff0c;我们使用npm scripts方式&#xff0c;在package.json文件中配置执行指令&#xff1a; {"scripts": {"start": "cross-env NODE_ENVdev webpack-dev-server","build": "cross-env NODE_…

前端之 form 详解

认识表单 在一个页面上可以有多个form表单&#xff0c;但是向web服务器提交表单的时候&#xff0c;一次只可以提交一个表单。要声明一个表单&#xff0c;只需要使用 form 标记来标明表单的开始和结束&#xff0c;若需要向服务器提交数据&#xff0c;则在form标签中需要设置act…

代码 优化 指南 实践

C代码优化方案 华中科技大学计算机学院 姓名&#xff1a; 王全明 QQ&#xff1a; 375288012 Email&#xff1a; quanming1119163.com 目录 目录 C代码优化方案 1、选择合适的算法和数据结构 2、使用尽量小的数据类型 3、减少运算的强度 &#xff08;1&…

.12-浅析webpack源码之NodeWatchFileSystem模块总览

剩下一个watch模块&#xff0c;这个模块比较深&#xff0c;先大概过一下整体涉及内容再分部讲解。 流程图如下&#xff1a; NodeWatchFileSystem const Watchpack require("watchpack");class NodeWatchFileSystem {constructor(inputFileSystem) {this.inputFileSy…

Python 第三方模块之 beautifulsoup(bs4)- 解析 HTML

简单来说&#xff0c;Beautiful Soup是python的一个库&#xff0c;最主要的功能是从网页抓取数据。官方解释如下&#xff1a;官网文档 Beautiful Soup提供一些简单的、python式的函数用来处理导航、搜索、修改分析树等功能。 它是一个工具箱&#xff0c;通过解析文档为用户提供…

modal vue 关闭_Vue弹出框的优雅实践

引言页面引用弹出框组件是经常碰见的需求,如果强行将弹出框组件放入到页面中,虽然功能上奏效但没有实现组件与页面间的解耦,非常不利于后期的维护和功能的扩展.下面举个例子来说明一下这种做法的弊端.click"openModal()">点击 :is_open"is_open" close…

Python 第三方模块之 lxml - 解析 HTML 和 XML 文件

lxml是python的一个解析库&#xff0c;支持HTML和XML的解析&#xff0c;支持XPath解析方式&#xff0c;而且解析效率非常高 XPath&#xff0c;全称XML Path Language&#xff0c;即XML路径语言&#xff0c;它是一门在XML文档中查找信息的语言&#xff0c;它最初是用来搜寻XML文…

(转)Linux下PS1、PS2、PS3、PS4使用详解

Linux下PS1、PS2、PS3、PS4使用详解 原文&#xff1a;http://www.linuxidc.com/Linux/2016-10/136597.htm 1、PS1——默认提示符 如下所示&#xff0c;可以通过修改Linux下的默认提示符&#xff0c;使其更加实用。在下面的例子中&#xff0c;默认的PS1的值是“\s-\v\$”,显示出…