pytorch-RNN实战-正弦曲线预测

目录

  • 1. 正弦数据生成
  • 2. 构建网络
  • 3. 训练
  • 4. 预测
  • 5. 完整代码
  • 6. 结果展示

1. 正弦数据生成

曲线如下图:
在这里插入图片描述
代码如下图:

  • 50个点构成一个正弦曲线
  • 随机生成一个0~3之间的一个值(随机的原因是防止每次都从相同的点开始,50个点的正弦曲线一样,被模型记住),值的范围区间是[start, start+10]
  • 输入x范围[0,48],预测值y范围是[1,49]

在这里插入图片描述

2. 构建网络

下图是构建的网络,注意out维度扩展出一个维度,是为了和y维度一致
在这里插入图片描述

3. 训练

loss计算采用均方差MSE,优化器采用Adam
注意:hidden_prev的自更新
在这里插入图片描述

4. 预测

预测是循环一个点一个点的预测,每次预测的点的结果作为下次点的输入,直到预测出全部点,放到predictions中。
input = x[:,0,:] 去掉了x[1,seq,1]中的seq维度,变成[1,1]
在这里插入图片描述

5. 完整代码

import  numpy as np
import  torch
import  torch.nn as nn
import  torch.optim as optim
from    matplotlib import pyplot as pltnum_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr=0.01class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=1,batch_first=True,)for p in self.rnn.parameters():nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# [b, seq, h]out = out.view(-1, hidden_size)out = self.linear(out)out = out.unsqueeze(dim=0)return out, hidden_prevmodel = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(1, 1, hidden_size)for iter in range(6000):start = np.random.randint(3, size=1)[0]time_steps = np.linspace(start, start + 10, num_time_steps)data = np.sin(time_steps)data = data.reshape(num_time_steps, 1)x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()loss = criterion(output, y)model.zero_grad()loss.backward()# for p in model.parameters():#     print(p.grad.norm())# torch.nn.utils.clip_grad_norm_(p, 10)optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)predictions = []
input = x[:, 0, :]
for _ in range(x.shape[1]):input = input.view(1, 1, 1)(pred, hidden_prev) = model(input, hidden_prev)input = predpredictions.append(pred.detach().numpy().ravel()[0])x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())plt.scatter(time_steps[1:], predictions)
plt.show()

6. 结果展示

图中黄色点是预测点,蓝色为实际点,前面的曲线是start不随机预测的效果,说明曲线已经被模型记住了;后面的曲线是start随机预测的效果,基本趋势和真实点是一致的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44576.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《C++设计模式》状态模式

文章目录 一、前言二、实现一、UML类图二、实现 一、前言 状态模式理解最基本上的我觉得应该也是够用了,实际用的话,也应该用的是Boost.MSM状态机。 相关代码可以在这里,如有帮助给个star!AidenYuanDev/design_patterns_in_mode…

【PTA天梯赛】L1-005 考试座位号(15分)

作者&#xff1a;指针不指南吗 专栏&#xff1a;算法刷题 &#x1f43e;或许会很慢&#xff0c;但是不可以停下来&#x1f43e; 文章目录 题目题解try1 编译错误正确题解 总结 题目 题目链接 题解 try1 编译错误 #include<bits/stdc.h> using namespace std;typedef…

sdwan是硬件还是网络协议?

SD-WAN&#xff08;Software-Defined Wide Area Network&#xff0c;软件定义广域网&#xff09;并不是一个硬件产品或单一的网络协议&#xff0c;而是结合了软件、硬件和网络技术的一种解决方案。SD-WAN的核心在于其软件定义的特性&#xff0c;它通过软件来控制和管理广域网的…

ENSP软件中DHCP的相关配置以及终端通过域名访问服务器

新建拓扑 配置路由器网关IP 设备配置命令&#xff1a;<Huawei> Huawei部分为设备名 <>代表当下所在的模式&#xff0c;不同模式下具有不同的配置权限<Huawei> 第一级模式&#xff0c;最低级模式 查看所有参数<Huawei>system-view 键入系统视图…

探索未来:Transformer模型在智能环境监测的革命性应用

探索未来&#xff1a;Transformer模型在智能环境监测的革命性应用 在当今数字化时代&#xff0c;环境监测正逐渐从传统的人工检测方式转变为智能化、自动化的系统。Transformer模型&#xff0c;作为深度学习领域的一颗新星&#xff0c;其在自然语言处理&#xff08;NLP&#x…

鸿蒙开发:每天一个小bug----鸿蒙开发路由跳转踩坑

一、前言 报错内容显示找不到页面 &#xff0c;肯定我们页面没写对呗&#xff01; 可能是这几个原因:1.main_pages.json没配置路由 {"src": ["pages/02/UserInfoClass","pages/02/AppStorageCase02"] } 2.跳转路径没写对 错误&#xff1a;…

Excel第29享:基于sum嵌套sumifs的多条件求和

1、需求描述 如下图所示&#xff0c;现要统计12.17-12.23这一周各个人员的“上班工时&#xff08;a1&#xff09;”。 下图为系统直接导出的工时数据明细样例。 2、解决思路 首先&#xff0c;确定逻辑&#xff1a;“对多个条件&#xff08;日期、人员&#xff09;进行“工时”…

自动驾驶论文总结

