【23-24 秋学期】NNDL 作业11 LSTM

目录

习题6-4 推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果

习题6-3P 编程实现下图LSTM运行过程

(一)numpy实现 

(二)使用nn.LSTMCell实现

(三) 使用nn.LSTM实现

总结

(一)推荐

 (二)关于LSTM的有关推导

(三)有关LSTM的代码


参考【【23-24 秋学期】NNDL 作业11 LSTM-CSDN博客

习题6-4 推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果

在我的上一篇博客【【23-24 秋学期】NNDL 作业10 BPTT-CSDN博客】中:

可以看到对于梯度爆炸/消失的推导,其中关键部分就在于递归梯度:\frac{\partial h_{t}}{\partial h_{t-1}}这一部分。

对于求解LSTM网络中的梯度消失也同理-----关键部分也就是内部状态递归梯度\frac{\partial C_{t}}{\partial C_{t-1}} 的变化。

先求解一下这个\frac{\partial C_{t}}{\partial C_{t-1}} 如下:

 然后假设损失为E,与上上图求解类似,如下:

参考【LSTM 如何避免梯度消失问题 - 知乎 (zhihu.com)】 这个作者的那个参考网址我点不开,看不了更原始的。

习题6-3P 编程实现下图LSTM运行过程

(一)numpy实现 

简单分析一下,这个LSTM的运行过程:主要包括三个门,加一个输入,中间层,以及输出层。

输入门的作用是:是否要保存输入结果

遗忘门的作用是:是否要把得到的输入结果存到延迟器中去,也就是隐藏层中

输出门的作用是是否把隐藏层的结果输出

需要强调的是,1)输入这一块也是有一个激活的,但是没使用,隐藏层向输出也是,所以我在实现的时候就没有写这一块激活。

2)对于中间隐藏层而言,他就是一个延迟器,所以当遗忘门为1时,它就会累加输入;为0时会清零。

3)对于激活函数sigmoid,它最后的结果类似一个概率,是处于【0,1】之间的一个数,但是由于这道题中最开始的要求,所以我在实现激活函数时,使用np.round(),把结果>0.5作为1输出,结果<0.5的作为0输出,这是特殊情况。

代码如下:

import numpy as np# x=[x1,x2,x3,bias]
x = [[1, 0, 0, 1], [3, 1, 0, 1], [2, 0, 0, 1], [4, 1, 0, 1], [2, 0, 0, 1], [1, 0, 1, 1], [3, -1, 0, 1], [6, 1, 0, 1], [1, 0, 1, 1]]
# input
input_w = [1, 0, 0, 0]
# 输入门
inputGate_w = [0, 100, 0, -10]
# 遗忘门
forgetGate_w = [0, 100, 0, 10]
# 输出门
outputGate_w = [0, 0, 100, -10]# 激活
def sigmoid(x):return np.round(1 / (1 + np.exp(-x)))# 延时器
hidden = []
y = []
temp = 0.0for input in x:hidden.append(temp)# 输入与与对应权值相乘再相加temp_input = np.sum(np.multiply(input, input_w))# 输入门【得 1 or 得 0】temp_inputGate = sigmoid(np.sum(np.multiply(input, inputGate_w)))# 遗忘门temp_forgetGate = sigmoid(np.sum(np.multiply(input, forgetGate_w)))# 延时器temp = temp_input * temp_inputGate + temp * temp_forgetGate# 输出门temp_outputGate = sigmoid(np.sum(np.multiply(input, outputGate_w)))# 输出temp_y = temp * temp_outputGatey.append(temp_y)print("延时器:", hidden)
print("输出y: ", y)

得到的结果为:

我就止步到此了。

参考了我们班学霸博客【指路:DL Homework 11-CSDN博客

 发现他对于输入和输出的激活函数有研究,在他博客里提到:

课程指路【 李宏毅手撕LSTM_哔哩哔哩_bilibili

