详解pytorch中循环神经网络(RNN、LSTM、GRU)的维度

详解pytorch中循环神经网络(RNN、LSTM、GRU)的维度

  • RNN
    • torch.nn.rnn详解
    • RNN输入输出维度
  • LSTM
    • torch.nn.LSTM详解
    • LSTM输入输出维度
  • GRU
    • torch.nn.GRU详解
    • GRU输入输出维度
  • 三种RNN的示例

首先如果你对RNN、LSTM、GRU不太熟悉,可点击查看。

RNN

torch.nn.rnn详解

torch.nn.RNN(input_size,
hidden_size,
num_layers=1,
nonlinearity=‘tanh’,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理
在这里插入图片描述

参数详解

  • input_size – 输入x中预期特征的数量

  • hidden_size – 隐藏状态h中的特征数量

  • num_layers – 循环层数。例如,设置num_layers=2 意味着将两个LSTM堆叠在一起形成堆叠 LSTM,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1

  • nonlinearity– 使用的非线性。可以是’tanh’或’relu’。默认:‘tanh’

  • bias– 如果False,则该层不使用偏差权重b_ih和b_hh。默认:True

  • batch_first – 如果,则输入和输出张量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。请注意,这不适用于隐藏状态或单元状态。默认:False

  • dropout – 如果非零,则在除最后一层之外的每个LSTM层的输出上 引入Dropout层,dropout 概率等于 。默认值:0.0

  • bidirectional – 如果True, 则成为双向LSTM。默认:False

RNN输入输出维度

rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)

可以看到输入是xh_0,h_0可以是None。如果batch_size是第0维度,需设置batch_first=True
输出则是outputh_n。h_n存了每一层的t时刻的隐藏状态值

# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, h_0=None):if batch_first:x = x.transpose(0, 1)seq_len, batch_size, _ = x.size()if h_0 is None:h_0 = torch.zeros(num_layers, batch_size, hidden_size)...return output, h_n

输入:
x的输入维度:(batch_size, sequence_length, input_size) [前提:batch_first=True]
h_0的维度:(D∗num_layers, hidden_size) [可以为None]

输出: output的输出维度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]
h_n的维度:(D∗num_layers, hidden_size)

LSTM

torch.nn.LSTM详解

torch.nn.LSTM(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
proj_size=0,
device=None,
dtype=None)

原理:

参数详解:
相比于RNN多了proj_size参数,少了nonlinearity参数

  • input_size – 输入x中预期特征的数量

  • hidden_size – 隐藏状态h中的特征数量

  • num_layers – 循环层数。例如,设置num_layers=2 意味着将两个LSTM堆叠在一起形成堆叠 LSTM,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1

  • bias– 如果False,则该层不使用偏差权重b_ih和b_hh。默认:True

  • batch_first – 如果,则输入和输出张量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。请注意,这不适用于隐藏状态或单元状态。默认:False

  • dropout – 如果非零,则在除最后一层之外的每个LSTM层的输出上 引入Dropout层,dropout 概率等于 。默认值:0dropout

  • bidirectional – 如果True, 则成为双向LSTM。默认:False

  • proj_size – 如果,将使用具有相应大小投影的LSTM 。默认值:0

LSTM输入输出维度

LSTM= nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = LSTM(input, (h0, c0))

输入是x,此外h_0c_0可以是None。如果batch_size是第0维度,需设置batch_first=True
输出则是output和一个元组(h_n, c_n)

