PyTorch -- RNN 快速实践

  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as pltimport torch
      import torch.nn as nn
      import torch.optim as optimseq_len     = 50
      batch       = 1
      num_time_steps = seq_leninput_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True class Net(nn.Module):  ## model 定义def __init__(self):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=batch_first)# for p in self.rnn.parameters():# 	nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# out: [batch, seq_len, hidden_size]out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]out = self.linear(out) 			 # [batch*seq_len, output_size]out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01def tarin_RNN():model = Net()print('model:\n',model)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化hl = []for iter in range(100):  # 训练100次start = np.random.randint(10, size=1)[0]  ## 序列起点time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detachloss = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))l.append(loss.item())#############################绘制损失函数#################################plt.plot(l,'r')plt.xlabel('训练次数')plt.ylabel('loss')plt.title('RNN LOSS')plt.savefig('RNN_LOSS.png')return hidden_prev,modelhidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):input = input.view(1, 1, 1)pred, hidden_prev = model(input, hidden_prev)input = pred  ## 循环获得每个input点输入网络predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"p.grad.nomr()torch.nn.utils.clip_grad_norm_(p, 10)  ## 其中的 norm 后面的_ 表示 in place
    
  • 对于梯度消失的解决:-> LSTM

  • 另一个很好的实例关于飞行轨迹预测- - RNN-博客链接,可供学习参考
  • B站视频参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/30434.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用密钥对登录服务器

目录 1、使用密钥文件登录服务器 2、登录成功画面: 3、如若出现以下状况,则说明密钥文件登录失败 1、使用密钥文件登录服务器 首先需要上传pem文件 2、登录成功画面: 3、如若出现以下状况,则说明密钥文件登录失败 解决方法&…

嵌入式技术学习——Linux环境编程(高级编程)——shell编程

一、shell编程的基础介绍 1.为什么要进行shell编程? 在Linux系统中,虽然有各种各样的图形化接口工具,但是shell仍然是一个非常灵活的 工具。 Shell不仅仅是命令的收集,而且是一门非常棒的编程语言。 您可以通过使用shell使大量的任务自动化…

mfc140.dll电脑文件丢失的处理方法,这4种方法能快速修复mfc140.dll

mfc140.dll文件是一个非常重要的dll文件,如果它丢失了,那么会严重的影响程序的运行,这时候我们要找方法去修复mfc140.dll这个文件,那么你知道怎么修复么?如果不知道,那么不妨看看下面的mfc140.dll文件丢失的…

【DAMA】掌握数据管理核心:CDGA考试指南

引言:        在当今快速发展的数字化世界中,数据已成为组织最宝贵的资产之一。有效的数据管理不仅能够驱动业务决策,还能提升竞争力和市场适应性。DAMA国际一直致力于数据管理和数字化的研究、实践及相关知识体系的建设。秉承公益、志愿…

集合系列(二十六) -利用LinkedHashMap实现一个LRU缓存

一、什么是 LRU LRU是 Least Recently Used 的缩写,即最近最少使用,是一种常用的页面置换算法,选择最近最久未使用的页面予以淘汰。 简单的说就是,对于一组数据,例如:int[] a {1,2,3,4,5,6},…

SpringBoot配置第三方专业缓存技术Ehcache

Ehcache缓存技术 我们刚才是用Springboot提供的默认缓存技术 我们用的是simple 是一个内存级的缓存 我们接下来要使用专业的缓存技术了 Ehcache 是一个流行的开源 Java 分布式缓存,由 Terracotta 公司开发和维护。它提供了一个快速、可扩展、易于集成的内存缓存…

如何制定适合不同行业的新版FMEA培训计划?

在快速变化的市场环境中,失效模式与影响分析(FMEA)作为一种预防性的质量控制工具,越来越受到企业的重视。然而,不同行业在FMEA应用上存在着明显的差异,因此制定适合不同行业的新版FMEA培训计划显得尤为重要…

Sui主网升级至V1.27.2版本

其他升级要点如下所示: 重点: #17245 增加了一个新的协议版本,并在开发网络上启用了Move枚举。 JSON-RPC #17245: 在返回的JSON-RPC结果中增加了对Move枚举值的支持。 GraphQL #17245: 增加了对Move枚举值和类型的支持。 CLI #179…

