[PyTorch][chapter 46][LSTM -1]

前言:

           长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的。

目录:

  1.      背景简介
  2.      LSTM Cell
  3.      LSTM 反向传播算法
  4.      为什么能解决梯度消失
  5.       LSTM 模型的搭建


一  背景简介:

       1.1  RNN

         RNN 忽略o_t,L_t,y_t 模型可以简化成如下

      

       

          图中Rnn Cell 可以很清晰看出在隐藏状态h_t=f(x_t,h_{t-1})

            得到 h_t后:

              一方面用于当前层的模型损失计算,另一方面用于计算下一层的h_{t+1}

    由于RNN梯度消失的问题,后来通过LSTM 解决 

       1.2 LSTM 结构

        


二  LSTM  Cell

   LSTMCell(RNNCell) 结构

          

          前向传播算法 Forward

         2.1   更新: forget gate 忘记门

             f_t=\sigma(W_fh_{t-1}+U_{t}x_t+b_f)

             将值朝0 减少, 激活函数一般用sigmoid

             输出值[0,1]

         2.2 更新: Input gate 输入门

                i_t=\sigma(W_ih_{t-1}+U_ix_t+b_i)

                决定是不是忽略输入值

    

           2.3 更新: 候选记忆单元

                    a_t=\widetilde{c_t}=tanh(W_a h_{t-1}+U_ax_t+b_a)

           2.4 更新: 记忆单元

               c_t=f_t \odot c_{t-1}+i_t \odot a_t

             2.5  更新: 输出门

                决定是否使用隐藏值

                 o_t=\sigma(W_oh_{t-1}+U_ox_t+b_0)  

           2.6. 隐藏状态

                h_t=o_t \odot tanh(c_t)

           2.7  模型输出

                  \hat{y_t}=\sigma(Vh_t+b)

LSTM 门设计的解释一:

 输入门 ,遗忘门,输出门 不同取值组合的时候,记忆单元的输出情况


