pytorch- RNN循环神经网络

目录

  • 1. why RNN
  • 2. RNN
  • 3. pytorch RNN layer
    • 3.1 基本单元
    • 3.2 nn.RNN
      • 3.2.1 函数说明
      • 3.2.2 单层pytorch实现
      • 3.2.3 多层pytorch实现
    • 3.3 nn.RNNCell
      • 3.3.1 函数说明
      • 3.3.2 单层pytorch实现
      • 3.3.3 多层pytorch实现
  • 4.完整代码

1. why RNN

以淘宝的评论为例,判断评论是正面还是负面的,如下图:
在这里插入图片描述
上图中每个单词用一个线性层来表示,最后再聚合,每个单词都有一个单独的w和b。
这种方法的问题:

  • 对于长句子甚至是一段文章来说,就很难表示了,因为要用很多线性层和参数表示
  • 没有语境信息
    比如:
    我不喜欢数学,如果没看到不,只看到喜欢,理解的意思就完全不一样了,因此对于一个句子来说,必须有一个语境信息,才能正确理解句子的意思。

为了解决上述问题,RNN增加了权值共享和一个用于保存语境信息的memory h

2. RNN

如下图:
第一个单词不仅考虑到了x输入还考虑到了初始化输入,通过这两个输入产生了一个语境信息h1,第二个单词不仅考虑当前单词的输入还要考虑上一个单词的语境信息h1,以此类推。
在这里插入图片描述
在这里插入图片描述
RNN的核心就是有个语境信息ht,这个语境信息根据当前的输入和上次的语境信息ht-1不断更新自我,并不断向前传。
展开图如下:
在这里插入图片描述

3. pytorch RNN layer

3.1 基本单元

下图展示了ht的计算过程,假设句子长度为5,batch是3,每个单词用100维向量表示,h0初始值用20维表示,最终得到h(t+1)维度为[3,20]
在这里插入图片描述
在这里插入图片描述
上图中rnn=nn.RNN(100,10),100是feature len,10表示hidden len。
输出参数中rnn.weight_hh_10.shape=》[hidden len, hidden len]
rnn.weight_ih_10.shape=》[hidden len, feature len]

3.2 nn.RNN

3.2.1 函数说明

在这里插入图片描述
input_size-输入x的维度
hidden_size-h的维度
num_layers-有几次,默认1
在这里插入图片描述
上图中forward函数的返回值中
ht[num layers, b, h dim]=》是最后时间戳所有memory(h)的状态
out[seq len, b, h dim]=》是所有时间错最后一个memory(h)的状态

3.2.2 单层pytorch实现

在这里插入图片描述

3.2.3 多层pytorch实现

在这里插入图片描述
上图为2层RNN,h变由1层的[1,3,20]变为][2,3,20]([num_layer,b, h dim]),out和1层一样是[10,3,20]
在这里插入图片描述
下图为4层RNN,pytorch代码实现,注意一下输出shape的变化
在这里插入图片描述

3.3 nn.RNNCell

3.3.1 函数说明

nn.RNNCell与nn.RNN的初始化参数是完全一致
在这里插入图片描述
但是输入输出就不一样了,如下图:
在这里插入图片描述

3.3.2 单层pytorch实现

从pytorch代码可以看出,nn.RNNCell是循环处理每个单词,每次自更新h1
在这里插入图片描述

3.3.3 多层pytorch实现

下图为2层nn.RNNCell的pytorch代码,注意1层的h dim与2层的input dim必须一致,下图都是30
从代码中也可以看出第1层的h1作为第2层的输入参与更新h2。
在这里插入图片描述

4.完整代码

