LSTM长短时记忆网络:推导与实现(pytorch)

LSTM长短时记忆网络:推导与实现(pytorch)

  • 背景
  • 推导
    • 遗忘门
    • 输入门
    • 输出门
  • LSTM的改进:GRU
  • 实现

背景

  人类不会每秒钟都从头开始思考。当你阅读这篇文章时,你会根据你对以前单词的理解来理解每个单词。你不会把所有东西都扔掉,重新开始思考。你的思想是持久的。

  传统的神经网络无法做到这一点,这是一个主要缺点。递归神经网络(Rnn)解决了这个问题,这是一种带有记忆的模型,可以将其认为是同一网络的多个副本,每个副本将消息传递给继任者。
在这里插入图片描述
  RNN的优势是,它们可能能够将先前的信息与当前任务联系起来,例如使用以前的视频帧为理解当前帧提供信息。但是在实际情况中,它不一定能做到这一点。

  有时我们只需要查看最近的信息即可执行当前任务。例如,考虑一个语言模型,它试图根据前一个单词来预测下一个单词。如果我们试图预测“云在天空中”中的最后一个词,我们不需要任何进一步的上下文,很明显下一个词将是天空。在这种情况下,相关信息与需要信息的地方之间的差距很小,RNN可以学习使用过去的信息。
  但在某些情况下,我们需要更多的背景信息。不妨试着预测经文中的最后一个字:“我在法国长大…我能说一口流利的法语。最近的信息表明,下一个词可能是一种语言的名称,但如果我们想缩小哪种语言的范围,我们需要从更远的地方开始了解法国的上下文。相关信息与需要信息的点之间的差距完全有可能变得非常大。
  不幸的是,随着这种差距的扩大,RNN变得无法学习连接信息。这就是传统RNN的缺点,很难处理长距离的依赖。

  长短期记忆网络(通常简称为“LSTM”)解决了这个问题,这是一种特殊的RNN,能够学习长期依赖关系。我们将推导LSTM中每一层的结构,并实现一个pytorch版本LSTM。

推导

  长短时记忆网络的思路很简单。传统RNN的隐藏层只有一个状态,即h,它对于短期的输入非常敏感,假如我们再增加一个状态,即c,让它来保存长期的状态,那么问题就解决了:
在这里插入图片描述
  新增加的状态c,称为单元状态(cell state)。我们把上图按照时间维度展开:
在这里插入图片描述
  在长短时记忆网络的前向计算中,通过门(gate)控制向量的变化。门实际上就是一层全连接层,它的输入是一个向量,输出是一个0到1之间的实数向量。那么门可以表示为: g ( x ) = σ ( W x + b ) g(x)=\sigma(Wx+b) g(x)=σ(Wx+b)  由于sigmod函数的性质,门的输出是0到1之间的实数向量,那么,当门输出为0时,任何向量与之相乘都会得到0向量;输出为1时,任何向量与之相乘都不会有任何改变。
  LSTM用两个门来控制单元状态c的内容,一个是遗忘门(forget gate),它决定了上一时刻的单元状态有多少保留到当前时刻ct;另一个是输入门(input gate),它决定了当前时刻网络的输入xt有多少保存到单元状态ct。LSTM用输出门(output gate)来控制单元状态ct有多少输出到LSTM的当前输出值ht。

遗忘门

  遗忘门通过门控制上一时刻的输入的单元状态 c t − 1 c_{t-1} ct1有多少被保留下来,门的权值通过上一时刻的输出值 h t − 1 h_{t-1} ht1和这一时刻的输入值 x t x_t xt得到,即: f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma(W_f\cdot[h_{t-1},x_t]+b_f) ft=σ(Wf[ht1,xt]+bf)  这个过程可以被分解为 f t = W f h h t − 1 + W f h x t f_t=W_{fh}h_{t-1}+W_{fh}x_t ft=Wfhht1+Wfhxt  在我们对门的实现中,我们都遵循这种方式。下图显示了遗忘门的计算:在这里插入图片描述

