pytorch实现RNN网络

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54418.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt+FFmpeg开发视频播放器笔记(三):音视频流解析封装

音频解析 音频解码是指将压缩的音频数据转换为可以再生的PCM(脉冲编码调制)数据的过程。 FFmpeg音频解码的基本步骤如下: 初始化FFmpeg解码器(4.0版本后可省略): 调用av_register_all()初始化编解码器。 调用avcodec_register_all()注册所有编解码器。 打开输入的音频流:…

pthread_cond_signal 和pthread_cond_wait

0、pthread_join()函数作用: pthread_join() 函数会一直阻塞调用它的线程,直至目标线程执行结束(接收到目标线程的返回值),阻塞状态才会解除。如果 pthread_join() 函数成功等到了目标线程执行结束(成功获取…

运行 xxxxApplication 时出错。命令行过长。 通过 JAR 清单或通过类路径文件缩短命令行,然后重新运行。

一、问题描述 运行 xxxxApplication 时出错。命令行过长。 通过 JAR 清单或通过类路径文件缩短命令行,然后重新运行。 二、问题分析 在idea中,运行一个springboot项目,在使用大量的库和依赖的时候,会出现报错“命令行过长”&…

Java | Leetcode Java题解之第406题根据身高重建队列

题目&#xff1a; 题解&#xff1a; class Solution {public int[][] reconstructQueue(int[][] people) {Arrays.sort(people, new Comparator<int[]>() {public int compare(int[] person1, int[] person2) {if (person1[0] ! person2[0]) {return person2[0] - perso…

Java项目实战II基于Java+Spring Boot+MySQL的车辆管理系统(开发文档+源码+数据库)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 "随着…

Arthas jvm(查看当前JVM的信息)

文章目录 二、命令列表2.1 jvm相关命令2.1.3 jvm&#xff08;查看当前JVM的信息&#xff09; 二、命令列表 2.1 jvm相关命令 2.1.3 jvm&#xff08;查看当前JVM的信息&#xff09; 基础语法&#xff1a; jvm [arthas18139]$ jvmRUNTIME …

【Delphi】通过 LiveBindings Designer 链接控件示例

本教程展示了如何使用 LiveBindings Designer 可视化地创建控件之间的 LiveBindings&#xff0c;以便创建只需很少或无需源代码的应用程序。 在本教程中&#xff0c;您将创建一个高清多设备应用程序&#xff0c;该应用程序使用 LiveBindings 绑定多个对象&#xff0c;以更改圆…

十七、RC振荡电路

振荡电路 1、振荡电路的组成、作用、起振的相位条件以及振荡电路起振和平衡幅度条件&#xff0c; 2、RC电路阻抗与频率、相位与频率的关系曲线; 3、RC振荡电路的相位条件分析和振荡频率

【yolo算法打架行为检测行人检测】

yolo打架行为检测 yolo算法打架行为检测yolo行人检测 yolo算法打架行为检测 数据集和模型YOLO算法打架行为检测数据集1万数据集 分两个类别&#xff1a;正常&#xff0c;打架行为&#xff1b; train: ../train/images val: ../valid/images test: ../test/images nc: 2 names…

一次RPC调用过程是怎么样的?

注册中心 RPC&#xff08;Remote Procedure Call&#xff09;翻译成中文就是 {远程过程调用}。RPC 框架起到的作用就是为了实现&#xff0c;调用远程方法时&#xff0c;能够做到和调用本地方法一样&#xff0c;让开发人员更专注于业务开发&#xff0c;不用去考虑网络编程等细节…

演示jvm锁存在的问题

文章目录 1、AlbumInfoApiController --》testLock()2、redis添加键值对3、AlbumInfoServiceImpl --》testLock() 没有加锁4、使用ab工具测试4.1、安装 ab 工具4.2、查看 redis 中的值 5、添加本地锁 synchronized6、集群情况下问题演示 jvm锁&#xff1a;synchronized lock 只…

尚品汇-H5移动端整合系统(五十五)

目录&#xff1a; &#xff08;1&#xff09;运行前端页面 &#xff08;2&#xff09;启动前端页面 &#xff08;3&#xff09;添加搜索分类接口 &#xff08;4&#xff09;购物车模块修改 &#xff08;5&#xff09;登录模块 &#xff08;6&#xff09;订单模块 &#…

Golang | Leetcode Golang题解之第423题从英文中重建数字

题目&#xff1a; 题解&#xff1a; func originalDigits(s string) string {c : map[rune]int{}for _, ch : range s {c[ch]}cnt : [10]int{}cnt[0] c[z]cnt[2] c[w]cnt[4] c[u]cnt[6] c[x]cnt[8] c[g]cnt[3] c[h] - cnt[8]cnt[5] c[f] - cnt[4]cnt[7] c[s] - cnt[6]…

【Verilog学习日常】—牛客网刷题—Verilog快速入门—VL16

使用8线-3线优先编码器Ⅰ实现16线-4线优先编码器 描述 ②请使用2片该优先编码器Ⅰ及必要的逻辑电路实现16线-4线优先编码器。优先编码器Ⅰ的真值表和代码已给出。 可将优先编码器Ⅰ的代码添加到本题答案中&#xff0c;并例化。 优先编码器Ⅰ的代码如下&#xff1a; module…

[python]从零开始的PySide安装配置教程

一、PySide是什么&#xff1f; PySide 是 Qt for Python 项目的一部分&#xff0c;它提供了与 PyQt 类似的功能&#xff0c;使开发者能够使用 Python 编程语言来构建基于 Qt 的图形用户界面 (GUI) 应用程序。PySide 是由 Qt 公司官方维护的&#xff0c;而 PyQt 则是由第三方开发…

【Pyside】pycharm2024配置conda虚拟环境

知识拓展 Pycharm 是一个由 JetBrains 开发的集成开发环境&#xff08;IDE&#xff09;&#xff0c;它主要用于 Python 编程语言的开发。Pycharm 提供了代码编辑、调试、版本控制、测试等多种功能&#xff0c;以提高 Python 开发者的效率。 Pycharm 与 Python 的关系 Pycharm 是…

【JavaEE】——多线程(join阻塞,计算,引用,状态)

阿华代码&#xff0c;不是逆风&#xff0c;就是我疯&#xff0c;你们的点赞收藏是我前进最大的动力&#xff01;&#xff01;希望本文内容能够帮助到你&#xff01; 目录 一&#xff1a;join等待线程结束 1&#xff1a;知识回顾 2&#xff1a;join的功能就是“阻塞等待” …

java之斗地主部分功能的实现

今天我们要实现斗地主中发牌和洗牌这两个功能&#xff0c;该如何去实现呢&#xff1f; 1.创建牌类&#xff1a;52张牌每一张牌包含两个属性:牌的大小和牌的花色。 故我们优先创建一个牌的类(Card)&#xff1a;包含大小和花色。 public class Card { //单张牌的大小及类型/…

无人机+自组网:中继通信增强技术详解

无人机与自组网技术的结合&#xff0c;特别是通过中继通信增强技术&#xff0c;为无人机在复杂环境中的通信提供了稳定、高效、可靠的解决方案。以下是对该技术的详细解析&#xff1a; 一、无人机自组网技术概述 无人机自组网技术是一种利用无人机作为节点&#xff0c;通过无…

proteus仿真学习(1)

一&#xff0c;创建工程 一般选择默认模式&#xff0c;不配置pcb文件 可以选用芯片型号也可以不选 不选则从零开始布局&#xff0c;没有初始最小系统。选用则有初始最小系统以及基础的main函数 本次学习使用从零开始&#xff0c;不配置固件 二&#xff0c;上手软件 1.在元件…