import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as Fdef main():rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=1)print(rnn)x = torch.randn(10, 3, 100)out, h = rnn(x, torch.zeros(1, 3, 20))print(out.shape, h.shape)rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=4)print(rnn)x = torch.randn(10, 3, 100)out, h = rnn(x, torch.zeros(4, 3, 20))print(out.shape, h.shape)# print(vars(rnn))print('rnn by cell')cell1 = nn.RNNCell(100, 20)h1 = torch.zeros(3, 20)for xt in x:h1 = cell1(xt, h1)print(h1.shape)cell1 = nn.RNNCell(100, 30)cell2 = nn.RNNCell(30, 20)h1 = torch.zeros(3, 30)h2 = torch.zeros(3, 20)for xt in x:h1 = cell1(xt, h1)h2 = cell2(h1, h2)print(h2.shape)print('Lstm')lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)print(lstm)x = torch.randn(10, 3, 100)out, (h, c) = lstm(x)print(out.shape, h.shape, c.shape)print('one layer lstm')cell = nn.LSTMCell(input_size=100, hidden_size=20)h = torch.zeros(3, 20)c = torch.zeros(3, 20)for xt in x:h, c = cell(xt, [h, c])print(h.shape, c.shape)print('two layer lstm')cell1 = nn.LSTMCell(input_size=100, hidden_size=30)cell2 = nn.LSTMCell(input_size=30, hidden_size=20)h1 = torch.zeros(3, 30)c1 = torch.zeros(3, 30)h2 = torch.zeros(3, 20)c2 = torch.zeros(3, 20)for xt in x:h1, c1 = cell1(xt, [h1, c1])h2, c2 = cell2(h1, [h2, c2])print(h2.shape, c2.shape)if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/43298.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

matplotlib颜色对照表

matplotlib的色彩设置: #------------------------------------------------------------------------------------------------------------------------------- #-------------------------------------------------------------------------------------------------------…

【JavaWeb】登录校验-会话技术(二)JWT令牌

