pytorch-RNN实战-正弦曲线预测

目录

  • 1. 正弦数据生成
  • 2. 构建网络
  • 3. 训练
  • 4. 预测
  • 5. 完整代码
  • 6. 结果展示

1. 正弦数据生成

曲线如下图:
在这里插入图片描述
代码如下图:

  • 50个点构成一个正弦曲线
  • 随机生成一个0~3之间的一个值(随机的原因是防止每次都从相同的点开始,50个点的正弦曲线一样,被模型记住),值的范围区间是[start, start+10]
  • 输入x范围[0,48],预测值y范围是[1,49]

在这里插入图片描述

2. 构建网络

下图是构建的网络,注意out维度扩展出一个维度,是为了和y维度一致
在这里插入图片描述

3. 训练

loss计算采用均方差MSE,优化器采用Adam
注意:hidden_prev的自更新
在这里插入图片描述

4. 预测

预测是循环一个点一个点的预测,每次预测的点的结果作为下次点的输入,直到预测出全部点,放到predictions中。
input = x[:,0,:] 去掉了x[1,seq,1]中的seq维度,变成[1,1]
在这里插入图片描述

5. 完整代码

import  numpy as np
import  torch
import  torch.nn as nn
import  torch.optim as optim
from    matplotlib import pyplot as pltnum_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr=0.01class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=1,batch_first=True,)for p in self.rnn.parameters():nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# [b, seq, h]out = out.view(-1, hidden_size)out = self.linear(out)out = out.unsqueeze(dim=0)return out, hidden_prevmodel = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(1, 1, hidden_size)for iter in range(6000):start = np.random.randint(3, size=1)[0]time_steps = np.linspace(start, start + 10, num_time_steps)data = np.sin(time_steps)data = data.reshape(num_time_steps, 1)x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()loss = criterion(output, y)model.zero_grad()loss.backward()# for p in model.parameters():#     print(p.grad.norm())# torch.nn.utils.clip_grad_norm_(p, 10)optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)predictions = []
input = x[:, 0, :]
for _ in range(x.shape[1]):input = input.view(1, 1, 1)(pred, hidden_prev) = model(input, hidden_prev)input = predpredictions.append(pred.detach().numpy().ravel()[0])x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())plt.scatter(time_steps[1:], predictions)
plt.show()

6. 结果展示

图中黄色点是预测点,蓝色为实际点,前面的曲线是start不随机预测的效果,说明曲线已经被模型记住了;后面的曲线是start随机预测的效果,基本趋势和真实点是一致的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44576.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《C++设计模式》状态模式

文章目录 一、前言二、实现一、UML类图二、实现 一、前言 状态模式理解最基本上的我觉得应该也是够用了,实际用的话,也应该用的是Boost.MSM状态机。 相关代码可以在这里,如有帮助给个star!AidenYuanDev/design_patterns_in_mode…

【PTA天梯赛】L1-005 考试座位号(15分)

作者&#xff1a;指针不指南吗 专栏&#xff1a;算法刷题 &#x1f43e;或许会很慢&#xff0c;但是不可以停下来&#x1f43e; 文章目录 题目题解try1 编译错误正确题解 总结 题目 题目链接 题解 try1 编译错误 #include<bits/stdc.h> using namespace std;typedef…

sdwan是硬件还是网络协议?

SD-WAN&#xff08;Software-Defined Wide Area Network&#xff0c;软件定义广域网&#xff09;并不是一个硬件产品或单一的网络协议&#xff0c;而是结合了软件、硬件和网络技术的一种解决方案。SD-WAN的核心在于其软件定义的特性&#xff0c;它通过软件来控制和管理广域网的…

ENSP软件中DHCP的相关配置以及终端通过域名访问服务器

新建拓扑 配置路由器网关IP 设备配置命令&#xff1a;<Huawei> Huawei部分为设备名 <>代表当下所在的模式&#xff0c;不同模式下具有不同的配置权限<Huawei> 第一级模式&#xff0c;最低级模式 查看所有参数<Huawei>system-view 键入系统视图…

鸿蒙开发:每天一个小bug----鸿蒙开发路由跳转踩坑

一、前言 报错内容显示找不到页面 &#xff0c;肯定我们页面没写对呗&#xff01; 可能是这几个原因:1.main_pages.json没配置路由 {"src": ["pages/02/UserInfoClass","pages/02/AppStorageCase02"] } 2.跳转路径没写对 错误&#xff1a;…

Excel第29享:基于sum嵌套sumifs的多条件求和

1、需求描述 如下图所示&#xff0c;现要统计12.17-12.23这一周各个人员的“上班工时&#xff08;a1&#xff09;”。 下图为系统直接导出的工时数据明细样例。 2、解决思路 首先&#xff0c;确定逻辑&#xff1a;“对多个条件&#xff08;日期、人员&#xff09;进行“工时”…

形态学图像处理

1 工具 1.1 灰度腐蚀和膨胀 当平坦结构元b的原点是(x,y)时&#xff0c;它在(x,y)处对图像f的灰度腐蚀定义为&#xff0c;图像f与b重合区域中的最小值。结构元b在位置(x,y)处对图像f的腐蚀写为&#xff1a; 类似地&#xff0c;当b的反射的原点是(x,y)时&#xff0c;平坦结构元…

