循环神经网络变形之 (Long Short Term Memory,LSTM)

1、长短期记忆网络LSTM简介

在RNN 计算中,讲到对于传统RNN水平方向进行长时刻序列依赖时可能会出现梯度消失或者梯度爆炸的问题。LSTM 特别适合解决这种需要长时间依赖的问题。

LSTM(Long Short Term Memory,长短期记忆网络)是RNN的一种,大体结构一直,区别在于:

  • LSTM 的‘记忆cell’ 是被改造过的,水平方向减少梯度消失与梯度爆炸
  • 该记录的信息会一直传递,不该记录的信息会被截断掉,部分输出和输入被从网络中删除

RNN 在语音识别,语言建模,翻译,图片描述等问题的应用的成功,都是通过 LSTM 达到的。

在这里插入图片描述

2、LSTM工作原理

2.1、传统的RNN“细胞”结构

所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。


2.2、LSTM结构

如下图展示了LSTM的一个神经元内部的结构。单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。

图中使用的各种元素的图标:

  • 每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。合在一起的线表示向量的连接,比如一个十维向量和一个二十维向量合并后形成一个三十维向量;分开的线表示内容被复制,然后分发到不同的位置。
  • 粉色的圈代表 pointwise 的操作,诸如向量的加法,减法,乘法,除法,都是矩阵的。
  • 黄色的矩阵就是神经网络层。

2.3、细胞状态

LSTM关键:“细胞状态” 。细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变很容易。

LSTM怎么控制“细胞状态”?

  • LSTM可以通过gates(“门”)结构来去除或者增加“细胞状态”的信息
  • 包含一个sigmoid神经网络层次和一个pointwist乘法操作
  • Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允 许任务变量通过”,1表示“运行所有变量通过”
  • LSTM中主要有三个“门”结构来控制“细胞状态”
  • C_{t-1} 和 C_t 矩阵的shape 是一样的

忘记门

决定从“细胞状态”中丢弃什么信息;比如在语言模型中,细胞状态可能包含了性别信息(“他”或者“她”),当我们看到新的代名词的时候,可以考虑忘记旧的数据

信息增加门

  • 决定放什么新信息到“细胞状态”中;
  • Sigmoid层决定什么值需要更新; 
  • Tanh层创建一个新的候选向量C_t
  • 主要是为了状态更新做准备

在这里插入图片描述

经过第一个和第二个“门”后,可以确定传递信息的删除和增加,即可以进行 “细胞状态”的更新

  • 更新C_(t-1)C_ t;
  • 将旧状态与f_t相乘,丢失掉确定不要的信息;
  • 加上新的候选值i_t*C_t得到最终更新后的“细胞状态”

输出门

  • 首先运行一个sigmoid网络层来确定细胞状态的那个部分将输出
  • 使用tanh处理细胞状态得到一个-1到1之间的值,再将它和sigmoid门的输出相乘,输出程序确定输出的部分。

  • 前向传播和反向传播可以参看前面的传播过程写下来,更新LSTM中的参数。具体的公式可以参看:RNN以及LSTM的介绍和公式梳理_Dark_Scope的博客-CSDN博客_lstm公式
  • 作者的论文:https://arxiv.org/pdf/1402.1128v1.pdf

2.4、 前向传播和后向传播

前向传播

现在我们来总结下LSTM前向传播算法。LSTM模型有两个隐藏状态h^{(t)},C^{(t)},模型参数几乎是RNN的4倍,因为现在多了W_f,U_f,b_f,W_a,U_a,b_a,W_i,U_i,b_i,W_o,U_o,b_o这些参数。   

前向传播过程在每个序列索引位置的过程为:

1)更新遗忘门输出:

f^{(t)}=\sigma (W_fh^{(t-1)}+U_fx^{(t)}+b_f)

2)更新输入门两部分输出:

i^{(t)}=\sigma (W_ih^{(t-1)}+U_ix^{(t)}+b_i)

a^{(t)}=tanh(W_ah^{(t-1)}+U_ax^{(t)}+b_a)

3)更新细胞状态:

C^{(t)}=C^{(t-1)}\odot f^{(t)}+i^{(t)}\odot a^{(t)}

4)更新输出门输出:

o^{(t)}=\sigma (W_oh^{(t-1)}+U_ox^{(t)}+b_o)

h^{(t))}=o^{(t)}\odot tanh(C^{(t)})

5)更新当前序列索引预测输出:

\hat{y} ^{(t)}=\sigma (Vh^{(t)}+c)

整体的过程如下图所示

可以看到,在t tt时刻,C^{(t)}用于计算h^{(t)}C^{(t+1)}

反向传播

有了LSTM前向传播算法,推导反向传播算法就很容易了, 思路和RNN的反向传播算法思路一致,也是通过梯度下降法迭代更新我们所有的参数,关键点在于计算所有参数基于损失函数的偏导数。

在RNN中,为了反向传播误差,我们通过隐藏状态h^{(t)}的梯度\delta^{(t)}一步步向前传播。在LSTM这里也类似。只不过我们这里有两个隐藏状态h^{(t)}C^{(t)}。这里我们定义两个\delta,即:

\delta_h^{(t)} = \frac{\partial L}{\partial h^{(t)}}\qquad \text{(8)}

\delta_C^{(t)} = \frac{\partial L}{\partial C^{(t)}}\qquad \text{(9)}

反向传播时只使用了\delta_C^{(t)},变量\delta_h^{(t)}仅为帮助我们在某一层计算用,并没有参与反向传播,这里要注意。如下图所示:

因为,我们在输出层定义的损失函数为对数损失,激活函数为softmax激活函数。因为,与RNN的推导类似,在最后的序列索引位置 \tau\delta_h^{(\tau)}\delta_C^{(\tau)}为:

\delta_h^{(\tau)} =\frac{\partial L}{\partial O^{(\tau)}} \frac{\partial O^{(\tau)}}{\partial h^{(\tau)}} = V^T(\hat{y}^{(\tau)} - y^{(\tau)})\qquad \text{(10)}
\delta_C^{(\tau)} =\frac{\partial L}{\partial h^{(\tau)}} \frac{\partial h^{(\tau)}}{\partial C^{(\tau)}} = \delta_h^{(\tau)} \odot o^{(\tau)} \odot (1 - tanh^2(C^{(\tau)}))\qquad \text{(11)}

接着我们由\delta_C^{(t+1)}反向推导\delta_C^{(t)}

\delta_h^{(t)}的梯度由本层的输出梯度误差决定,与公式(10)类似,即:

\delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}} = V^T(\hat{y}^{(t)} - y^{(t)})\qquad \text{(12)}

\delta_C^{(t)}​的反向梯度误差由前一层\delta_C^{(t+1)}​的梯度误差和本层的从h^{(t)}传回来的梯度误差两部分组成,即:

\delta_C^{(t)} =\frac{\partial L}{\partial C^{(t+1)}} \frac{\partial C^{(t+1)}}{\partial C^{(t)}} + \frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial C^{(t)}} = \delta_C^{(t+1)}\odot f^{(t+1)} + \delta_h^{(t)} \odot o^{(t)} \odot (1 - tanh^2(C^{(t)})) \qquad \text{(13)}

公式(13)的前半部分由公式(4)和公式(9)得到,公式(13)的后半部分由公式(6)和公式(8)得到。

有了\delta_h^{(t)}\delta_C^{(t)}, 计算这一大堆参数的梯度就很容易了,这里只给出W_f​的梯度计算过程,其他的U_f, b_f, W_a, U_a, b_a, W_i, U_i, b_i, W_o, U_o, b_o,V, c的梯度大家只要照搬就可以了。
\frac{\partial L}{\partial W_f} = \sum\limits_{t=1}^{\tau}\frac{\partial L}{\partial C^{(t)}} \frac{\partial C^{(t)}}{\partial f^{(t)}} \frac{\partial f^{(t)}}{\partial W_f} =\sum\limits_{t=1}^{\tau} \delta_C^{(t)} \odot C^{(t-1)} \odot f^{(t)}\odot(1-f^{(t)}) (h^{(t-1)})^T \qquad \text{(14)}

公式(13)的由公式(1)、公式(4)和公式(9)得到。

由上面可以得到,只要我们清晰地搞清楚前向传播过程,且只使用了\delta_C^{(t)}​进行反向传播的话,反向传播的整个过程是比较清晰的。

在这里有必要解释下为什么反向传播不使用\delta_h^{(t)}​,如果与循环神经网络(RNN)模型的前向反向传播算法里一样的话,那么\delta_h^{(t)}的计算方式就不应该是(12)式了

\delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}} = V^T(\hat{y}^{(t)} - y^{(t)})\qquad \text{(12)}

