PyTorch -- RNN 快速实践

  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as pltimport torch
      import torch.nn as nn
      import torch.optim as optimseq_len     = 50
      batch       = 1
      num_time_steps = seq_leninput_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True class Net(nn.Module):  ## model 定义def __init__(self):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=batch_first)# for p in self.rnn.parameters():# 	nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# out: [batch, seq_len, hidden_size]out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]out = self.linear(out) 			 # [batch*seq_len, output_size]out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01def tarin_RNN():model = Net()print('model:\n',model)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化hl = []for iter in range(100):  # 训练100次start = np.random.randint(10, size=1)[0]  ## 序列起点time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detachloss = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))l.append(loss.item())#############################绘制损失函数#################################plt.plot(l,'r')plt.xlabel('训练次数')plt.ylabel('loss')plt.title('RNN LOSS')plt.savefig('RNN_LOSS.png')return hidden_prev,modelhidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):input = input.view(1, 1, 1)pred, hidden_prev = model(input, hidden_prev)input = pred  ## 循环获得每个input点输入网络predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"p.grad.nomr()torch.nn.utils.clip_grad_norm_(p, 10)  ## 其中的 norm 后面的_ 表示 in place
    
  • 对于梯度消失的解决:-> LSTM

  • 另一个很好的实例关于飞行轨迹预测- - RNN-博客链接,可供学习参考
  • B站视频参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/30434.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用密钥对登录服务器

目录 1、使用密钥文件登录服务器 2、登录成功画面: 3、如若出现以下状况,则说明密钥文件登录失败 1、使用密钥文件登录服务器 首先需要上传pem文件 2、登录成功画面: 3、如若出现以下状况,则说明密钥文件登录失败 解决方法&…

嵌入式技术学习——Linux环境编程(高级编程)——shell编程

一、shell编程的基础介绍 1.为什么要进行shell编程? 在Linux系统中,虽然有各种各样的图形化接口工具,但是shell仍然是一个非常灵活的 工具。 Shell不仅仅是命令的收集,而且是一门非常棒的编程语言。 您可以通过使用shell使大量的任务自动化…

Django:如何将多个数据表内容合在一起返回响应

一.概要 Django写后端返回响应时,通常需要返回的可能不是一个数据表的内容,还包括了这个数据表的外键所关联的其他表的一些字段,那该如何做才能把他们放在一起返回响应呢? 二.处理方法 在这里我有三个数据表 第一个是航空订单&…

内聚性越高,模块独立性越强

内聚性(Cohesion)是衡量模块内部元素彼此关联程度的指标,而模块独立性(Coupling)则是指模块之间相互依赖的程度。这两个概念在软件工程中是评估设计质量的重要标准。 ### 内聚性: - **高内聚性**意味着模块…

内核学习——0、内核各类机制

1、应用读取驱动四种基本方式:阻塞、非阻塞、poll、异步通知 驱动构造file_operation结构体,里面有open、read、wirte等函数 查询:相当于应用程序非阻塞方式, O_NONBLOCK 休眠–唤醒:相当于应用程序阻塞方式 poll方式…

mfc140.dll电脑文件丢失的处理方法,这4种方法能快速修复mfc140.dll

mfc140.dll文件是一个非常重要的dll文件,如果它丢失了,那么会严重的影响程序的运行,这时候我们要找方法去修复mfc140.dll这个文件,那么你知道怎么修复么?如果不知道,那么不妨看看下面的mfc140.dll文件丢失的…

【DAMA】掌握数据管理核心:CDGA考试指南

引言:        在当今快速发展的数字化世界中,数据已成为组织最宝贵的资产之一。有效的数据管理不仅能够驱动业务决策,还能提升竞争力和市场适应性。DAMA国际一直致力于数据管理和数字化的研究、实践及相关知识体系的建设。秉承公益、志愿…

集合系列(二十六) -利用LinkedHashMap实现一个LRU缓存

一、什么是 LRU LRU是 Least Recently Used 的缩写,即最近最少使用,是一种常用的页面置换算法,选择最近最久未使用的页面予以淘汰。 简单的说就是,对于一组数据,例如:int[] a {1,2,3,4,5,6},…

git从master分支创建分支