react学习——24redux实现求和案例(精简版)

1、目录结构 2、count/index.js import React, {Component} from "react"; //引入store,用于获取数据 import store from ../../redux/store export default class Count extends Component {state {count:store.getState()}componentDidMount() {//监测redux中的…

传言称 iPhone 16 Pro 将支持 40W 快速充电和 20W MagSafe

目前&#xff0c;iPhone 15 和 iPhone 15 Pro 机型使用合适的 USB-C 电源适配器可实现高达 27W 的峰值充电速度&#xff0c;而 Apple 和授权第三方的官方 MagSafe 充电器可以高达 15W 的功率为 iPhone 15 机型进行无线充电。所有四款 iPhone 15 机型均可使用 20W 或更高功率的电…

PHP计件工资系统小程序源码

解锁高效管理新姿势&#xff01;全面了解计件工资系统 &#x1f525; 开篇&#xff1a;为什么计件工资系统成为企业新宠&#xff1f; 在这个效率至上的时代&#xff0c;企业如何精准激励员工&#xff0c;提升生产力成为了一大挑战。计件工资系统应运而生&#xff0c;它以其公…

【小沐学Python】在线web数据可视化Python库:Bokeh

文章目录 1、简介2、安装3、测试3.1 创建折线图3.2 添加和自定义渲染器3.3 添加图例、文本和批注3.4 自定义您的绘图3.5 矢量化字形属性3.6 合并绘图3.7 显示和导出3.8 提供和筛选数据3.9 使用小部件3.10 嵌入Bokeh图表到Flask应用程序 结语 1、简介 https://bokeh.org/ https…

算法力扣刷题记录 四十【226.翻转二叉树】

前言 继续二叉树其余操作&#xff1a; 记录 四十【226.翻转二叉树】 一、题目阅读 给你一棵二叉树的根节点 root &#xff0c;翻转这棵二叉树&#xff0c;并返回其根节点。 示例 1&#xff1a; 输入&#xff1a;root [4,2,7,1,3,6,9] 输出&#xff1a;[4,7,2,9,6,3,1]示例…

CAS介绍

CAS是计算机科学中的一个概念&#xff0c;全称是Compare-And-Swap&#xff08;比较并交换&#xff09;&#xff0c;它是一种原子操作&#xff0c;用于多线程环境下的同步机制。在Java中&#xff0c;你可以使用java.util.concurrent.atomic包下的类&#xff0c;如AtomicInteger来…

绝对值不等式运用(C++)

货仓选址 用数学公式表达题意&#xff0c;假设有位置a1~an,假设选址在x位置处&#xff0c;则有&#xff1a; 如何让这个最小&#xff0c;我们把两个式子整合一下&#xff0c;利用绝对值不等式&#xff1a; 我们知道&#xff1a; 如下图所示&#xff1a;到A&#xff0c;B两点&…

用python生成词频云图(python实例二十一)

目录 1.认识Python 2.环境与工具 2.1 python环境 2.2 Visual Studio Code编译 3.词频云图 3.1 代码构思 3.2 代码实例 3.3 运行结果 4.总结 1.认识Python Python 是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。 Python 的设计具有很强的可读性&a…

[ICS] Inferno(地狱) ETH/IP未授权访问,远程控制工控设备利用工具

项目地址:https://github.com/MartinxMax/Inferno Inferno $ ./Install.sh $ python Inferno.py -h 模拟服务端 $ sudo python3 -m pip install --upgrade cpppo $ $ python -m cpppo.server.enip SCADAINT[1000] ADMININT[2] -v 创建一个EtherNet/IP设备 扫描设备 $ pyth…

QT--SQLite

配置类相关的表&#xff0c;所以我使用sqlite,且QT自带该组件&#xff1b; 1.安装 sqlite-tools-win-x64-3460000、SQLiteExpert5.4.31.575 使用SQLiteExpert建好数据库.db文件&#xff0c;和对应的表后把db文件放在指定目录 ./db/program.db&#xff1b; 2.选择sql组件 3.新…

YOLOv10改进 | Conv篇 | 全新的SOATA轻量化下采样操作ADown(参数量下降百分之二十,附手撕结构图)

一、本文介绍 本文给大家带来的改进机制是利用2024/02/21号最新发布的YOLOv9其中提出的ADown模块来改进我们的Conv模块&#xff0c;其中YOLOv9针对于这个模块并没有介绍&#xff0c;只是在其项目文件中用到了&#xff0c;我将其整理出来用于我们的YOLOv10的项目&#xff0c;经…

【人工智能】-- 反向传播

个人主页&#xff1a;欢迎来到 Papicatch的博客 课设专栏 &#xff1a;学生成绩管理系统 专业知识专栏&#xff1a; 专业知识 文章目录 &#x1f349;引言 &#x1f349;反向传播 &#x1f348;定义 &#x1f348;反向传播的作用 &#x1f34d;参数优化 &#x1f34d;学…

Qt Creator仿Visual Studio黑色主题

转自本人博客&#xff1a;Qt Creator仿Visual Studio黑色主题 1.演示 配置文件和步骤在后面&#xff0c;先看成品&#xff0c;分别是QWidget和QML的代码编写界面&#xff1a; 2. 主题配置文件 下载链接&#xff1a;QtCreator _theme_VS_dark.xml 也可以自己新建一个xml文件&…