三  LSTM 反向传播推导

      3.1 定义两个\delta_t

             \delta_h^t=\frac{\partial L}{\partial h_t}

            \delta_c^t=\frac{\partial L}{\partial C_t}

    3.2  定义损失函数

            损失函数L(t)分为两部分: 

             时刻t的损失函数 l(t)

             时刻t后的损失函数L(t+1)

              L(t)=\left\{\begin{matrix} l(t)+L(t+1), if: t<T\\ l(t), if: t=T \end{matrix}\right.

      3.3 最后一个时刻\tau

              

 这里面要注意这里的o^{\tau}= Vh_{\tau}+c

    证明一下第二项,主要应用到微分的两个性质,以及微分和迹的关系:

   

   dl= tr((\frac{\partial L^{\tau}}{\partial h^{\tau}})^Tdh^{\tau})  ... 公式1: 微分和迹的关系

       =tr((\delta_h^{\tau})^Tdh^{\tau})

     因为

    h^{\tau}=o^{\tau} \odot tanh(c^{\tau})

   dh_T=o^{\tau}\odot(d(tanh (c^{\tau})))

           =o^{\tau} \odot (1-tanh^2(c^{\tau})) \odot dc^{\tau}

     带入上面公式1:

      dl= tr((\delta_h^{\tau})^T (o^{\tau}\odot(1-tanh^2(c^{\tau}))\odot dc^{\tau})

           =tr((\delta_h^{\tau} \odot o^{\tau} \odot(1-tanh^2(c^{\tau}))^Tdc^{\tau})

    所以

3.4   链式求导过程

       求导结果:

 

  这里详解一下推导过程:

  这是一个符合函数求导:先把h 写成向量形成

h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

 ------------------------------------------------------------   

 第一项: 

             

         h_{t+1}=o_{t+1}\odot tanh(c_{t+1})

         o_{t+1}=\sigma(W_oh_t+U_ox_{t+1}+b_0)

        设 a_{t+1}=W_oh_t+U_ox_{t+1}+b_0

           则    \frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial o_{t+1}}\frac{\partial o_{t+1}}{\partial a_{t+1}}\frac{\partial a_{t+1}}{\partial h_{t}}

 

            其中:(利用矩阵求导的定义法 分子布局原理)

                    \frac{\partial h_{t+1}}{\partial o_{t+1}}=diag(tanh(c^{t+1})) 是一个对角矩阵

                  o=\begin{bmatrix} \sigma(a_1)\\ \sigma(a_2) \\ .... \\ \sigma(a_n) \end{bmatrix}

                 \frac{\partial o_{t+1}}{\partial a_{t+1}}=diag(o_{t+1}\odot(1-o_{t+1}))

                 \frac{\partial a_{t+1}}{\partial h_{t}}=W_o

                 几个连乘起来就是第一项

               

第二项

    c_{t+1}=f_{t+1}\odot c_t+i_{t+1}\odot a_{t+1}

   f_{t+1}=\sigma(W_fh_t+U_tx_{t+1}+b_f)

   i_{t+1}=\sigma(W_ih_t+U_i x_{t+1}+b_i)

  a_{t+1}=tanh(W_a h_t +U_ax_t +b_a)

参考:

   h=\begin{bmatrix} o_1*tanh(c_1)\\ o_2*tanh(c_2) \\ .... \\ o_n*tanh(c_n) \end{bmatrix}

其中:

\frac{\partial h_{t+1}}{\partial c^{t+1}}=diag(o^{t+1}\odot (1-tanh^2(c^{t+1}))

\frac{\partial h_{t+1}}{\partial h_{t}}=\frac{\partial h_{t+1}}{\partial c_{t+1}}\frac{\partial c_{t+1}}{\partial f_{t+1}}\frac{\partial f_{t+1}}{\partial h_{t}}

 \frac{\partial c_{t+1}}{\partial f_{t+1}}=diag(c^{t})

 \frac{\partial a_{t+1}}{\partial h_{t}}=diag(f_t \odot(1-f_t))W_f

其它也是相似,就有了上面的求导结果


四  为什么能解决梯度消失

    

     4.1 RNN 梯度消失的原理

                ,复旦大学邱锡鹏书里面 有更加详细的解释,通过极大假设:

在梯度计算中存在梯度的k 次方连乘 ,导致 梯度消失原理。

    4.2  LSTM 解决梯度消失 解释1:

            通过上面公式发现梯度计算中是加法运算,不存在连乘计算,

            极大概率降低了梯度消失的现象。

    4.3  LSTM 解决梯度 消失解释2:

              记忆单元c  作用相当于ResNet的残差部分.  

   比如f_{t}=1,\hat{c_t}=0 时候,\frac{\partial c_t}{\partial c_{t-1}}=1,不会存在梯度消失。

       


五 模型的搭建

   

    我们最后发现:

    O_t,C_t,H_t 的维度必须一致,都是hidden_size

    通过C_t,则 I_t,F_t,\tilde{c} 最后一个维度也必须是hidden_size

    

# -*- coding: utf-8 -*-
"""
Created on Thu Aug  3 15:11:19 2023@author: chengxf2
"""# -*- coding: utf-8 -*-
"""
Created on Wed Aug  2 15:34:25 2023@author: chengxf2
"""import torch
from torch import nn
from d21 import torch as d21def normal(shape,devices):data = torch.randn(size= shape, device=devices)*0.01return datadef get_lstm_params(input_size, hidden_size,categorize_size,devices):#隐藏门参数W_xf= normal((input_size, hidden_size), devices)W_hf = normal((hidden_size, hidden_size),devices)b_f = torch.zeros(hidden_size,devices)#输入门参数W_xi= normal((input_size, hidden_size), devices)W_hi = normal((hidden_size, hidden_size),devices)b_i = torch.zeros(hidden_size,devices)#输出门参数W_xo= normal((input_size, hidden_size), devices)W_ho = normal((hidden_size, hidden_size),devices)b_o = torch.zeros(hidden_size,devices)#临时记忆单元W_xc= normal((input_size, hidden_size), devices)W_hc = normal((hidden_size, hidden_size),devices)b_c = torch.zeros(hidden_size,devices)#最终分类结果参数W_hq = normal((hidden_size, categorize_size), devices)b_q = torch.zeros(categorize_size,devices)params =[W_xf,W_hf,b_f,W_xi,W_hi,b_i,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q]for param in params:param.requires_grad_(True)return paramsdef init_lstm_state(batch_size, hidden_size, devices):cell_init = torch.zeros((batch_size, hidden_size),device=devices)hidden_init = torch.zeros((batch_size, hidden_size),device=devices)return (cell_init, hidden_init)def lstm(inputs, state, params):[W_xf,W_hf,b_f,W_xi,W_hi,b_i,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q] = params    (H,C) = stateoutputs= []for x in inputs:#input gateI = torch.sigmoid((x@W_xi)+(H@W_hi)+b_i)F = torch.sigmoid((x@W_xf)+(H@W_hf)+b_f)O = torch.sigmoid((x@W_xo)+(H@W_ho)+b_o)C_tmp = torch.tanh((x@W_xc)+(H@W_hc)+b_c)C = F*C+I*C_tmpH = O*torch.tanh(C)Y = (H@W_hq)+b_qoutputs.append(Y)return torch.cat(outputs, dim=0),(H,C)def main():batch_size,num_steps =32, 35train_iter, cocab= d21.load_data_time_machine(batch_size, num_steps)if __name__ == "__main__":main()


 参考

 

CSDN

https://www.cnblogs.com/pinard/p/6519110.html

57 长短期记忆网络(LSTM)【动手学深度学习v2】_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/24774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VR全景在建筑工程行业能起到哪些作用?

在建筑工程领域&#xff0c;数字化技术为行业的发展起到巨大的推动作用&#xff0c;虽然建筑施工行业主要是依赖于工人劳动力和施工设备&#xff0c;但是VR全景在该行业中方方面面都能应用&#xff0c;从设计建模到项目交付&#xff0c;帮助建筑师以及项目方更好的理解每个环节…

数字电路的重要概念——静态功耗和动态功耗

静态功耗和动态功耗&#xff1a; CMOS电路功耗是由静态功耗和动态功耗组成的&#xff0c;动态功耗远大于静态功耗 1&#xff1a;静态功耗&#xff1a; 我们从一个简单的反相器角度来理解和说明静态功耗的概念&#xff0c;众所周知&#xff0c;反相器是由PMOS和NMOS互补组成的…

【ES】笔记-let 声明及其特性

let 声明及其特性 声明变量 变量赋值、也可以批量赋值 let a;let b,c,d;let e100;let f521,giloveyou,h[];变量不能重复声明 let star罗志祥;let star小猪;块级作用域&#xff0c;let声明的变量只在块级作用域内有效 {let girl周杨青;}console.log(girl)注意&#xff1a;在 i…

Redis可视化工具

Redis可视化工具 1、RedisInsight 下载地址&#xff1a;https://redis.com/redis-enterprise/redis-insight/ 双击软件进行安装&#xff0c;安装完后弹出如下界面&#xff1a; 安装完成后在主界面选择添加Redis数据库&#xff1b; 选择手动添加数据库&#xff0c;输入Redis…

【统计学精要】:使用 Python 实现的统计检验— 1/10

一、介绍 欢迎来到“掌握 Python 统计测试&#xff1a;综合指南”&#xff0c;它将介绍本手册中您需要熟悉使用 Python 的所有基本统计测试和分析方法。本文将为您提供统计测试及其应用的全面介绍&#xff0c;无论您是新手还是经验丰富的数据科学家。 使用来自现实世界的实际示…

HarmonyOS 开发基础(五)对用户名做点啥

一、实现用户名检验 条件渲染 、生命周期 1.规定用户名长度 2.限定使用的数字及字母&#xff08;涉及正则表达&#xff09; // 导出方式直接从文件夹 import MyInput from "../common/commons/myInput" Entry Component /* 组件可以基于struct实现&#xff0c;组件…

驱动开发(中断)

头文件&#xff1a; #ifndef __LED_H__ #define __LED_H__#define PHY_LED1_MODER 0X50006000 #define PHY_LED1_ODR 0X50006014 #define PHY_LED1_RCC 0X50000A28#define PHY_LED2_MODER 0X50007000 #define PHY_LED2_ODR 0X50007014 #define PHY_LED2_RCC 0X50000A28#def…

在word的文本框内使用Endnote引用文献,如何保证引文编号按照上下文排序

问题 如下图所示&#xff0c;我在word中插入了一个文本框&#xff08;为了插图&#xff09;&#xff0c;然后文本框内有引用&#xff0c;结果endnote自动将文本框内的引用优先排序&#xff0c;变成文献[1]了&#xff0c;而事实上应该是[31]。请问如何能让文本框内的排序也自动…

maven install命令:将包安装在本地仓库,供本地的其它工程或者模块依赖

说明 有时候&#xff0c;自己本地的maven工程依赖于本地的其它工程&#xff0c;或者manven工程中的一个模块依赖于另外的模块&#xff0c;可以执行maven的install命令&#xff0c;将被依赖的包安装在maven本地仓库。 示例 一个工程包含几个模块&#xff0c;模块之间存在依赖…

第一个maven项目(IDEA生成)

第一个maven项目&#xff08;IDEA生成&#xff09; 步骤1 配置Project SDK 步骤2 配置maven File->Settings搜索maven

风辞远的科技茶屋:来自未来的信号枪

很久之前&#xff0c;有位朋友问我&#xff0c;现在科技资讯这么发达了&#xff0c;你们还写啊写做什么呢&#xff1f; 我是这么看的。最终能够凝结为资讯的那个新闻点&#xff0c;其实是一系列事情最终得出的结果&#xff0c;而这个结果又会带来更多新的结果。其中这些“得出”…

kagNet:对常识推理的知识感知图网络 2023 AAAI 8.4+8.5

这里写目录标题 摘要介绍概述问题陈述推理流程 模式图基础概念识别模式图构造概念网通过寻找路径来匹配子图基于KG嵌入的路径修剪 知识感知图网络图卷积网络&#xff08;GCN&#xff09;关系路径编码分层注意机制 实验数据集和实验步骤比较方法KAGNET是实施细节性能比较和分析与…

python GUI nicegui初识一(登录界面创建)

最近尝试了python的nicegui库&#xff0c;虽然可能也有一些不足&#xff0c;但个人感觉对于想要开发不过对ui设计感到很麻烦的人来说是很友好的了&#xff0c;毕竟nicegui可以利用TailwindCSS和Quasar进行ui开发&#xff0c;并且也支持定制自己的css样式。 这里记录一下自己利…

【Spring框架】Spring事务

目录 Spring中事务的实现编程式事务声明式事务Transactional 作⽤范围Transactional 参数说明注意事项Transactional ⼯作原理 MySQL 事务隔离级别Spring 事务隔离级别事务传播机制 Spring中事务的实现 Spring中事务操作分为两类&#xff1a; 1.编程式事务 2.声明式事务 编程…

Abaqus 中最常用的子程序有哪些 硕迪科技

在ABAQUS中&#xff0c;用户定义的子程序是一种重要的构件&#xff0c;可以将其插入到Abaqus分析中以增强该软件的功能和灵活性。这些子程序允许用户在分析过程中添加自定义材料模型、边界条件、初始化、加载等特定操作&#xff0c;以便更精准地模拟分析中的现象和现象。ABAQUS…

小白电脑装机(自用)

几个月前买了配件想自己装电脑&#xff0c;结果最后无法成功点亮&#xff0c;出现的问题是主板上的DebugLED黄灯常亮&#xff0c;即DRAM灯亮。对于微星主板的Debug灯&#xff0c;其含义这篇博文中有说明。 根据另一篇博文&#xff0c;有两种可能。 我这边曾将内存条和主板一块…

mongodb-win32-x86_64-2008plus-ssl-3.6.23-signed.msi

Microsoft Windows [版本 6.1.7601] 版权所有 (c) 2009 Microsoft Corporation。保留所有权利。C:\Users\Administrator>cd C:\MongoDB\Server\3.6\binC:\MongoDB\Server\3.6\bin> C:\MongoDB\Server\3.6\bin> C:\MongoDB\Server\3.6\bin>mongod --dbpath C:\Mongo…

c语言每日一练(2)

前言&#xff1a; 每日一练系列&#xff0c;每一期都包含5道选择题&#xff0c;2道编程题&#xff0c;博主会尽可能详细地进行讲解&#xff0c;令初学者也能听的清晰。每日一练系列会持续更新&#xff0c;暑假时三天之内必有一更&#xff0c;到了开学之后&#xff0c;将看学业情…

Nginx启动报错- Failed to start The nginx HTTP and reverse proxy server

根据日志&#xff0c;仍然出现 “bind() to 0.0.0.0:8888 failed (13: Permission denied)” 错误。这意味着 Nginx 仍然无法绑定到 8888 端口&#xff0c;即使使用 root 权限。 请执行以下操作来进一步排查问题&#xff1a; 确保没有其他进程占用 8888 端口&#xff1a;使用以…

python#django数据库一对一/一对多/多对多

一对一OneToOneField 用户和用户信息 搭建 # 一对一 class TestUser(models.Model): usernamemodels.CharField(max_length32) password models.CharField(max_length32) class TestInfo(models.Model): mick_namemodels.CharField(max_length32) usermode…