JWT令牌 介绍 JWT全称:JSON Web Token (官网:https://jwt.io/) 定义了一种简洁的、自包含的格式,用于在通信双方以json数据格式安全的传输信息。由于数字签名的存在,这些信息是可靠的。 简洁&#xff1a…

vue和react你怎么选择?

在选择Vue和React之间,其实没有一个绝对的“最佳选择”,因为这取决于你的项目需求、团队熟悉度、开发环境、以及你对这两个框架的个人偏好。下面是一些可以帮助你做出决策的因素: 1. 学习曲线 Vue:Vue的学习曲线相对平缓&#xf…

借助软件资产管理系统,优化Solidworks软件许可证管理

在当今数字化的企业环境中,软件许可证的有效管理对于业务的顺畅运行至关重要。然而,IT 运维部门常常面临着诸如用户部门 SW 许可证不够用、使用紧张等问题,而由于缺乏可靠的数据支持,难以准确判断许可证的短缺程度,这给…

MFC引用C#生成的dll,将dll放置到非exe程序目录,如何操作?

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

信创:鲲鹏(arm64)+麒麟(kylin v10)离线部署k8s和kubesphere(含离线部署新方式)

本文将详细介绍,如何基于鲲鹏CPU(arm64)和操作系统 Kylin V10 SP2/SP3,利用 KubeKey 制作 KubeSphere 和 Kubernetes 离线安装包,并实战部署 KubeSphere 3.3.1 和 Kubernetes 1.22.12 集群。 服务器配置 主机名IPCPUOS用途master-1192.168.10…

【linux高级IO(二)】多路转接之select详解

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:Linux从入门到精通⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学更多操作系统知识   🔝🔝 Linux高级IO 1. 前言2. 初识s…

SCI丨返修一作+通讯

中科四区,JCR2 返修转让一作通讯,5个月左右录用 题目:通过机器学习算法XXXXXXXxxx混凝土力学性能的可靠方法

苍穹外卖--完善登录功能:进行MD5加密

目标 TODO:使用MD5加密方式对明文密码。 实现 password DigestUtils.md5DigestAsHex(password.getBytes());

Face_recognition实现人脸识别

这里写自定义目录标题 欢迎使用Markdown编辑器一、安装人脸识别库face_recognition1.1 安装cmake1.2 安装dlib库1.3 安装face_recognition 二、3个常用的人脸识别案例2.1 识别并绘制人脸框2.2 提取并绘制人脸关键点2.3 人脸匹配及标注 欢迎使用Markdown编辑器 本文基于face_re…

双向链表+Map实现LRU

LRU: LRU是Least Recently Used的缩写,即最近最少使用,是一种常用的页面置换算法,选择最近最久未使用的页面予以淘汰。 核心思想: 基于Map实现k-v存储,双向链表中使用一个虚拟头部和虚拟尾部,虚拟头部的…

BioXcell—InVivoMAb anti-West Nile/dengue virus E protein

研发背景: 西尼罗河病毒(WNV)是一种由蚊虫类介导传播的黄病毒,与引起人类感染性流行病的登革热病毒、黄热病病毒和日本脑炎病毒密切相关。 WNV和登革热病毒(DENV)同属黄病毒科(Flaviviridae)黄热病毒属,是具有小包膜单…

【多模态】41、VILA | 打破常规多模态模型训练策略,在预训练阶段就微调 LLM 被证明能取得更好的效果!

论文:VILA: On Pre-training for Visual Language Models 代码:https://github.com/NVlabs/VILA 出处:NVLabs 时间:2024.05 贡献: 证明在预训练阶段对 LLM 进行微调能够提升模型对上下文任务的效果在 SFT 阶段混合…

Centos7离线安装ElasticSearch7.4.2

一、官网下载相关的安装包 ElasticSearch7.4.2: elasticsearch-7.4.2-linux-x86_64.tar.gz 下载中文分词器: elasticsearch-analysis-ik-7.4.2.zip 二、上传解压文件到服务器 上传到目录:/home/data/elasticsearch 解压文件&#xff1…

免费无限白嫖阿里云服务器

今天,我来分享一个免费且无限使用阿里云服务器的方法,零成本!这适用于日常测试学习,比如测试 Shell 脚本、学习 Docker 安装、MySQL 等等。跟着我的步骤,你将轻松拥有一个稳定可靠的服务器,为你的学习和实践…

错误记录-SpringCloud-OpenFeign测试远程调用

文章目录 1,org.springframework.beans.factory.UnsatisfiedDependencyException: Error creating bean with name memberController: Unsatisfied dependency expressed through field couponFeign2, Receiver class org.springframework.cloud.netflix…

几种不同的方式禁止IP访问网站(PHP、Nginx、Apache设置方法)

1、PHP禁止IP和IP段访问 <?//禁止某个IP$banned_ip array ("127.0.0.1",//"119.6.20.66","192.168.1.4");if ( in_array( getenv("REMOTE_ADDR"), $banned_ip ) ){die ("您的IP禁止访问&#xff01;");}//禁止某个IP段…

SAP S4 环境下,KSU1 Ob52 转为前台操作,不产生传输请求号

参考 OB52/KSV1/KSU1等与Client状态相关的前台操作-y_q_yang se16n 编辑表 T811FLAGS 新增行 CYCLES MAINTENANCE X 即可

WAIC 2024 AI盛宴大会亮点回顾

&#x1f389; 刚刚落幕的WAIC 2024大会&#xff0c;简直是科技迷们的狂欢节&#xff01;这场全球瞩目的人工智能盛会&#xff0c;不仅汇聚了全球顶尖的智慧&#xff0c;还带来了无数令人惊叹的创新成果。让我们一起回顾那些抓人眼球的亮点吧&#xff01; &#x1f525; 超燃开…

Java面试八股之MySQL的redo log和undo log

MySQL的redo log和undo log 在MySQL的InnoDB存储引擎中&#xff0c;redo log和undo log是两种重要的日志&#xff0c;它们各自服务于不同的目的&#xff0c;对数据库的事务处理和恢复机制至关重要。 Redo Log&#xff08;重做日志&#xff09; 功能 redo log的主要作用是确…