输入门

  输入门决定了将当前的输入有多少被保留到c中,门的权值计算方式与遗忘门相同,即: i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma(W_i\cdot[h_{t-1},x_t]+b_i) it=σ(Wi[ht1,xt]+bi)  通过门值,我们计算当前输入的单元状态 c ^ t \hat c_t c^t,它是通过上一次的输出和本次输入计算的: c ^ t = t a n h ( W c ⋅ [ h t − 1 , x t ] + b c ) \hat c_t=tanh(W_c\cdot [h_{t-1},x_t]+b_c) c^t=tanh(Wc[ht1,xt]+bc)  下图显示了输入门的计算:
在这里插入图片描述
  到这里,我们可以计算当前时刻的单元状态 c t c_t ct,它是由上一次的单元状态 c t − 1 c_{t-1} ct1乘以遗忘门权值 f t f_t ft,再用当前输入的单元状态 c ^ t \hat c_t c^t乘以输入门权值 i t i_t it,再将两个积加和产生的: c t = f t ⋅ c t − 1 + i t ⋅ c ^ t c_t=f_t\cdot c_{t-1}+i_t\cdot \hat c_t ct=ftct1+itc^t  这样,我们就把LSTM关于当前的记忆 c ^ t \hat c_t c^t和长期的记忆 c t − 1 c_{t-1} ct1组合在一起,形成了新的单元状态 c t c_t ct。由于遗忘门的控制,它可以保存很久很久之前的信息,由于输入门的控制,它又可以避免当前无关紧要的内容进入记忆。

输出门

  输出门的权值计算方式与上面两个门相同: o t = σ ( W o ⋅ [ h t − 1 , x t ] + b i ) o_t=\sigma(W_o\cdot[h_{t-1},x_t]+b_i) ot=σ(Wo[ht1,xt]+bi)  LSTM最终的输出,是由输出门和单元状态共同确定的: h t = o t ⋅ t a n h ( c t ) h_t=o_t\cdot tanh(c_t) ht=ottanh(ct)在这里插入图片描述
  至此,LSTM的推导就讲完了。

LSTM的改进:GRU

  LSTM也有许多缺点,因此提出了很多变体,GRU(Gated Recurrent Unit)就是比较成功的一种,针对LSTM有三个不同的门,参数较多,训练困难的缺点,GRU将LSTM中的输入门和遗忘门合二为一,称为更新门(update gate),控制前边记忆信息能够继续保留到当前时刻的数据量;另一个门称为重置门(reset gate),控制要遗忘多少过去的信息。其结构如下所示,读者可以自行对比:在这里插入图片描述

实现

  我们使用pytorch实现了一个LSTMlayer的模型:

class LstmLayer(torch.nn.Module):def __init__(self, bert_model, fea_dim, dropout):super(LstmLayer, self).__init__()self.fea_dim = fea_dim# 激活函数self.sigmod = nn.Sigmoid()self.tanh = nn.Tanh()# 遗忘门权重矩阵Wfh, Wfxself.Wfh = torch.nn.Linear(self.fea_dim, 1)self.Wfx = torch.nn.Linear(self.fea_dim, 1)# 输入门权重矩阵Wfh, Wfxself.Wih = torch.nn.Linear(self.fea_dim, 1)self.Wix = torch.nn.Linear(self.fea_dim, 1)# 单元状态更新权重矩阵Wch, Wcxself.Wch = torch.nn.Linear(self.fea_dim, self.fea_dim)self.Wcx = torch.nn.Linear(self.fea_dim, self.fea_dim)# 输出门权重矩阵Woh, Woxself.Woh = torch.nn.Linear(self.fea_dim, self.fea_dim)self.Wox = torch.nn.Linear(self.fea_dim, self.fea_dim)def forward(self,x_t,c_t,h_t):# 遗忘门fg = self.calc_gate(x_t,h_t, self.Wfx, self.Wfh,self.sigmod)c_t = fg * c_t# 输入门ig = self.calc_gate(x_t,h_t, self.Wix, self.Wih,self.sigmod)c_t_temp = ig * self.tanh(self.Wch(h_t) + self.Wcx(x_t))c_out = c_t_temp + c_t# 输出门og = self.calc_gate(x_t,h_t, self.Wox, self.Woh,self.sigmod)h_out = self.tanh(c_out) * ogreturn c_out,h_outdef calc_gate(self, x,h, Wx, Wh, activator):# 计算门权值gate_weight = Wx(x) + Wh(h)gate_out = activator(gate_weight)return gate_out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/18841.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Camunda 7.x 系列【64】实战篇之挂起、删除流程模型