我自己的理解就是:在LSTM的模型中,规定了输入输出状态的激活函数 是tanh函数,与我们实现的代码有一些区别。并且nn.LSTM的内部激活函数,没办法修改,所以仿照学霸的,写了一个检验版【也就是加上tanh版本的】

代码如下:

import numpy as np# x=[x1,x2,x3,bias]
x = [[1, 0, 0, 1], [3, 1, 0, 1], [2, 0, 0, 1], [4, 1, 0, 1], [2, 0, 0, 1], [1, 0, 1, 1], [3, -1, 0, 1], [6, 1, 0, 1], [1, 0, 1, 1]]
# input
input_w = [1, 0, 0, 0]
# 输入门
inputGate_w = [0, 100, 0, -10]
# 遗忘门
forgetGate_w = [0, 100, 0, 10]
# 输出门
outputGate_w = [0, 0, 100, -10]# 激活
def sigmoid(x):return 1 / (1 + np.exp(-x))# 延时器
hidden = []
y = []
temp = 0for input in x:hidden.append(temp)# 输入与与对应权值相乘再相加#加tanh激活temp_input = np.tanh(np.sum(np.multiply(input, input_w)))# 输入门【得 1 or 得 0】temp_inputGate = sigmoid(np.sum(np.multiply(input, inputGate_w)))# 遗忘门temp_forgetGate = sigmoid(np.sum(np.multiply(input, forgetGate_w)))# 延时器#加tanh激活temp = np.tanh(temp_input * temp_inputGate + temp * temp_forgetGate)# 输出门temp_outputGate = sigmoid(np.sum(np.multiply(input, outputGate_w)))# 输出temp_y = temp * temp_outputGatey.append(temp_y)rounded_hidden = [round(x) for x in hidden]
print("检验版延时器:", rounded_hidden)rounded_y = [round(x) for x in y]
print("检验版输出y: ", rounded_y)

其中按照下图:

将tanh函数加到输入以及输出之前。 

得到输出为:

 

(二)使用nn.LSTMCell实现

参考【PyTorch - torch.nn.LSTMCell (runebook.dev)】【汉语版,感觉比网页英转汉好用多了】

 可知我们实现的各个值为:

input_size=4

hidden_size=1

偏差bias=False

细胞状态cx=(1,hidden_size)#因为每次运算都是输入的一个批次

同理隐藏状态=(1,hidden_size)

import torch
import torch.nn as nn
#x 维度需要变换,因为LSTMcell接收的是(time_steps,batch_size,input_size)
x = torch.tensor([[1, 0, 0, 1],[3, 1, 0, 1],[2, 0, 0, 1], [4, 1, 0, 1],[2, 0, 0, 1],[1, 0, 1, 1], [3, -1, 0, 1],[6, 1, 0, 1],[1, 0, 1, 1]], dtype=torch.float)
x = x.unsqueeze(1)
#权重
input_w = [1, 0, 0, 0]# input
inputGate_w = [0, 100, 0, -10]# 输入门
forgetGate_w = [0, 100, 0, 10]# 遗忘门
outputGate_w = [0, 0, 100, -10]# 输出门
#输入形状每次一组,一组是【x1,x2,x3,bias】
input_size=4
hidden_size=1
#定义LSTM
cell=nn.LSTMCell(input_size=input_size,hidden_size=hidden_size,bias=False)
#输入隐藏权重,形状为 (4*hidden_size, input_size)
cell.weight_ih.data = torch.tensor([forgetGate_w, inputGate_w, input_w, outputGate_w], dtype=torch.float)
#隐藏权重,形状为 (4*hidden_size, hidden_size)
cell.weight_hh.data=torch.zeros([4*hidden_size,hidden_size])#hx 和 cx 的初始值都需要初始化为全零张量,表示没有历史信息。
hx=torch.zeros(1,hidden_size)#hx 表示隐藏状态
cx=torch.zeros(1,hidden_size)#cx 表示细胞状态
outputs=[]
for i in range(len(x)):#这里没有区分c0,n0和c1,h1:因为使用了递归,也就是这次的输出是下次的输入hx, cx = cell(x[i], (hx, cx))outputs.append(hx.detach().numpy()[0][0])
#约数
outputs_rounded = [round(x) for x in outputs]
print("使用nn.LSTMCell的输出为:",outputs_rounded)