预测 光栅化 代表性论文 Motion Prediction of Traffic Actors for Autonomous Driving using Deep Convolutional Networks (Uber)MultiPath (Waymo) 问题 渲染信息丢失感受野有限高计算复杂度 图神经网络 VectorMap (waymo 2020)LaneGCN (uber 2020) Transformer mm…

C#利用NPOI在已有多个Sheet的Excel中的其中一个Sheet插入或保存数据

在使用NPOI库处理Excel文件&#xff08;尤其是.xlsx文件&#xff0c;即Excel 2007及以上版本&#xff09;时&#xff0c;你可以很容易地读取、修改或向已存在的Excel文件中的特定Sheet添加数据。以下是一个基本的步骤说明和示例代码&#xff0c;展示如何在C#中使用NPOI向已包含…

[Linux][Shell][Shell数学运算]详细讲解

目录 0.前置知识1.if参数2.Unix Shell里面比较字符写法 1.算数运算符号2.常见算数运算命令0.常用算数运算命令1.双括号(())2.let命令3.expr命令4.bc命令5.中括号[]6.awk计算 0.前置知识 1.if参数 参数意义-b当file存在并且是块⽂件时返回true-c当file存在并且是字符⽂件时返回…

形态学图像处理

1 工具 1.1 灰度腐蚀和膨胀 当平坦结构元b的原点是(x,y)时&#xff0c;它在(x,y)处对图像f的灰度腐蚀定义为&#xff0c;图像f与b重合区域中的最小值。结构元b在位置(x,y)处对图像f的腐蚀写为&#xff1a; 类似地&#xff0c;当b的反射的原点是(x,y)时&#xff0c;平坦结构元…

react学习——24redux实现求和案例(精简版)

1、目录结构 2、count/index.js import React, {Component} from "react"; //引入store,用于获取数据 import store from ../../redux/store export default class Count extends Component {state {count:store.getState()}componentDidMount() {//监测redux中的…

从像素角度出发使用OpenCV检测图像是否为彩色

从像素角度出发使用OpenCV检测图像是否为彩色 使用OpenCV检测图像是否为彩色&#xff08;从像素角度出发&#xff09;引言基本概念从像素角度检测图像是否为彩色代码实现1. 读取图像2. 获取图像的形状3. 遍历图像的每个像素4. 基于RGB通道的判断测试代码 5.优化代码性能6.使用N…

传言称 iPhone 16 Pro 将支持 40W 快速充电和 20W MagSafe

目前&#xff0c;iPhone 15 和 iPhone 15 Pro 机型使用合适的 USB-C 电源适配器可实现高达 27W 的峰值充电速度&#xff0c;而 Apple 和授权第三方的官方 MagSafe 充电器可以高达 15W 的功率为 iPhone 15 机型进行无线充电。所有四款 iPhone 15 机型均可使用 20W 或更高功率的电…

PHP计件工资系统小程序源码

解锁高效管理新姿势&#xff01;全面了解计件工资系统 &#x1f525; 开篇&#xff1a;为什么计件工资系统成为企业新宠&#xff1f; 在这个效率至上的时代&#xff0c;企业如何精准激励员工&#xff0c;提升生产力成为了一大挑战。计件工资系统应运而生&#xff0c;它以其公…

golang interface指针实现

在 Go 语言中&#xff0c;接口(interface)是一种类型&#xff0c;它定义了一组方法的集合。任何实现了接口中所有方法的类型都会自动满足该接口。当涉及到指针接收者时&#xff0c;情况会稍微复杂一些&#xff0c;因为需要考虑到值接收者和指针接收者之间的区别。 下面是一个简…

【小沐学Python】在线web数据可视化Python库:Bokeh

文章目录 1、简介2、安装3、测试3.1 创建折线图3.2 添加和自定义渲染器3.3 添加图例、文本和批注3.4 自定义您的绘图3.5 矢量化字形属性3.6 合并绘图3.7 显示和导出3.8 提供和筛选数据3.9 使用小部件3.10 嵌入Bokeh图表到Flask应用程序 结语 1、简介 https://bokeh.org/ https…

算法力扣刷题记录 四十【226.翻转二叉树】

前言 继续二叉树其余操作&#xff1a; 记录 四十【226.翻转二叉树】 一、题目阅读 给你一棵二叉树的根节点 root &#xff0c;翻转这棵二叉树&#xff0c;并返回其根节点。 示例 1&#xff1a; 输入&#xff1a;root [4,2,7,1,3,6,9] 输出&#xff1a;[4,7,2,9,6,3,1]示例…

包管理器-npm、yarn、cnpm、pnpm的比较

1. npm (node package manage) 1.1本地安装 使用命令&#xff1a;npm install 包名 或 npm i 包名 本地安装的包出现在当前目录下的node_module目录中 如果本地安装的包带有CLI&#xff0c;npm 会将它的CLI脚本放置到node_modules/.bin下&#xff0c;使用npx命令即可调用。 …

Perl伪哈希探秘:深入理解Perl中的高级数据结构

&#x1f310; Perl伪哈希探秘&#xff1a;深入理解Perl中的高级数据结构 在Perl的世界里&#xff0c;数据结构是编程的基础。除了传统的数组和哈希&#xff0c;Perl还提供了一种特殊的数据结构——伪哈希&#xff08;Pseudo-Hashes&#xff09;。伪哈希是一种灵活的键值对集合…