【论文复现】LSTM长短记忆网络

LSTM

  • 前言
  • 网络架构
    • 总线
    • 遗忘门
    • 记忆门
    • 记忆细胞
    • 输出门
  • 模型定义
    • 单个LSTM神经元的定义
    • LSTM层内结构的定义
  • 模型训练
  • 模型评估
  • 代码细节
    • LSTM层单元的首尾的处理
    • 配置Tensorflow的GPU版本

前言

LSTM作为经典模型,可以用来做语言模型,实现类似于语言模型的功能,同时还经常用于做时间序列。由于LSTM的原版论文相关版权问题,这里以colah大佬的博客为基础进行讲解。之前写过一篇Tensorflow中的LSTM详解,但是原理部分跟代码部分的联系并不紧密,实践性较强但是如果想要进行更加深入的调试就会出现原理性上面的问题,因此特此作文解决这个问题,想要用LSTM这个有趣的模型做出更加好的机器学习效果😊。

网络架构

LSTM框架图
这张图展示了LSTM在整体结构,下面就开始分部分介绍中间这个东东。

总线

在这里插入图片描述
这条是总线,可以实现神经元结构的保存或者更改,如果就是像上图一样一条总线贯穿不做任何改变,那么就是不改变细胞状态。那么如果想要改变细胞状态怎么办?可以通过来实现,这里的门跟高中生物中学的神经兴奋阈值比较像,用数学来表示就是sigmoid函数或者其他的激活函数,当门的输入达到要求,门就会打开,允许当前门后面的信息“穿过”门改变主线上面传递的信息,如果把每一个神经元看成一个时间节点,那么从上一个时间节点传到下一个时间节点过程中的门的开启与关闭就实现了时间序列数据的信息传递。
在这里插入图片描述

遗忘门

在这里插入图片描述
首先是遗忘门,这个门的作用是决定从上一个神经元传输到当前神经元的数据丢弃的程度,如果经过sigmoid函数以后输出0表示全部丢弃,输出1表示全部保留,这个层的输入是旧的信息和当前的新信息。

σ \sigma σ:sigmoid函数
W f W_f Wf:权重向量
b f b_f bf:偏置项,决定丢弃上一个时间节点的程度,如果是正数,表示更容易遗忘,如果是负数,表示比较容易记忆
h t − 1 h_{t-1} ht1:上一个时刻的输入
x t x_t xt:当前层的输入

记忆门

在这里插入图片描述
接下来是记忆门,这个门决定要记住什么信息,同时决定按照什么程度记住上一个状态的信息。

i t i_t it:在时间步t时刻的输入门激活值,计算方法跟上面的遗忘门是一样的,只是目的不一样,这里是记忆
C ~ t \tilde{C}_{t} C~t:表示上一个时刻的信息和当前时刻的信息的集合,但是是规则化到[-1,1]这个范围内了的

记忆细胞

在这里插入图片描述
有了上面的要记忆的信息和要丢弃的信息,记忆细胞的功能就可以得到实现,用 f t f_t ft这个标量决定上一个状态要遗忘什么,用 i t i_t it这个标量决定上一个状态要记住什么以及当前状态的信息要记住什么。这样就形成了一个记忆闭环了。

输出门

在这里插入图片描述
最后,在有了记忆细胞以后不仅仅不要将当前细胞状态记住,还要将当前的信息向下一层继续传输,实现公式中的状态转移。

o t o_t ot:跟前面的门公式都一样,但是功能是决定输出的程度
h t h_t ht:将输出规范到[-1,1]的区间,这里有两个输出的原因是在构建LSTM网络的时候需要有纵向向上的那个 h t h_t ht,然而在当前层的LSTM的神经元之间还是首尾相接的😍。

模型定义

单个LSTM神经元的定义


# 定义单个LSTM单元
# 定义单个LSTM单元
class My_LSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size# 初始化门的权重和偏置,由于每一个神经元都有自己的偏置,所以在定义单元内部定义self.Wf = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bf = nn.Parameter(torch.Tensor(hidden_size))self.Wi = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bi = nn.Parameter(torch.Tensor(hidden_size))self.Wo = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bo = nn.Parameter(torch.Tensor(hidden_size))self.Wg = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))self.bg = nn.Parameter(torch.Tensor(hidden_size))# 初始化输出层的权重和偏置self.W = nn.Parameter(torch.Tensor(hidden_size, output_size))self.b = nn.Parameter(torch.Tensor(output_size))# 用于计算每一种权重的函数def cal_weight(self, input, weight, bias):return F.linear(input, weight, bias)# x是输入的数据,数据的格式是(batch, seq_len, input_size),包含的是batch个序列,每个序列有seq_len个时间步,每个时间步有input_size个特征def forward(self, x):# 初始化隐藏层和细胞状态h = torch.zeros(1, 1, self.hidden_size).to(x.device)c = torch.zeros(1, 1, self.hidden_size).to(x.device)# 遍历每一个时间步for i in range(x.size(1)):input = x[:, i, :].view(1, 1, -1) # 取出每一个时间步的数据# 计算每一个门的权重f = torch.sigmoid(self.cal_weight(input, self.Wf, self.bf)) # 遗忘门i = torch.sigmoid(self.cal_weight(input, self.Wi, self.bi)) # 输入门o = torch.sigmoid(self.cal_weight(input, self.Wo, self.bo)) # 输出门C_ = torch.tanh(self.cal_weight(input, self.Wg, self.bg)) # 候选值# 更新细胞状态c = f * c + i * C_# 更新隐藏层h = o * torch.tanh(c) # 将输出标准化到-1到1之间output = self.cal_weight(h, self.W, self.b) # 计算输出return output