输入: x的输入维度:(batch_size, sequence_length, input_size)`
[前提:batch_first=True]

输出: output的输出维度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]

具体可参考官方文档:nn.LSTM
在这里插入图片描述

GRU

torch.nn.GRU详解

torch.nn.GRU(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理:
在这里插入图片描述

参数详解:
与上文LSTM相比,缺少了proj_size参数,与RNN相比也缺少了nonlinearity参数

GRU输入输出维度

gru= nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = gru(input, h0)

与RNN一致见上文,相比LSTM少了c_n

三种RNN的示例

import torch
import torch.nn as nnrnn = nn.RNN(10, 20, 2, batch_first=True) # (input_size, hidden_size, num_layer)
lstm = nn.LSTM(10, 20, 2, batch_first=True)
gru = nn.GRU(10, 20, 2, batch_first=True)input = torch.randn(5, 3, 10)  # (batchsize, seq, input_size)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)output_rnn, h_n = rnn(input)
output_lstm, (hn, cn) = lstm(input)
output_gru, h_n2 = gru(input)
print("输入维度:", input.shape)
print(f"RNN 输出维度:{output_rnn.shape}, h_n维度:{h_n.shape}" )
print("LSTM 输出维度:", output_lstm.shape)
print("GRU 输出维度:", output_gru.shape)"""
输入维度: torch.Size([5, 3, 10])
RNN 输出维度:torch.Size([5, 3, 20]), h_n维度:torch.Size([2, 5, 20])
LSTM 输出维度: torch.Size([5, 3, 20])
GRU 输出维度: torch.Size([5, 3, 20])
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/12546.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python数据可视化:层次聚类热图clustermap()

【小白从小学Python、C、Java】 【考研初试复试毕业设计】 【Python基础AI数据分析】 python数据可视化: 层次聚类热图 clustermap() [太阳]选择题 请问关于以下代码表述错误的选项是? import seaborn as sns import matplotlib.pyplot as plt import n…

代码随想录—— 填充每个节点的下一个右侧节点指针(Leetcode116)

题目链接 层序遍历 /* // Definition for a Node. class Node {public int val;public Node left;public Node right;public Node next;public Node() {}public Node(int _val) {val _val;}public Node(int _val, Node _left, Node _right, Node _next) {val _val;left _…

开源的全自动生成视频文案、视频素材、视频字幕、视频背景音乐的AI项目

网址 https://github.com/harry0703/MoneyPrinterTurbo 只需提供一个视频 主题 或 关键词 ,就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐,然后合成一个高清的短视频。 如果用来做视频,可以先收藏一下,值得本地…

51 单片机[2-1]:点亮一个LED

一、在 Keil5 中新建项目 打开 Keil5 ,点击 Project —— new μVision Project 新建文件夹 KeilProject ,以后的项目都在这个文件夹下,再建一个文件夹 2-1 点亮一个LED。在该文件夹下创建名为 Project 的文件,并保存。推荐起这…

Spring Boot:异常处理

Spring Boot 前言使用自定义错误页面处理异常使用 ExceptionHandler 注解处理异常使用 ControllerAdvice 注解处理异常使用配置类处理异常使用自定义类处理异常 前言 在 Spring Boot 中,异常处理是一个重要的部分,可以允许开发者优雅地处理应用程序中可…

复利效应(应用于成长)

应用 每个人在智力、知识、经验上,复利效应都一样,只要能积累的东西,基本上最终都会产生复利效应。 再来看一下复利公式:FP*(1i)^n P本金;i利率;n持有期限。在使用时,一定要注意4个限定条件&a…

AI图书推荐:ChatGPT等生成式AI在高等教育中的应用

自2022年11月以来,ChatGPT及其在高等教育各个层面的影响已成为所有教育对话的核心内容。Chan和Colloton所著的书籍是首批全面探讨ChatGPT与生成式人工智能(GenAI)在高等教育中应用及影响的作品之一。 该书深入研究了针对专业环境定制的AI素养…

基础学习-Git(分布式版本控制系统)

学习视频推荐 http://【黑马程序员Git全套教程,完整的git项目管理工具教程,一套精通git】 https://www.bilibili.com/video/BV1MU4y1Y7h5/?p5&share_sourcecopy_web&vd_source2b85bd9be9213709642d908906c3d863 1、Git环境配置 安装Git Git下…

wireshark_概念

ARP (Address Resolution Protocol)协议,即地址解析协议。该协议的功能就是将IP地址解析成MAC地址。 混杂模式 抓取经过网卡的所有数据包,包括发往本网卡和非发往本网卡的。 非混杂模式 只抓取目标地址是本网卡的数据包,对于发往…

《控制系统实验与综合设计》综合四至六(含程序和题目)

1.电机模型辨识实验 1.1 实验目的 (1)掌握一阶系统阶跃响应的特点,通过实验加深对直流电解模型的理解; (2)掌握系统建模过程中参数的整定,体会参数变化对系统的影响; &#xff0…

单片机开发板上外设资源讲解

单片机开发电路板上简单外设 开发板上各基础外设LED灯按键:数码管介绍液晶屏矩阵键盘扫描的概念LED点阵屏实时时钟蜂鸣器存储器 温度传感器&单总线 开发板上各基础外设 LED灯 中文名:发光二极管 外文名:Light Emitting Diode 简称&…

杨校老师项目之基于单片机STC89C52的智能环境监测系统【嵌入式】

获取全套资料: 有偿获取:mryang511688 技术:C语言、单片机等 摘要: 此设计可分为三个主要部分。此中的温度和湿度的检测功能,通过操纵单总线型温湿度传感器DHT11以数字形式显示,实现了切确测得温湿度的功能…

如何管理多个版本的Node.js

我们如何在本地管理多个版本的Node.js,有没有那种不需要重新安装软件再修改配置文件和环境变量的方法?经过我的查找,还真有这种方式,那就是nvm(Node Version Manager)。 下面我就给大家介绍下NVM的使用 1…

vs2019 c++中模板 enable_if_t 的使用

&#xff08;1&#xff09; 该模板的定义如下&#xff1a; template <bool _Test, class _Ty void> struct enable_if {}; // no member "type" when !_Testtemplate <class _Ty> struct enable_if<true, _Ty> { // type is _Ty for _Testusing …

Golang | Leetcode Golang题解之第89题格雷编码

题目&#xff1a; 题解&#xff1a; func grayCode(n int) []int {ans : make([]int, 1<<n)for i : range ans {ans[i] i>>1 ^ i}return ans }

MSR810-LM快速配置通过LTE模块上网

正文共&#xff1a;1111 字 13 图&#xff0c;预估阅读时间&#xff1a;1 分钟 之前买了一个无线版本的MSR810-W&#xff08;淘了一台二手的H3C企业路由器&#xff0c;就用它来打开网络世界的大门&#xff09;&#xff0c;并整理了一份快速配置&#xff08;脚本案例来了&#x…

三菱FX3U-4AD模拟量电压输入采集实例

硬件&#xff1a;&#xff30;&#xff2c;&#xff23;模块 &#xff26;&#xff38;&#xff13;&#xff27;&#xff21;-&#xff12;&#xff14;&#xff2d;&#xff34; &#xff1b;&#xff21;&#xff0f;&#xff24;模块&#xff26;&#xff38;&#xff13…

SQL——SERVER的建表主要操作

目录 一&#xff1a;数据存储问题 1.表的相关数据 2.表&#xff0c;字段&#xff0c;记录 二&#xff1a;建表 1.创建表头 2. 数据类型 3.保存数据 4.数据冗余 5.使用命令重置表 7.设置主键 一&#xff1a;数据存储问题 1.表的相关数据 表是数据库的基本单位&…

交互原型设计工具 Axure RP 9 for Mac 正式激活版

Axure RP 9 Pro Mac版是Mac平台上的一款专为快速原型设计而生的应用&#xff0c;Axure RP 9 Pro Mac版可以辅助产品经理快速设计完整的产品原型&#xff0c;并结合批注&#xff0c;说明以及流程图&#xff0c;框架图等元素&#xff0c;将产品完整地表述给各方面设计人员&#x…

Android Studio(AS)使用别人的项目与gradle包并运行项目

一、问题描述 在进行AS开发时&#xff0c;我们可能会使用到别人的项目&#xff0c;但发现别人把项目发给我们后会发现gradle项目同步失败o(≧口≦)o&#xff0c;此时计有三&#xff1a; 1.横行霸道、豪取抢夺&#xff1a;直接空降到项目人那里&#xff0c;强他的电脑占为己有…