pytorch实现RNN网络

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54418.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt+FFmpeg开发视频播放器笔记(三):音视频流解析封装

音频解析 音频解码是指将压缩的音频数据转换为可以再生的PCM(脉冲编码调制)数据的过程。 FFmpeg音频解码的基本步骤如下: 初始化FFmpeg解码器(4.0版本后可省略): 调用av_register_all()初始化编解码器。 调用avcodec_register_all()注册所有编解码器。 打开输入的音频流:…

pthread_cond_signal 和pthread_cond_wait

0、pthread_join()函数作用: pthread_join() 函数会一直阻塞调用它的线程,直至目标线程执行结束(接收到目标线程的返回值),阻塞状态才会解除。如果 pthread_join() 函数成功等到了目标线程执行结束(成功获取…

【LLM】Ollama:本地大模型使用

本指南将详细介绍如何在Linux系统上使用Ollama进行本地大模型的快速部署与管理。通过Docker容器化技术,您可以轻松部署Ollama及其WebUI,实现通过浏览器访问和管理大型语言模型。 快速部署,Web访问 安装Docker 根据操作系统下载并安装 Dock…

运行 xxxxApplication 时出错。命令行过长。 通过 JAR 清单或通过类路径文件缩短命令行,然后重新运行。

一、问题描述 运行 xxxxApplication 时出错。命令行过长。 通过 JAR 清单或通过类路径文件缩短命令行,然后重新运行。 二、问题分析 在idea中,运行一个springboot项目,在使用大量的库和依赖的时候,会出现报错“命令行过长”&…

Java | Leetcode Java题解之第406题根据身高重建队列

题目&#xff1a; 题解&#xff1a; class Solution {public int[][] reconstructQueue(int[][] people) {Arrays.sort(people, new Comparator<int[]>() {public int compare(int[] person1, int[] person2) {if (person1[0] ! person2[0]) {return person2[0] - perso…

Java项目实战II基于Java+Spring Boot+MySQL的车辆管理系统(开发文档+源码+数据库)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 "随着…

不同语言的switch/case语句

多路分支选择switch/case是if/else if/else的好看版本&#xff0c;switch中default语句对应if的else语句。每个case对应汇编代码的label, 编译器插入多条jmp语句实现不同分支跳转. 条件类型 C/C条件类型必须是整型、字符或枚举&#xff0c;不能是字符串。 C#除了C/C支持的类型&…

Arthas jvm(查看当前JVM的信息)

文章目录 二、命令列表2.1 jvm相关命令2.1.3 jvm&#xff08;查看当前JVM的信息&#xff09; 二、命令列表 2.1 jvm相关命令 2.1.3 jvm&#xff08;查看当前JVM的信息&#xff09; 基础语法&#xff1a; jvm [arthas18139]$ jvmRUNTIME …

LangChain4j支持的API类型

本文描述了底层的大语言模型&#xff08;LLM&#xff09;API。高级的LLM API参见AI服务。 1 LLM API的类型 1.1 LanguageModel 非常简单—&#xff0c;接受一个String作为输入&#xff0c;并返回一个String作为输出。 该API现正逐渐被聊天API&#xff08;第二种API类型&…

【Delphi】通过 LiveBindings Designer 链接控件示例

本教程展示了如何使用 LiveBindings Designer 可视化地创建控件之间的 LiveBindings&#xff0c;以便创建只需很少或无需源代码的应用程序。 在本教程中&#xff0c;您将创建一个高清多设备应用程序&#xff0c;该应用程序使用 LiveBindings 绑定多个对象&#xff0c;以更改圆…

十七、RC振荡电路

振荡电路 1、振荡电路的组成、作用、起振的相位条件以及振荡电路起振和平衡幅度条件&#xff0c; 2、RC电路阻抗与频率、相位与频率的关系曲线; 3、RC振荡电路的相位条件分析和振荡频率

【数据结构】线性数据结构-顺序栈

栈&#xff08;Stack&#xff09;是一种基本的数据结构&#xff0c;具有以下特点&#xff1a; 后进先出&#xff08;LIFO, Last In First Out&#xff09;&#xff1a;栈内的数据项遵循后进先出的原则&#xff0c;即最后存入的项最先被取出。 操作限制&#xff1a;栈通常只允许…

【yolo算法打架行为检测行人检测】

yolo打架行为检测 yolo算法打架行为检测yolo行人检测 yolo算法打架行为检测 数据集和模型YOLO算法打架行为检测数据集1万数据集 分两个类别&#xff1a;正常&#xff0c;打架行为&#xff1b; train: ../train/images val: ../valid/images test: ../test/images nc: 2 names…

一次RPC调用过程是怎么样的?

注册中心 RPC&#xff08;Remote Procedure Call&#xff09;翻译成中文就是 {远程过程调用}。RPC 框架起到的作用就是为了实现&#xff0c;调用远程方法时&#xff0c;能够做到和调用本地方法一样&#xff0c;让开发人员更专注于业务开发&#xff0c;不用去考虑网络编程等细节…

项目实战 (15)--- 代码区块重构及相关技术落地

目录 背景 思想与技术方案 概述 路由及socket 封装模式 技术描述 方案 1 方案2 各类连接资源管理封装 vector db connection cache management web socket 管理 service 封装 代码实现 service 层 router层 resource层 小结 背景 到目前为止,视频搜索系统功…

演示jvm锁存在的问题

文章目录 1、AlbumInfoApiController --》testLock()2、redis添加键值对3、AlbumInfoServiceImpl --》testLock() 没有加锁4、使用ab工具测试4.1、安装 ab 工具4.2、查看 redis 中的值 5、添加本地锁 synchronized6、集群情况下问题演示 jvm锁&#xff1a;synchronized lock 只…

golang学习笔记27——golang 实现 RPC 模块

推荐学习文档 golang应用级os框架&#xff0c;欢迎stargolang应用级os框架使用案例&#xff0c;欢迎star案例&#xff1a;基于golang开发的一款超有个性的旅游计划app经历golang实战大纲golang优秀开发常用开源库汇总想学习更多golang知识&#xff0c;这里有免费的golang学习笔…

golang学习笔记3-变量的声明

声明&#xff1a;本人已有C&#xff0c;C,Python基础&#xff0c;只写本人认为的重点&#xff0c;方便自己回顾。 一、变量的三种声明方式 func main() {//方式1&#xff0c;指定数据类型&#xff0c;声明后若不赋值&#xff0c;使用默认值//比如int的默认值是0&#xff0c;st…

尚品汇-H5移动端整合系统(五十五)

目录&#xff1a; &#xff08;1&#xff09;运行前端页面 &#xff08;2&#xff09;启动前端页面 &#xff08;3&#xff09;添加搜索分类接口 &#xff08;4&#xff09;购物车模块修改 &#xff08;5&#xff09;登录模块 &#xff08;6&#xff09;订单模块 &#…

Golang | Leetcode Golang题解之第423题从英文中重建数字

题目&#xff1a; 题解&#xff1a; func originalDigits(s string) string {c : map[rune]int{}for _, ch : range s {c[ch]}cnt : [10]int{}cnt[0] c[z]cnt[2] c[w]cnt[4] c[u]cnt[6] c[x]cnt[8] c[g]cnt[3] c[h] - cnt[8]cnt[5] c[f] - cnt[4]cnt[7] c[s] - cnt[6]…