1. 切换到主分支或你想从哪里创建新分支 git checkout master 2. 创建并切换到新的本地分支 develop git checkout -b develop 3. 将新分支推送到远程存储库 git push origin develop 4. 设置本地 develop 分支跟踪远程 develop 分支 git branch --set-upstream-toorigi…

Clickhouse Projection

背景 Clickhouse一个视图本质还是表,只支持一种order By,不然要维护太多的视图。 物化视图能力有限。 在设计聚合功能时,考虑使用AggregatingMergeTree表引擎,现在有了projections,打算尝试使用一下 操作 ADD PROJE…

利用冲激平衡法,设冲激响应h(t)的形式(通过求特征根 再转 齐次方程形式)

让我们详细解释一下所谓的“冲激平衡法”(或“冲激响应法”)以及为什么在这个方法中假设冲激响应 ( h(t) ) 的形式为特定的指数函数组合是合理的。 冲激平衡法的基本思想 冲激平衡法的基本思想是通过假设冲激响应 ( h(t) ) 的特定形式,并将…

项目经理真的不能太“拧巴”

前期的项目经理经常是“拧巴”的,就是心里纠结、思路混乱、行动迟缓。对于每天需要面对各种挑战、协调各方资源、确保项目顺利进行的项目经理来说,这种“拧巴”不仅会让自己陷入内耗中,还会让项目出大问题。 项目计划总是改来改去&#xff0…

编程奇境:C++之旅,从新手村到ACM/OI算法竞赛大门(中级武器:并查集)

我们都知道,朋友的朋友也可以是朋友,并查集就是这么一种武器,能够让自己广交天下之友。 并查集 并查集啊,想象一下你班上的同学们都在操场上自由活动。突然老师说:“大家找朋友手拉手围成圈玩个游戏!”这…

SpringBoot配置第三方专业缓存技术Ehcache

Ehcache缓存技术 我们刚才是用Springboot提供的默认缓存技术 我们用的是simple 是一个内存级的缓存 我们接下来要使用专业的缓存技术了 Ehcache 是一个流行的开源 Java 分布式缓存,由 Terracotta 公司开发和维护。它提供了一个快速、可扩展、易于集成的内存缓存…

LeetCode 每日一题 2748. 美丽下标对的数目

Hey编程小伙伴们👋,今天我要带大家一起解锁力扣上的一道有趣题目—— 美丽下标对的数目 - 力扣 (LeetCode)。这不仅是一次编程挑战,更是一次深入理解欧几里得算法判断互质的绝佳机会!🎉 问题简介 题目要求我们给定一…

如何制定适合不同行业的新版FMEA培训计划?

在快速变化的市场环境中,失效模式与影响分析(FMEA)作为一种预防性的质量控制工具,越来越受到企业的重视。然而,不同行业在FMEA应用上存在着明显的差异,因此制定适合不同行业的新版FMEA培训计划显得尤为重要…

Sui主网升级至V1.27.2版本

其他升级要点如下所示: 重点: #17245 增加了一个新的协议版本,并在开发网络上启用了Move枚举。 JSON-RPC #17245: 在返回的JSON-RPC结果中增加了对Move枚举值的支持。 GraphQL #17245: 增加了对Move枚举值和类型的支持。 CLI #179…

kubernetes node 节点管理

kubernetes node 节点管理 1 查看集群信息 kubectl cluster-info 2 查看节点信息 2.1 查看node信息 kubectl get nodes 2.2 查看node细致信息 kubectl get nodes -o wide 2.3 查看node描述详细信息 kubectl describe node <node-name> 2.4 查看节点资源使用情况…

明基的台灯值得入手吗?书客、柏曼真实横向测评对比

如今&#xff0c;近视问题在人群中愈发凸显&#xff0c;据2024年的最新统计数据揭示&#xff0c;我国儿童青少年的近视率已经飙升至惊人的52.7%。在学业日益繁重的背景下&#xff0c;学生们的视力健康成为了社会各界关注的焦点。近视不仅影响视力&#xff0c;还可能引发一系列严…

LeetCode80. 删除有序数组中的重复项 II题解

LeetCode80. 删除有序数组中的重复项 II题解 题目链接&#xff1a; https://leetcode.cn/problems/remove-duplicates-from-sorted-array-ii/ 题目描述&#xff1a; 给你一个有序数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使得出现次数超过两次的元素…