得到输出为:

(三) 使用nn.LSTM实现

参考【PyTorch - torch.nn.LSTM (runebook.dev)

 与Cell不同的是隐藏状态和细胞状态,其中多加了一个有关序列长度的维度。

代码如下:

import torch
import torch.nn as nn#x 维度需要变换,因为LSTMcell接收的是(time_steps,batch_size,input_size)
x = torch.tensor([[1, 0, 0, 1],[3, 1, 0, 1],[2, 0, 0, 1], [4, 1, 0, 1],[2, 0, 0, 1],[1, 0, 1, 1], [3, -1, 0, 1],[6, 1, 0, 1],[1, 0, 1, 1]], dtype=torch.float)
x = x.unsqueeze(1)
#权重
input_w = [1, 0, 0, 0]# input
inputGate_w = [0, 100, 0, -10]# 输入门
forgetGate_w = [0, 100, 0, 10]# 遗忘门
outputGate_w = [0, 0, 100, -10]# 输出门
#输入形状每次一组,一组是【x1,x2,x3,bias】
input_size=4
hidden_size=1
#定义LSTM模型
lstm=nn.LSTM(input_size=input_size,hidden_size=hidden_size,bias=False)
#设置LSTM的权重矩阵
#输入隐藏权重,形状为 (4*hidden_size, input_size)
lstm.weight_ih_l0.data = torch.tensor([forgetGate_w, inputGate_w, input_w, outputGate_w], dtype=torch.float)
#隐藏权重,形状为 (4*hidden_size, hidden_size)
lstm.weight_hh_l0.data=torch.zeros([4*hidden_size,hidden_size])
# 初始化隐藏状态和记忆状态
hx = torch.zeros(1, 1, hidden_size)
cx = torch.zeros(1, 1, hidden_size)# 前向传播
outputs, (hx, cx) = lstm(x, (hx, cx))
#所有维度值为 1 的维度都删除
outputs = outputs.squeeze().tolist()#约数
outputs_rounded = [round(x) for x in outputs]
print("使用nn.LSTM计算的结果为:",outputs_rounded)

结果为:

总结

(一)推荐

首先是给大家推荐一下课程,有老师上课讲的下边这种动图:

然后我在推导有关Ct的递归梯度时,也是参考了这个视频,博主【苏坡爱豆的笑容】讲的很清楚!!!

 【清晰图解LSTM、BPTT、RNN的梯度消失问题】

 还有上个博客听的那个视频【也有关于LSTM推导的内容】:【循环神经网络讲解|随时间反向传播推导(BPTT)|RNN梯度爆炸和梯度消失的原因|LSTM及GRU(解决RNN中的梯度爆炸和梯度消失)-跟李沐老师动手学深度学习】

都非常nice!

还有一篇知乎的文章解释有关避免梯度消失的也特别好!地址如下:LSTM 如何避免梯度消失问题 - 知乎 (zhihu.com)

然后在看了博客:DL Homework 11-CSDN博客 后,感觉可以推荐一下关于pytorch的汉化版【我读英语头昏脑胀,适合不喜欢学英语的同学,指路:PyTorch 1.8 简体中文 (runebook.dev)】

 (二)关于LSTM的有关推导

这个与上一个作业关联挺大的,很类似,都是在递归推导。可以发现,这个模型中减少梯度消失现象的很重要的一个点就是对于门控单元【遗忘门】的把控。可以及时的缓解梯度消失。

(三)有关LSTM的代码

第一个numpy能写出来,但是关于输出输出位置的激活函数没有多想,参考了学霸的代码后发现他研究了关于这一部分,学到了很多---也就是有激活tanh函数只是写的例子里边把激活函数换成了一个输入几输出几的这样一个激活而已【我自己觉得跟去掉了激活没什么区别】

在pytorch默认的模型里是tanh激活函数。

