LSTM和GRU的介绍以及Pytorch源码解析

介绍一下LSTM模型的结构以及源码,用作自己复习的材料。 

LSTM模型所对应的源码在:\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。

上次上一篇文章介绍了RNN序列模型,但是RNN模型存在比较严重的梯度爆炸和梯度消失问题。

本文介绍的LSTM模型解决的RNN的大部分缺陷。

首先展示LSTM的模型框架:

下面是LSTM模型的数学推导公式:

\begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ \end{array}

h_t表示t时刻的隐藏状态,c_t表示t时刻的记忆细胞状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。

i_t,f_t,g_t,o_t 分别是输入门、遗忘门、单元门和输出门。

这张图片比较好的介绍了各个门之间的交互关系以及输入输出,大家可以放大看一下。

接下来展示GRU的框架模型:

下面是GRU的数学推导公式:

r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}

h_t表示t时刻的隐藏状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。r_t,n_t,z_t分别表示重置门更新门和新建门

上面的图片可以更直观的看到GRU中是如何迭代的。

接下来我们看一下源码中LSTM和GRU类的初始化(只介绍几个重要的参数):

torch.nn.LSTM(self, input_size, hidden_size, num_layers=1,bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None,dtype=None)
torch.nn.GRU(self, input_size, hidden_size, num_layers=1,bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)
  • input_size:输入数据中的特征数(可以理解为嵌入维度 embedding_dim)。
  • hidden_size:处于隐藏状态 h 的特征数(可以理解为输出的特征维度)。
  • num_layers:代表着RNN的层数,默认是1(层),当该参数大于零时,又称为多层RNN。
  • bidirectional:即是否启用双向LSTM(GRU),默认关闭。

LSTM与GRU都是特殊的RNN,因此输入输出可以参考的上一篇介绍RNN的文章,在这里直接进行代码举例。

lstm1 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=True)
lstm2 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=False)gru1 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=True)
gru2 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=False)tensor1 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)
tensor2 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)out_lstm1,(hn, cn) = lstm1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_lstm2,(hn, cn) = lstm2(tensor2)  # (batch_size * seq_len * (hidden_size * bidirectional))out_gru1,h_n = gru1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_gru2,h_n = gru2(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))print(out_lstm1.shape)  # torch.Size([5, 10, 80])
print(out_lstm2.shape)  # torch.Size([5, 10, 40])print(out_gru1.shape)  # torch.Size([5, 10, 50])
print(out_gru2.shape)  # torch.Size([5, 10, 25])

维度已经在注释中给大家标注上了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222494.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT-CAD-3D显示操作工具