有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 2.7.9 本系列Camunda 版本 7.19.0 源码地址:https://gitee.com/pearl-organization/camunda-study-demo 前后端基于若依:https://gitee.com/y_project/RuoYi-Vue 流程设计器基于RuoYi-flowable:https://gi…

2024年6月1日(星期六)骑行禹都甸

2024年6月1日 (星期六)骑行禹都甸(韭葱花),早8:30到9:00,昆明氧气厂门口集合,9:30准时出发【因迟到者,骑行速度快者,可自行追赶偶遇。】 偶遇地点:昆明氧气厂门口集合 ,…

Linux系统维护

1. 批量安装部署 2. 初始化配置 3. 禁用Selinux 永久更改 SELinux 配置: 编辑 SELinux 配置文件:使用文本编辑器打开 /etc/selinux/config 文件: 在配置文件中,找到 SELINUX… 的行。将其值更改为以下选项之一: e…

TypeScript 学习笔记(七):TypeScript 与后端框架的结合应用

1. 引言 在前几篇学习笔记中,我们已经探讨了 TypeScript 的基础知识和在前端框架(如 Angular 和 React)中的应用。本篇将重点介绍 TypeScript 在后端开发中的应用,特别是如何与 Node.js 和 Express 结合使用,以构建强类型、可维护的后端应用。 2. TypeScript 与 Node.js…

YoloV8改进策略:BackBone|融合改进的HCANet网络中的多尺度前馈网络(MSFN)|二次创新|即插即用

本文使用HCANet网络中的多尺度前馈网络来提高Backbone的表征能力和检测精度。即插即用,方便大家移植自己的模型中。 论文指导 原论文中的表述 B. 多尺度前馈网络 在 V i T \mathrm{ViT} ViT 中的原始 FFN 是由两个线性层所构成,这样的设计仅用于单尺度特征聚合。但是,F…

2024 GIAC 全球互联网架构大会:拓数派向量数据库 PieCloudVector 架构设计与案例实践

5月24-25日,msup 和高可用架构联合举办了第11届 GIAC 全球互联网架构大会。会议聚焦“共话AI技术的最新进展、架构实践和未来趋势”主题,邀请了 100 余位行业内的领军人物和革新者,分享”Agent/RAG 技术、云原生、基座大模型“等多个热门技术…

浏览器修改后端返回值

模拟接口响应和网页内容 通过本地覆盖可以模拟接口返回值和响应头,无需 mock 数据工具,比如(Requestly),无需等待后端支持,快速复现在一些数据下的 BUG 等。在 DevTools 可以直接修改你想要的 Fetch/XHR 接…

event.preventDefault()使用指南

event.preventDefault(); 是 JavaScript 中用于阻止默认事件行为的方法。具体而言,它在处理 HTML 元素(如链接和表单)的事件时非常有用。下面是详细的解释和示例,说明它的作用和使用场景。 解释 在 HTML 中,许多元素…

将四种算法的预测结果绘制在一张图中

​ 声明:文章是从本人公众号中复制而来,因此,想最新最快了解各类智能优化算法及其改进的朋友,可关注我的公众号:强盛机器学习,不定期会有很多免费代码分享~ 之前的一期推文中,我们推出了…