自己在网上搜关于使用LSTM相关模型时写不出来,学会使用现成的工具是一种能力,然后这个LSTM模型好像没有手写过,我还是不太会内部结构【下几个实验有这个内容,我好好+认真写】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/228909.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

dysmsapi

dysmsapi DY - SMS - API 短信服务接口 短信服务_SDK中心-阿里云OpenAPI开发者门户 <!-- 阿里dayu sms api短信群发接口 --><!-- https://mvnrepository.com/artifact/com.aliyun/dysmsapi20170525/2.0.24 --><dependency><groupId>com.aliyun&l…

Python学习笔记第七十四天(OpenCV安装)

Python学习笔记第七十四天 OpenCV安装安装Python安装OpenCV简单使用OpenCV 后记 OpenCV安装 在Windows系统下&#xff0c;安装Python和OpenCV可以按照以下步骤进行&#xff1a; 安装Python 下载Python&#xff1a;在Python官网下载最新的Python安装包&#xff0c;建议选择与…

JS匿名函数之函数表达式与立即执行函数

匿名函数是什么&#xff1f;和具名函数有什么区别&#xff1f;让我为大家介绍一下吧&#xff01; 没有名字的函数&#xff0c;无法直接使用 一.函数表达式 将匿名函数赋值给一个变量&#xff0c;并且通过变量名去调用&#xff0c;我们将这个称为函数表达式 语法&#xff1a; …

Java EE 网络之网络初识

文章目录 1. 网络发展史1.1 独立模式1.2 网络互连1.3 局域网 LAN1.4 广域网 WAN 2. 网络通信基础2.1 IP 地址2.2 端口号2.3 认识协议2.4 五元组2.5 协议分层2.5.1 什么是协议分层2.5.2 分层的作用2.5.3 OSI七层协议2.5.4 TCP/IP五层协议2.5.5 网络设备所在分层 2.6 分装和分用 …

Leetcode的AC指南 —— 链表:24. 两两交换链表中的节点

摘要&#xff1a; Leetcode的AC指南 —— 链表&#xff1a;24. 两两交换链表中的节点。题目介绍&#xff1a;给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能…

【OpenHarmony 北向应用开发】ArkTS语言入门(构建应用页面)

ArkTS语言入门 在学习ArkTS语言之前&#xff0c;我们首先需要一个能够编译并运行该语言的工具 DevEco Studio。 了解ArkTS ArkTS是OpenHarmony优选的主力应用开发语言。ArkTS围绕应用开发在TypeScript&#xff08;简称TS&#xff09;生态基础上做了进一步扩展&#xff0c;继…

有两个循环单链表,链表头指针分别为 h1 和 h2,编写一个函数将 h2 链接到 链表h1 之后,要求处理完仍是一个循环单链表。

题目描述 &#xff1a;有两个循环单链表&#xff0c;链表头指针分别为 h1 和 h2&#xff0c;编写一个函数将 h2 链接到 链表h1 之后&#xff0c;要求处理完仍是一个循环单链表。 分析&#xff1a; 注意题目说的是头指针 h1 和 h2&#xff0c;所以这两个循环单链表并没有头结点…

【MyBatis-Plus】常用的插件介绍(乐观锁、逻辑删除、分页)

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于MyBatis-Plus的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.为什么要使用MyBatis-Plus中的插…

c++程序设计定义一个CForm窗体类,在该类中包括:

定义一个CForm窗体类&#xff0c;在该类中包括&#xff1a; &#xff08;1&#xff09;成员变量&#xff1a;title&#xff08;窗体标题&#xff09;&#xff0c;width&#xff08;窗体宽度&#xff09;&#xff0c;height&#xff08;窗体高度&#xff09;&#xff1b; &…

MySQL中union和union all的区别

一、区别1&#xff1a;取结果的并集 1、union: 对两个结果集进行并集操作, 不包括重复行,相当于distinct, 同时进行默认规则的排序; 2、union all: 对两个结果集进行并集操作, 包括重复行, 即所有的结果全部显示, 不管是不是重复; 二、区别2&#xff1a;获取结果后的操作 1…