因为,h^{(t)}参与了L^{(t)}L^{(t+1)}的计算,所以在RNN文章里的求梯度方法,\frac{\partial L}{\partial h^{(t)}}​应该是
\delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}} + \frac{\partial L}{\partial h^{(t+1)}}\frac{\partial h^{(t+1)}}{\partial h^{(t)}}\qquad \text{(12*)}

但是,这里是一个比较复杂的时序模型,如果使用RNN的思路,将h^{(t+1)}的部分也一起反向传播回来的话,这里的反向梯度根本无法得到闭式解。而只考虑一个的话,也可以做反向梯度优化,进度下降,但是优化起来容易的多,可以理解为这里做了一个近似。

3、另一种理解方式

图中方框我们称为记忆单元,其中实线箭头代表当前时刻的信息传递,虚线箭头表示上一时刻的信息传递。从结构图中我们看出,LSTM模型共增加了三个门: 输入门、遗忘门和输出门。进入block的箭头代表输入,而出去的箭头代表输出。

前向传播公式

fig2

上图中所有带h的权重矩阵均代表一种泛指,为LSTM的各种变种做准备,表示任意一条从上一时刻指向当前时刻的边,本文暂不考虑。与上篇公式类似,a代表汇集计算结果,b代表激活计算结果, Wil代表输入数据与输入门之间的权重矩阵, Wcl代表上一时刻Cell状态与输入门之间的权重矩阵, WiΦ代表输入数据与遗忘门之间的权重矩阵, WcΦ代表上一时刻Cell状态与遗忘门之间的权重矩阵, Wiω代表输入数据与输出门之间的权重矩阵, Wcω代表Cell状态与输出门之间的权重矩阵, Wic代表输入层原有的权重矩阵。 需要注意的是,图中Cell一栏描述的是从下方输入到中间Cell输出的整个传播过程。

反向传播

和朴素RNN的推导一样,有了前向传播公式,我们就能逐个写出LSTM网络中各个参数矩阵的梯度计算公式。首先,由于输出门不牵扯时间维度,我们可以直接写出输出门WiωWcω的迭代公式,如下图:

fig3

遗忘门的权重矩阵 WiΦ也可以直接给出,如下图:

fig4

而对于遗忘门的权重矩阵 WcΦ,由于是和上一时刻Cell状态做汇集计算,残差除了来自当前Cell,还来自下一时刻的Cell,因此需要写出下一时刻Cell传播至本时刻遗忘门的时间维度前向传播公式,如下图:

fig5

有了上面的公式,我们就能完整写出 WcΦ的梯度公式了。如下图所示(如果对这个时间维度前向公式不理解,可以参考上一篇我对朴素RNN的公式推导过程):

fig6

请注意,上图中L”和前面的L’不一样,这里只是为了式子简洁。

推完遗忘门公式,就可以此类推输入门与Cell的公式。其中输入门基本与遗忘门的推法一样,残差都是来自本时刻和下一时刻Cell。而Cell的残差则来自三个地方:输出层、输出门和下一时刻Cell。其中输出层和输出门残差可直接写出;而下一时刻Cell的残差,我们只要写出对应的时间维度前向传播公式便可写出。由于时间关系,这里就不详细推导遗忘门和Cell的梯度公式了,各位若有兴趣可自行继续推导。

相比于朴素RNN模型,LSTM模型更为复杂,且可调整和变化的地方也更多。比如:增加peephole将Cell状态连接到每个门,变体模型Gated Recurrent Unit (GRU),以及后面出现的Attention模型等。LSTM模型在语音识别、图像识别、手写识别、以及预测疾病、点击率和股票等众多领域中都发挥着惊人的效果,是目前最火的神经网络模型之一。敬请期待下节。

4、LSTM 的变体

我们到目前为止都还在介绍正常的 LSTM。但是不是所有的 LSTM 都长成一个样子的。实际上,几乎所有包含 LSTM 的论文都采用了微小的变体。差异非常小,目前为止有上百种,常用的也就几种。 其中一个流形的 LSTM 变体,就是由 Gers & Schmidhuber (2000) 提出的,

4.1、变种1

  • 增加了 “peephole connection”层。
  • 让门层也会接受细胞状态的输入

4.2、变种2

通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增 加什么信息,而是一起进行考虑。

4.3、Gated Recurrent Unit(GRU),2014年提出

  • 将忘记门和输入门合并成为一个单一的更新门
  • 同时合并了数据单元状态和隐藏状态
  • 结构比LSTM的结构更加简单