明基的台灯值得入手吗?书客、柏曼真实横向测评对比

如今,近视问题在人群中愈发凸显,据2024年的最新统计数据揭示,我国儿童青少年的近视率已经飙升至惊人的52.7%。在学业日益繁重的背景下,学生们的视力健康成为了社会各界关注的焦点。近视不仅影响视力,还可能引发一系列严…

LeetCode80. 删除有序数组中的重复项 II题解

LeetCode80. 删除有序数组中的重复项 II题解 题目链接: https://leetcode.cn/problems/remove-duplicates-from-sorted-array-ii/ 题目描述: 给你一个有序数组 nums ,请你 原地 删除重复出现的元素,使得出现次数超过两次的元素…

渲染农场深度解析:原理理解、配置要点与高效使用策略

许多设计领域的新手可能对“渲染农场”这一概念感到陌生。渲染农场是一种强大的计算资源集合,它通过高性能的CPU和GPU以及专业的渲染引擎,为设计项目提供必要的渲染支持。这种平台由多台计算机或渲染节点组成,形成一个分布式网络,…

从零基础到学完CCIE要多久?

思科认证的CCIE是网络工程师追求的顶级认证之一。 对于刚入门的初学者来说,从零基础到通过CCIE认证,这条路需要多长时间? 这个问题的答案因人而异,取决于多种因素。 这不仅是一个关于时间的问题,更是一个关于规划、学习…

操作系统真象还原:输入输出系统

第10章-输入输出系统 这是一个网站有所有小节的代码实现,同时也包含了Bochs等文件 10.1 同步机制–锁 10.1.1 排查GP异常,理解原子操作 线程调度工作的核心内容就是线程的上下文保护+上下文恢复 。 根本原因是访问公共资源需要多个操作&…

【教学类-64-04】20240619彩色鱼骨图(一)6.5*1CM 6根棒子720种

背景需求: 幼儿益智早教玩具❗️鱼骨拼图 - 小红书在家也能自制的木棒鱼骨拼图,你也收藏起来试一试吧。 #母婴育儿 #新手爸妈 #玩具 #宝宝玩具怎么选 #早教 #早教玩具 #幼儿早教 #益智早教 #玩具 #宝宝早教 #益智拼图 #宝宝拼图 #玩不腻的益智玩具 #儿童…

vscode插件开发之 - Treeview视图

一些测试类插件,往往需要加载测试文件,并执行这些测试文件。以playwright vscode为例,该插件可以显示目录下所有的测试文件。如下图所示,显示了tests目录下的所有xxx.spec.js文件,那么如何在vscode插件开发中显示TreeV…

[Python学习篇] Python公共操作

公共运算符 运算符描述支持的容器类型合并字符串、列表、元组*复制字符串、列表、元组in元素是否存在字符串、列表、元组、字典not in元素是否不存在字符串、列表、元组、字典 示例: 字符串 str1 ab str2 cd print(str1 str2) # abcd print(str1 * 3) # ab…

Go语言day1

下载go语言的安装程序: All releases - The Go Programming Language 配置go语言的环境变量: 写第一个go语言 在E:\go_workspace当前窗口使用cmd命令: 输入 go run test.go

炭熄卡顿、延迟高、联机报错的解决方法一览

炭熄在制作中巧妙地结合了程序随机生成的元素,为玩家呈现出了一个充满未知与惊险的开放世界,是一款独具匠心的中式民俗恐怖题材游戏。在这款游戏中,玩家将化身为一位意外闯入村子的青年,面对种种鬼怪、努力活下来。游戏将于6月24日…

分页插件结合collection标签后分页数量不准确的问题

问题1:不使用collection 聚合分页正确 简单列子 T_ATOM_DICT表有 idname1原子12原子23原子34原子45原子56原子6 T_ATOM_DICT_AUDIT_ROUTE表审核记录表有 idaudit1拒绝1通过4拒绝 我要显示那些原子审核了,我把两个表inner join 就是那些原子审核过了 idnameaudit1原子1拒绝…

iOS原生APP开发的技术难点

iOS原生APP开发的技术难点主要体现在以下几个方面,总而言之,iOS原生APP开发是一项技术难度较高的工作,需要开发者具备扎实的编程基础、丰富的开发经验和良好的学习能力。北京木奇移动技术有限公司,专业的软件外包开发公司&#xf…