【C语言】随机数生成详解,手把手教你,保姆级!!!

目录 rand函数 srand函数 time函数 设置随机数范围 拓展--猜数字游戏 总结 rand函数 C语⾔提供了⼀个函数叫 rand&#xff0c;这函数是可以⽣成随机数的&#xff0c;函数原型如下所⽰ int rand (void); rand函数会返回⼀个伪随机数&#xff0c;这个随机数的范围是在0~RAN…

M个苹果放入N个盘子(递归)

题目&#xff1a; 把M个同样的苹果放在N个同样的盘子里&#xff0c;允许有的盘子空着不放&#xff0c;问共有多少种不同的分法&#xff1f;&#xff08;5&#xff0c;1&#xff0c;1和1&#xff0c;5&#xff0c;1 是同一种分法&#xff09; 输入 每个用例包含二个整数M和N。…

【golang】go执行shell命令行的方法( exec.Command )

所需包: import "os/exec" cmd的用法: cmd : exec.Command("ls", "-lah") //ls是命令,后面是参数 e : cmd.Run() 多个参数的要分开传入: 如:ip link show bond0 cmd :exec.Command("ip","link","show","…

Linux(操作系统)面经——part 1(持续更新中......)

1、说一说常用的 Linux 命令 mkdir创建文件夹&#xff0c;touch创建文件&#xff0c;mv移动文件内容或改名 rm-r 文件名&#xff1a;删除文件 cp拷贝&#xff1a;cp 文件1 文件2&#xff0c;cp-r跨目录拷贝 cp-r 路径1 路径2 vi 插入 &#xff1a;wqb保存退出 :q!强制退出…

【Axure教程】区间评分条

区间评分条是一种图形化的表示工具&#xff0c;用于展示某一范围内的数值或分数&#xff0c;并将其划分成不同的区间。这种评分条通常用于直观地显示数据的分布或某个指标的表现。常用于产品评价、调查和反馈、学术评价、健康评估、绩效评估、满意度调查等场景。 所以今天作者…

DOM树和DOM对象与JS关系的深入研究

const和let使用说明 var不好用&#xff0c;我们如果用变量都是用let&#xff0c;如果用常量乃是不变的量&#xff0c;我们用const&#xff0c;见let const知变量是否可变。比如一个常量在整个程序不会变&#xff0c;但是你用let&#xff0c;是可以的。但是let最好与内部变量改…

SSH连接服务器后执行多条命令

SSH连接服务器后执行多条命令 大家平时有没有遇到自己连接云服务器&#xff0c;ssh 连接上去之后&#xff0c;发现自己的一些小工具用不了 例如go build无法使用 &#xff0c;由于我们安装配置golang 环境的时候&#xff0c;是在文件/etc/profile中写了配置&#xff0c;因此需…

《洛谷深入浅出进阶篇》简单数据结构

本篇文章内容如下&#xff0c;请耐心观看&#xff0c;将持续更新。 简单数组 简单栈 简单队列 简单链表 简单二叉树 简单集合 图的基本概念 二叉堆 线段树 树状数组与字典树 线段树进阶 简单数组&#xff1a; STL可变数组 vector " 我们首先要知道这个容器有什…

Java多线程编程学习

1 线程的概念 多线程是指同一个程序同时存在多个“执行体”&#xff0c;它们可以同时工作 1.1 进程的概念 一次程序的每一次运行都叫做进程&#xff08;一个进程可以包含多个线程 1.2 线程的概念 多线程是指一个程序中多段代码同时并发进行 1.3 主线程的概念 JavaMain中的线程就…

Python语言学习笔记之十(字符串处理)

本课程对于有其它语言基础的开发人员可以参考和学习&#xff0c;同时也是记录下来&#xff0c;为个人学习使用&#xff0c;文档中有此不当之处&#xff0c;请谅解。 字符串处理&#xff1a;以实现字符串的分割、替换、格式化、大小写转换&#xff0c;Python字符串处理是指对Py…