4.4、https://arxiv.org/pdf/1402.1128v1.pdf 论文

论文中定义的 LTSM cell 如下图所示:

lstm model

图示

  1. \bigotimes代表两个数据源乘上参数后相加。代表两个数据源相加。 
  2. \bigotimes外面再加花边的,代表两个数据源相乘后再取 sigmoid 。
  3.  圆圈里是gg的,代表取 tanh 。
  4. state下标-1代表这是上一次迭代时的结果。

所以像论文里指出的,这里实现的 LSTM Cell 含有更多参数,效果更好?

一般的 LSTM 就够用了,GRU 用的也比较多。

参考

  • LSTM Forward and Backward Pass
  • Understanding LSTM Networks
  • https://arxiv.org/pdf/1402.1128v1.pdf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/453873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UE4 ShooterGame Demo的开火的代码

之前一直没搞懂按下鼠标左键开火之后&#xff0c;代码的逻辑是怎么走的&#xff0c;今天看懂了之前没看懂的部分&#xff0c;进了一步 ShooterCharacter.cpp void AShooterCharacter::OnStartFire() {AShooterPlayerController* MyPC Cast<AShooterPlayerController>(Co…

Windows系统使用minGW+msys 编译ffmpeg 0.5的全过程详述

一.环境配置 1.下载并安装 MinGW-5.1.4.exe (http://jaist.dl.sourceforge.net/sourcef … -5.1.4.exe)&#xff0c;安装时选中 g, mingw make。建议安装到c:/mingw. 2.下载并安装 MSYS-1.0.11-rc-1.exe (http://jaist.dl.sourceforge.net/sourcef … 1-rc-1.exe)&#xff0c;安…

程序员 赚钱

业余编程赚钱 程序员的好方法 现在的人生活水平高了&#xff0c;开销也大了&#xff0c;同时对于一些技术性人员来说有很多种&#xff0c;有些程序员自己开公司&#xff0c;开发自己的产品&#xff0c;年赚百万&#xff0c;有些程序员还在给别人打工&#xff0c;每天累死累活的…

代码 优化 指南 实践

C代码优化方案 华中科技大学计算机学院 姓名&#xff1a; 王全明 QQ&#xff1a; 375288012 Email&#xff1a; quanming1119163.com 目录 目录 C代码优化方案 1、选择合适的算法和数据结构 2、使用尽量小的数据类型 3、减少运算的强度 &#xff08;1&…

.12-浅析webpack源码之NodeWatchFileSystem模块总览

剩下一个watch模块&#xff0c;这个模块比较深&#xff0c;先大概过一下整体涉及内容再分部讲解。 流程图如下&#xff1a; NodeWatchFileSystem const Watchpack require("watchpack");class NodeWatchFileSystem {constructor(inputFileSystem) {this.inputFileSy…

Python 第三方模块之 beautifulsoup(bs4)- 解析 HTML

简单来说&#xff0c;Beautiful Soup是python的一个库&#xff0c;最主要的功能是从网页抓取数据。官方解释如下&#xff1a;官网文档 Beautiful Soup提供一些简单的、python式的函数用来处理导航、搜索、修改分析树等功能。 它是一个工具箱&#xff0c;通过解析文档为用户提供…

modal vue 关闭_Vue弹出框的优雅实践

引言页面引用弹出框组件是经常碰见的需求,如果强行将弹出框组件放入到页面中,虽然功能上奏效但没有实现组件与页面间的解耦,非常不利于后期的维护和功能的扩展.下面举个例子来说明一下这种做法的弊端.click"openModal()">点击 :is_open"is_open" close…

开放平台大抉择

开放平台大抉择之新浪SAE&#xff1a;为个人应用开发带来福音 导读&#xff1a;继上期淘宝网副总裁王文彬从平台功能特色、运营状况等多方面分享了淘宝开放平台的历程和挑战之后。国内另一家云平台服务方的典型代表——Sina App Engine(简称SAE)&#xff0c;作为新浪研发中心于…

HTTP POST 发送数据的参数 application/x-www-form-urlencoded、multipart/form-data、text/plain

HTTP 简介 HTTP/1.1 协议规定的 HTTP 请求方法有 OPTIONS、GET、HEAD、POST、PUT、DELETE、TRACE、CONNECT 这几种。 其中 POST 一般用来向服务端提交数据&#xff0c;本文主要讨论 POST 提交数据的几种方式。 我们知道&#xff0c;HTTP 协议是以 ASCII 码传输&#xff0c;建…

vue 二进制文件的下载(解决乱码和解压报错)

问题描述&#xff1a;项目中使用的是vue框架进行开发&#xff0c;因为文件下载存在权限问题&#xff0c;所以并不能通过 a 链接的 href 属性直接赋值 URL进行下载&#xff0c; &#xff08;如果你的文件没有下载权限&#xff0c;可以直接通过href属性赋值URL的方法进行文件下载…

浅谈云计算与数据中心计算

文/林仕鼎 云计算概念发端于Google和Amazon等超大规模的互联网公司&#xff0c;随着这些公司业务的成功&#xff0c;作为其支撑技术的云计算也得到了业界的高度认可和广泛传播。时至今日&#xff0c;云计算已被普遍认为是IT产业发展的新阶段&#xff0c;从而被赋予了很多产业和…

数据挖掘:如何寻找相关项

导读&#xff1a;随着大数据时代浪潮的到来数据科学家这一新兴职业也越来越受到人们的关注。本文作者Alexandru Nedelcu就将数学挖掘算法与大数据有机的结合起来&#xff0c;并无缝的应用在面临大数据浪潮的网站之中。 数据科学家需要具备专业领域知识并研究相应的算法以分析对…

00030_ArrayList集合

1、数组可以保存多个元素&#xff0c;但在某些情况下无法确定到底要保存多少个元素&#xff0c;此时数组将不再适用&#xff0c;因为数组的长度不可变 2、JDK中提供了一系列特殊的类&#xff0c;这些类可以存储任意类型的元素&#xff0c;并且长度可变&#xff0c;统称为集合 3…

1.3tf的varible\labelencoder

1.tf的varible变量 import tensorflow as tf #定义变量--这里是计数的变量 statetf.Variable(0,namecounter) print (state.name) #输出变量值 onetf.constant(1) #常量new_valuetf.add(state,one) updatetf.assign(state,new_value)#初始化所有变量 inittf.initialize_all_var…

多线程编程指南

1. 多线程编程指南1--线程基础 线程编程指南1--线程基础 Wednesday, 29. March 2006, 11:48:45 多线程 本文出自:BBS水木清华站 作者:Mccartney (coolcat) (2002-01-29 20:25:25) multithreading可以被翻译成多线程控制。与传统的UNIX不同&#xff0c;一个传统 的UNIX进…

路由器和猫的区别

路由器和猫的区别 网络在我们现在生活中必不可少,我们链接互联网经常需要用到猫和路由器,但是依然有很多菜鸟根本不知道什么是猫什么是路由器,至于猫和路由器怎么使用就更不知道了,下面给大家详细的讲解下路由器和猫的区别。 路由器和猫的用途和链接位置不一样,如下图: 路由器:…

kafka 命令行命令大全

kafka 脚本 connect-distributed.sh connect-mirror-maker.sh connect-standalone.sh kafka-acls.sh kafka-broker-api-versions.sh kafka-configs.sh kafka-console-consumer.sh kafka-console-producer.sh kafka-consumer-groups.sh kafka-consumer-perf-test.sh kafka-dele…

python 第三方模块之 APScheduler - 定时任务

介绍 APScheduler的全称是Advanced Python Scheduler。它是一个轻量级的 Python 定时任务调度框架。APScheduler 支持三种调度任务&#xff1a;固定时间间隔&#xff0c;固定时间点&#xff08;日期&#xff09;&#xff0c;Linux 下的 Crontab 命令。同时&#xff0c;它还支持…

hadoop分布式搭建

一&#xff0c;前提&#xff1a;下载好虚拟机和安装完毕Ubuntu系统。因为我们配置的是hadoop分布式&#xff0c;所以需要两台虚拟机&#xff0c;一台主机&#xff08;master&#xff09;&#xff0c;一台从机&#xff08;slave&#xff09; 选定一台机器作为 Master 在 Master …

xvid 详解 代码分析 编译等

1. Xvid参数详解 众所周知&#xff0c;Mencoder以其极高的压缩速率和不错的画质赢得了很多朋友的认同&#xff01; 原来用Mencoder压缩Xvid的AVI都是使用Xvid编码器的默认设置&#xff0c;现在我来给大家冲冲电&#xff0c;讲解一下怎样使用Mencoder命令行高级参数制作Xvid编…