LSTM层内结构的定义

class My_LSTMNetwork(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(My_LSTMNetwork, self).__init__()self.hidden_size = hidden_sizeself.lstm = My_LSTM(input_size, hidden_size)  # 使用自定义的LSTM单元self.fc = nn.Linear(hidden_size, output_size)  # 定义全连接层def forward(self, x):h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)out, _ = self.lstm(x, (h0, c0))  # LSTM层的前向传播out = self.fc(out[:, -1, :])  # 全连接层的前向传播return out

模型训练

history = model.fit(trainX, trainY, batch_size=64, epochs=50, validation_split=0.1, verbose=2)
print('compilation time:', time.time()-start)

模型评估

为了更加直观展示,这里用画图的方法进行结果展示。

fig3 = plt.figure(figsize=(20, 15))
plt.plot(np.arange(train_size+1, len(dataset)+1, 1), scaler.inverse_transform(dataset)[train_size:], label='dataset')
plt.plot(testPredictPlot, 'g', label='test')
plt.ylabel('price')
plt.xlabel('date')
plt.legend()
plt.show()

代码细节

LSTM层单元的首尾的处理

  • 首部:由于第一个节点不用接受来自上一个节点的输入,不需要有输入,当然也有一些是添加标识。

  • 尾部:由于已经进行到当前层的最后一个节点,因此输出只需要向下一层进行传递而不用向下一个节点传递,添加标识也是可以的。

配置Tensorflow的GPU版本

这一篇写的比较好,我自己的硬件环境如下图所示,需要的可以借鉴一下,当然也可以在我提供的代码链接直接用我给的environment.yml一键构建环境😃。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/16114.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Torch学习笔记】

作者:zjk 和 的区别是逐元素相乘,是矩阵相乘 cat stack 的区别 cat stack 是用于沿新维度将多个张量堆叠在一起的函数。它要求所有输入张量具有相同的形状,并在指定的新维度上进行堆叠。

【NumPy】关于numpy.mean()函数,看这一篇文章就够了

🧑 博主简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向…

Android11热点启动和关闭

Android官方关于Wi-Fi Hotspot (Soft AP) 的文章:https://source.android.com/docs/core/connect/wifi-softap?hlzh-cn 在 Android 11 的WifiManager类中有一套系统 API 可以控制热点的开和关,代码如下: 开启热点: // SoftApC…

国际版Tiktok抖音运营流量实战班:账号定位/作品发布/热门推送/等等-13节

课程目录 1-tiktok账号定位 1.mp4 2-tiktok作品发布技巧 1.mp4 3-tiktok数据功能如何开通 1.mp4 4-tiktok热门视频推送机制 1.mp4 5-如何发现热门视频 1.mp4 6-如何发现热门音乐 1.mp4 7-如何寻找热门标签 1.mp4 8-如何寻找垂直热门视频 1.mp4 9-如何发现热门挑战赛 1…

【Python特征工程系列】一文教你使用PCA进行特征分析与降维(案例+源码)

这是我的第287篇原创文章。 一、引言 主成分分析(Principal Component Analysis, PCA)是一种常用的降维技术,它通过线性变换将原始特征转换为一组线性不相关的新特征,称为主成分,以便更好地表达数据的方差。 在特征重要…

DAMA数据管理知识体系必背18张框图

近期对数据管理知识体系中比较重要的框图进行了梳理总结,总共有18张框图,供大家参考。主要涉及数据管理、数据治理阶段模式、数据安全需求、主数据管理关键步骤,主数据架构、DW架构、数据科学的7个阶段、数据仓库建设活动、信息收敛三角、大数据分析架构图、数据管理成熟度等…

QGIS开发笔记(二):Windows安装版二次开发环境搭建(上):安装OSGeo4W运行依赖其Qt的基础环境Demo

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/139136356 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

Charles抓包App_https_夜神模拟器

Openssl安装 下载安装 下载地址: http://slproweb.com/products/Win32OpenSSL.html 我已经下载好了64位的,也放出来: 链接:https://pan.baidu.com/s/1Nkur475YK48_Ayq_vEm99w?pwdf4d7 提取码:f4d7 --来自百度网…

地下城游戏(leetcode)

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 地下城游戏https://leetcode.cn/problems/dungeon-game/description/ 图解分析&#xff1a; 代码 class Solution { public:int calculateMinimumHP(vector<vector<int>>& vv) {int row vv.size(), col …

Zookeeper 安装教程和使用指南

一、Zookeeper介绍 ZooKeeper 是 Apache 软件基金会的一个开源项目&#xff0c;主要基于 Java 语言实现。 Apache ZooKeeper 是一个开源的分布式应用程序协调服务&#xff0c;提供可靠的数据管理通知、数据同步、命名服务、分布式配置服务、分布式协调等服务。 关键特性 分布…

Nginx实战(安装部署、常用命令、反向代理、负载均衡、动静分离)

文章目录 1. nginx安装部署1.1 windows安装包1.2 linux-源码编译1.3 linux-docker安装 2. nginx介绍2.1 简介2.2 常用命令2.3 nginx运行原理2.3.1 mater和worker2.3.3 Nginx 的工作原理 2.4 nginx的基本配置文件2.4.1 location指令说明 3. nginx案例3.1 nginx-反向代理案例013.…

数据结构和算法|排序算法系列(三)|插入排序(三路排序函数std::sort)

首先需要你对排序算法的评价维度和一个理想排序算法应该是什么样的有一个基本的认知&#xff1a; 《Hello算法之排序算法》 主要内容来自&#xff1a;Hello算法11.4 插入排序 插入排序的整个过程与手动整理一副牌非常相似。 我们在未排序区间选择一个基准元素&#xff0c;将…

移动云以深度融合之服务,令“大”智慧贯穿云端

移动云助力大模型&#xff0c;开拓创新领未来。 云计算——AI模型的推动器。 当前人工智能技术发展的现状和趋势&#xff0c;以及中国在人工智能领域的发展策略和成就。确实&#xff0c;以 ChatGPT 为代表的大型语言模型在自然语言处理、文本生成、对话系统等领域取得了显著的…

项目管理:敏捷实践框架

一、初识敏捷 什么是敏捷(Agile)?敏捷是思维方式。 传统开发模型 央企,国企50%-60%需求分析。整体是由文档控制的过程管理。 传统软件开发面临的问题: 交付周期长:3-6个月甚至更长沟通效果差:文档化沟通不及时按时发布低:技术债增多无法发版团队士气弱:死亡行军不关注…

Vmware 17安装 CentOS9

前言 1、提前下载好需要的CentOS9镜像&#xff0c;下载地址&#xff0c;这里下载的是x86_64 2、提前安装好vmware 17&#xff0c;下载地址 &#xff0c;需要登录才能下载 安装 1、创建新的虚拟机 2、在弹出的界面中选择对应的类型&#xff0c;我这里选择自定义&#xff0c;点…

python command乱码怎么解决

python command乱码怎么解决&#xff1f;具体方法如下&#xff1a; 先引入import sys 再加一句&#xff1a;typesys.getfilesystemencoding() 然后在输出乱码的数据的后面加上“.decode(utf-8).encode(type)”。 比如输入“ss”乱码。 就写成print ss.decode(utf-8).encode(typ…

【话题】AIGC行业现在适合进入吗

大家好&#xff0c;我是全栈小5&#xff0c;欢迎阅读小5的系列文章&#xff0c;这是《话题》系列文章 目录 引言AIGC的发展阶段市场需求时机是否合适优势挑战 文章推荐 引言 在撰写关于当前是否适合进入AIGC&#xff08;人工智能生成内容&#xff09;行业的文章之前&#xff0…

从零实现Llama3中文版

1.前言 一个月前&#xff0c;Meta 发布了开源大模型 llama3 系列&#xff0c;在多个关键基准测试中优于业界 SOTA 模型&#xff0c;并在代码生成任务上全面领先。 此后&#xff0c;开发者们便开始了本地部署和实现&#xff0c;比如 llama3 的中文实现、llama3 的纯 NumPy 实现…

数据结构——链式二叉树知识点以及链式二叉树数据操作函数详解!!

引言&#xff1a;该博客将会详细的讲解二叉树的三种遍历方法&#xff1a;前序、中序、后序&#xff0c;也同时会讲到关于二叉树的数据操作函数。值得一提的是&#xff0c;这些函数几乎都是建立在一个函数思想——递归之上的。这次的代码其实写起来十分简单&#xff0c;用不了几…

告别红色波浪线:tsconfig.json 配置详解

使用PC端的朋友&#xff0c;请将页面缩小到最小比例&#xff0c;阅读最佳&#xff01; tsconfig.json 文件用于配置 TypeScript 项目的编译选项。如果配不对&#xff0c;就会在项目中显示一波又一波的红色波浪线&#xff0c;警告你这些地方的类型声明存在问题。 一般我们遇到这…