RPA在抖音等短视频创作开发的应用

相较于一般人对Ai的漠视或仅仅停留在逗比对话而言,在凭此谋生的专业的行当,或AI应用相对宽泛的领域。融合Ai的自动化辅助办公(创作、演示等)的进步日新月异,这方面的知识还是应尽快了解。 RPA是Robotic process autom…

【ROS2问题记录】ros2 bag play xx.db3失败

报错内容: nvidiaoceanstar:~/yolov8_ros2-Tensorrt$ ros2 bag play rosbag2_2024_04_24-13_55_03_0.db3 /opt/ros/foxy/bin/ros2:6: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html fr…

NoSQL是什么?NoSQL数据库存在SQL注入攻击?

一、NoSQL是什么? NoSQL(Not Only SQL)是一种非关系型数据库的概念。与传统的关系型数据库不同,NoSQL数据库使用不同的数据模型来存储和检索数据。NOSQL数据库通常更适合处理大规模的非结构化和半结构化数据,且能够…

CPU对代码执行效率的优化,CPU的缓存、指令重排序

目录 一、CPU对代码执行效率的优化 1. 指令流水线(Instruction Pipelining) 2. 超标量架构(Superscalar Architecture) 3. 动态指令重排序(Dynamic Instruction Reordering) 4. 分支预测(…

【RuoYi】使用代码生成器完成CRUD操作

一、前言 前面,介绍了如何下载和启动我们的RuoYi框架。为了让小伙伴们认识到ruoyi的强大,那么这篇博客就介绍一下如何使用ruoyi的代码生成器,自动生成前端页面以及后端的对应数据库表的CRUD操作!!!真的很强…

LWIP_TCP 协议

目录 1 TCP 协议简介 1.1 TCP 协议简介 1.2 TCP 的建立连接 1.3 TCP 终止连接 1.4 TCP 报文结构 1.5 lwIP 的 TCP 报文首部数据结构 1.6 lwIP 的 TCP 连接状态图 1 TCP 协议简介 1.1 TCP 协议简介 TCP(Transmission Control Protocol 传输控制协议&#xff0…

MySQL实战行转列(或称为PIVOT)实战sales的表记录了不同产品在不同月份的销售情况,进行输出

有一个sales的表,它记录了不同产品在不同月份的销售情况: productJanuaryFebruaryMarchProduct AJanuary10Product AFebruary20Product BJanuary5Product BFebruary15Product CJanuary8Product CFebruary12 客户需求展示为如下的样子: pro…

斯坦福报告解读4:图解有趣的推理基准(中)

《人工智能指数报告》由斯坦福大学、AI指数指导委员会及业内众多大佬Raymond Perrault、Erik Brynjolfsson 、James Manyika等人员和组织合著,该报告已被公认为最权威、最具信誉人工智能数据与洞察来源之一。 2024年版《人工智能指数报告》是迄今为止最为详尽的一份…

linux下常用的终端命令

文章目录 1. MV移动文件、重命名文件1.1 移动文件:mv [选项] 源文件或目录 目标文件或目录1.2 文件重命名 2. 查找:文件,内容,统计文件2.1 find查找文件2.2 Linux查找文件内容 3. 查看当前用户4. linux修改文件所属用户和组5. 复制…

Token验证流程、代码示例、优缺点和安全策略,一文告诉你。

Token和Session都是用于身份验证和授权的机制,而且Token渐渐成为主流,有不少小伙伴对token的认识不全,这里给大家分享下。 一、什么是Token Token是一种用于身份验证和授权的令牌,通常用于在客户端和服务器之间进行安全的通信。…

SQLITE存储时间数据报警语法错误,syntax error

使用sqllite数据库,有一个时间数据current_time需要插入表中,如下 current_time time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 不能直接将时间戳格式化为字符串并嵌入到SQL语句中,如下: sql f"INSER…