QT-CAD-3D显示操作工具 一、效果展示二、核心程序三、程序链接 一、效果展示 二、核心程序 TDF_LabelSequence DxfReader::transfer(DocumentPtr doc, TaskProgress* progress) {TDF_LabelSequence seqLabel;Handle_XCAFDoc_ShapeTool shapeTool doc->xcaf().shapeTool();…

大数据技术13:HBase分布式列式数据库

前言:2007年Powerset的工作人员,通过google的论文开发出了BigTable的java版本,即HBASE。2008年HBASE贡献给了Apache。HBase 需要依赖 JDK 环境。 一、Hadoop的局限 HBase 是一个构建在 Hadoop 文件系统之上的面向列的数据库管理系统。 要想…

微服务学习:Gateway服务网关

一,Gateway服务网关的作用: 路由请求:Gateway服务网关可以根据请求的URL或其他标识符将请求路由到特定的微服务。 负载均衡:Gateway服务网关可以通过负载均衡算法分配请求到多个实例中,从而平衡各个微服务的负载压力。…

爬虫的基本介绍 , 什么是爬虫 , 爬虫的主要功能

走进爬虫 1. 什么是爬虫? 本节课程的内容是介绍什么是爬虫?爬虫有什么用?以及爬虫是如何实现的?从这三点一起来寻找答案! 1.1 初识网络爬虫 网络爬虫(又被称为网页蜘蛛,网络机器人&#xff…

PythonStudio:一款国人写的python及窗口开发编辑IDE,可以替代pyqt designer等设计器了

本款软件只有十几兆,功能算是强大的,国人写的,很不错的python界面IDE.顶部有下载链接。下面有网盘下载链接,或者从官网直接下载。 目前产品免费,以后估计会有收费版本。主页链接:PythonStudio-硅量实验室 作…

阿里云Centos8安装Dockers详细过程

一、卸载旧版本 较旧的 Docker 版本称为 docker 或 docker-engine 。如果已安装这些程序,请卸载它们以及相关的依赖项。 yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrotate \docker-logrotate \do…

服务器数据恢复-raid5多块磁盘掉线导致上层卷无法挂载的数据恢复案例

服务器数据恢复环境: 一台服务器中有一组由24块FC硬盘组建的raid5磁盘阵列,linux操作系统ext3文件系统,服务器上层部署有oracle数据库。 服务器故障&检测: raid5阵列中有两块硬盘出现故障掉线,导致服务器上层卷无法…

大文件加密传输助力企业数据交互安全

在当前信息时代,数据成为企业的关键资产和竞争优势。企业为提高效率和创新能力,需要与内外部合作伙伴进行数据交换与协作。然而,在大量数据在网络上传输时,数据安全成为企业不可忽视的挑战。如何确保数据的机密性、完整性和可用性…

【Linux】信号--信号初识/信号的产生方式/信号的保存

文章目录 一、信号初步理解1.生活角度的信号2.技术应用角度的信号 二、信号的产生方式1.通过终端按键产生信号2.调用系统函数向进程发信号3.硬件异常产生信号4.由软件条件产生信号5.进程退出时的核心转储问题 三、信号的保存1.信号其他相关常见概念2.信号在内核中的表示3.sigse…

ubuntu debian mini安装系统 有线选项消失或ens33 ethernet 未托管解决方法

nmcli device status#修改NetworkManager.conf如下 sed s/false/true/ /etc/NetworkManager/NetworkManager.confsed -i s/false/true/ /etc/NetworkManager/NetworkManager.conf#重启生效systemctl restart NetworkManager

智能优化算法应用:基于蝠鲼觅食算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于蝠鲼觅食算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于蝠鲼觅食算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.蝠鲼觅食算法4.实验参数设定5.算法结果6.…

phpMyAdmin的常见安装位置

nginx的日志显示有人一直在尝试访问phpMyAdmin的setup.php,用了各种位置。 其实我只有一个nginx,别的什么也没有。 47.99.136.156 - - [01:44:37 0800] "GET http://abc.com:80/phpMyAdmin/scripts/setup.php HTTP/1.0" 404 162 "-"…

生成树基本实验

背景 某公司的二层交换网络中,为了提高网络可靠性,故在二层交换网络中增加冗余链路。为了阻 止冗余链路可能带来的广播风暴,MAC地址漂移等负面影响,需要在交换机之间部署生成树 协议。 实验 一.配置stp en 开启 stp en stp …

PPINN Parareal physics-informed neural network for time-dependent PDEs

论文阅读:PPINN Parareal physics-informed neural network for time-dependent PDEs PPINN Parareal physics-informed neural network for time-dependent PDEs简介方法PPINN加速分析 实验确定性常微分方程随机常微分方程Burgers 方程扩散反应方程 总结 PPINN Par…

R语言【rgbif】——什么是多值传参?如何在rgbif中一次性传递多个值?多值传参时的要求有哪些?

rgbif版本:3.7.8.1 什么是多值传参? 您是否在使用rgbif时设想过,给某个参数一次性传递许多个值,它将根据这些值独立地进行请求,各自返回独立的结果。 rgbif支持这种工作模式,但是具体的细节需要进一步地…

新版Spring Security6.2 - Digest Authentication

前言: 书接上文,上次翻译basic的这页,这次翻译Digest Authentication这页。 摘要认证-Digest Authentication 官网的警告提示:不应在应用程序中使用摘要式身份验证,因为它不被认为是安全的。最明显的问题是您必须以…

IDEA中Terminal配置为bash

简介 我们日常命令行都是使用Linux的bash指令,但是我们的开发基本都是基于Windows上的IDEA进行开发的,对此我们可以通过将IDEA将终端Terminal改为git bash自带的bash.exe解决问题。 配置步骤 安装GIT 这步无需多说了,读者可自行到官网下载…

大模型时代-从0开始搭建大模型

开发一个简单模型的步骤; 搭建一个大模型的过程可以分为以下几个步骤: 数据收集和处理模型设计模型训练模型评估模型优化 下面是一个简单的例子,展示如何使用Python和TensorFlow搭建一个简单的大模型。 数据收集和处理 首先,我…

Python接口自动化 —— Json 数据处理实战(详解)

简介   上一篇说了关于json数据处理,是为了断言方便,这篇就带各位小伙伴实战一下。首先捋一下思路,然后根据思路一步一步的去实现和实战,不要一开始就盲目的动手和无头苍蝇一样到处乱撞,撞得头破血流后而放弃了。不仅…

作业12.11

1 完善对话框,点击登录对话框,如果账号和密码匹配,则弹出信息对话框,给出提示”登录成功“,提供一个Ok按钮,用户点击Ok后,关闭登录界面,跳转到其他界面 如果账号和密码不匹配&…