RNN介绍及Pytorch源码解析

介绍一下RNN模型的结构以及源码,用作自己复习的材料。 

RNN模型所对应的源码在:\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。

RNN的模型图如下:

源码注释中写道,RNN的数学公式:

h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})

h_t表示在t时刻的隐藏状态,

x_t表示在t时刻的输入,

h_{t-1}表示前一层在时间t-1的隐藏状态,或者是在时间“0”的初始隐藏状态。

接下来我们看一下源码中RNN类的初始化(只介绍几个重要的参数):

torch.nn.RNN(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)
  • input_size:输入数据中的特征数(可以理解为嵌入维度 embedding_dim)。
  • hidden_size:处于隐藏状态 h 的特征数(可以理解为输出的特征维度)。
  • num_layers:代表着RNN的层数,默认是1(层),当该参数大于零时,又称为多层RNN。
  • bidirectional:即是否启用双向RNN,默认关闭。

下面是输入部分:

这是Pytorch官方文档中给出的解释:

输入分为input和h_0,当没有提供h_0的时候,h_0默认为0。

当batch_size == Ture时,输入的维度一般为(batch_size * seq_len * emb_dim)。

下面是输出部分:

其中output的维度为(batch_size * seq_len * hidden_size * bidirectional

其中bidirectional表示RNN是双向还是单向的,单向为1,双向为2。

下面使用代码举例:

import torch
import torch.nn as nn
rnn1 = nn.RNN(input_size=20,hidden_size=40,num_layers=4,bidirectional=True)
rnn2 = nn.RNN(input_size=20,hidden_size=40,num_layers=4,bidirectional=False)
tensor1 = torch.randn(5,10,20)
tensor2 = torch.randn(5,10,20)
out1,h_n = rnn1(tensor1)
out2,h_n = rnn2(tensor2)
print(out1.shape) # torch.Size([5, 10, 80])
print(out2.shape) # torch.Size([5, 10, 40])

可以看到当bidirectional=True时,输出的特征维度是hidden_size * 2;

可以看到当bidirectional=False时,输出的特征维度是hidden_size * 1;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222308.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ES6学习(三):Set和Map容器的使用

Set容器 set的结构类似于数组,但是成员是唯一且不会重复的。 创建的时候需要使用new Set([])的方法 创建Set格式数据 let set1 new Set([])console.log(set1, set1)let set2 new Set([1, 2, 3, 4, 5])console.log(set2, set2) 对比看看Set中唯一 let set3 new Set([1, 1,…

多架构容器镜像构建实战

最近在一个国产化项目中遇到了这样一个场景,在同一个 Kubernetes 集群中的节点是混合架构的,也就是说,其中某些节点的 CPU 架构是 x86 的,而另一些节点是 ARM 的。为了让我们的镜像在这样的环境下运行,一种最简单的做法…

Rust语言基础语法使用

1.安装开发工具: RustRover JetBrains: Essential tools for software developers and teams 下载: RustRover: Rust IDE by JetBrains 下载成功后安装并启动RustRover 安装中文语言包插件 重启RustRover生效

vue3引入echarts正确姿势

echarts文档地址: echarts官网地址 echarts配置手册 echarts 模板地址 1、安装 (1)安装echarts包 npm install echarts --save 或者 cnpm install echarts --save(2)安装vue echarts工具包 npm install echart…

持续集成交付CICD:Jenkins使用基于SaltStack的CD流水线下载Nexus制品

目录 一、理论 1.salt常用命令 二、实验 1.SaltStack环境检查 2.Jenkins使用基于SaltStack的CD流水线下载Nexus制品 二、问题 1.salt未找到命令 2.salt简单测试报错 3. wget输出日志过长 一、理论 1.salt常用命令 (1)salt 命令 该 命令执行s…

蓝牙物联网智慧物业解决方案

蓝牙物联网智慧物业解决方案是一种利用蓝牙技术来提高物业管理和服务效率的解决方案。它通过将蓝牙技术与其他智能设备、应用程序和云服务相结合,为物业管理和服务提供更便捷、高效和智能化的支持。 蓝牙物联网智慧物业解决方案包括: 1、设备管理&#…

JupyterNotebook安装依赖 使用conda环境

怎么使用JupyterNotebook安装依赖?(使用conda环境?) 预装的jupyter中有conda,可以进入终端执行相关命令 预装的Jupyter,目标用户是轻度的Jupyter用户,如果想使用conda的多环境等高级功能&#x…

Unity中的ShaderToy

文章目录 前言一、ShaderToy网站二、ShaderToy基本框架1、我们可以在ShaderToy网站中,这样看用到的GLSL文档2、void mainImage 是我们的程序入口,类似于片断着色器3、fragColor作为输出变量,为屏幕每一像素的颜色,alpha一般赋值为…

微信小程序单图上传和多图上传

图片上传主要用到 1、wx.chooseImage(Object object) 从本地相册选择图片或使用相机拍照。 参数 Object object 属性类型默认值必填说明countnumber9否最多可以选择的图片张数sizeTypeArray.<string>[original, compressed]否所选的图片的尺寸sourceTypeArray.<s…

从开源项目中学习如何自定义 Spring Boot Starter 小组件

前言 今天参考的开源组件Graceful Response——Spring Boot接口优雅响应处理器。 具体用法可以参考github以及官方文档。 基本使用 引入Graceful Response组件 项目中直接引入如下maven依赖&#xff0c;即可使用其相关功能。 <dependency><groupId>com.feiniaoji…

螺旋矩阵算法(leetcode第59题)

题目描述&#xff1a; 给你一个正整数 n &#xff0c;生成一个包含 1 到 n2 所有元素&#xff0c;且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。示例 1&#xff1a;输入&#xff1a;n 3 输出&#xff1a;[[1,2,3],[8,9,4],[7,6,5]] 示例 2&#xff1a;输入&#…

电商API代码如何接入写

要接入一个API&#xff0c;通常需要遵循以下步骤&#xff1a; 1. 【了解API文档】&#xff1a;首先&#xff0c;需要了解你想要接入的API的文档。这些文档通常会包含关于如何使用该API的重要信息&#xff0c;比如基本的请求格式、可用的端点&#xff08;endpoints&#xff09;…

Process On在线绘制流程图

目录 一.ProcessOn 1.1.介绍 1.2.直接网上使用 二.绘制门诊流程图 三.绘制住院流程图 四.绘制药库采购入库流程图 五.绘制OA会议流程图 今天就到这里了哦!!!希望能帮到你哦&#xff01;&#xff01;&#xff01; 一.ProcessOn 1.1.介绍 ProcessOn&#xff08;流程&#…

Linux,Web网站服务(一)

1.准备工作 为了避免发生端口冲突&#xff0c;程序冲突等现象&#xff0c;建议卸载使用RPM方式安装的httpd [rootnode01 ~]# rpm -e http --nodeps 挂载光盘到/mnt目录 [rootnode01 ~]# mount /dev/cdrom /mnt Apache的配置及运行需要apr.pcre等软件包的支持&#xff0c;因此…

牛客网SQL训练2—SQL基础进阶

文章目录 一、基本查询二、数据过滤三&#xff1a;函数四&#xff1a;分组聚合五&#xff1a;子查询六&#xff1a;多表连接七&#xff1a;组合查询八&#xff1a;技能专项-case when使用九&#xff1a;多表连接-窗口函数十&#xff1a;技能专项-having子句十一&#xff1a;技能…

Web信息收集,互联网上的裸奔者

Web信息收集&#xff0c;互联网上的裸奔者 1.资产信息收集2.域名信息收集3.子域名收集4.单点初步信息收集网站指纹识别服务器类型(Linux/Windows)网站容器(Apache/Nginx/Tomcat/IIS)脚本类型(PHP/JSP/ASP/ASPX)数据库类型(MySQL/Oracle/Accees/SqlServer) 5.单点深入信息收集截…

Linux-----10、查找命令

# 查找命令 # 1、 命令查找 Linux下一切皆文件&#xff01; which 命令 &#xff1a;找出命令的绝对路径 whereis 命令 &#xff1a;找出命令的路径以及文档手册信息 [rootheima ~]# which mkdir /usr/bin/mkdir[rootheima ~]# whereis mkdir mkdir: /usr/bin/mkdir /usr/…

一款计算机顶会爬取解析系统 paper info

一款计算机顶会爬取解析系统 paper info 背景项目实现的功能 技术方案架构设计项目使用的技术选型 使用方法本地项目部署使用ChatGPT等大模型创建一个ChatGPT助手使用阿里云 顶会数据量 百度网盘pfd文件json文件 Q&A github链接 &#xff1a;https://github.com/codebricki…

(反序列化)小记录

目录 [CISCN 2023 华北]ez_date 绕过MD5和sha1强相关绕过 date()绕过 payload [FSCTF 2023]ez_php [CISCN 2023 华北]ez_date <?php error_reporting(0); highlight_file(__FILE__); class date{public $a;public $b;public $file;public function __wakeup(){if(is_a…

【快刊】Springer旗下2区,仅1个月左右录用,国人极其友好!

计算机类 • 好刊解读 今天小编带来Springer旗下计算机领域好刊&#xff0c;如您有投稿需求&#xff0c;可作为重点关注&#xff01;后文有相关领域真实发表案例&#xff0c;供您投稿参考~ 01 期刊简介 ✅出版社&#xff1a;Springer ✅影响因子&#xff1a;4.0-5.0 ✅期刊…