RNN预测正弦时间点

 

import torch.nn as nn
import torch
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
# net = nn.RNN(100,10) #100个单词,每个单词10个维度
# print(net._parameters.keys())
#序列时间点预测num_time_steps =50
input_size =1
hidden_size =16
output_size = 1
lr=0.01
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=1,batch_first=True,  #[b,seq,feature]   batch_first=False [seq,b,feature] ,)self.linear = nn.Linear(hidden_size,output_size)def forward(self,x,hidden_prev):# hidden_prev=h0 表示最后一个Ht的输出,out是表示[h0,h1,h2,h3....]每一个时间t的输出out,hidden_prev = self.rnn(x,hidden_prev)#[1,seq,h] => [seq,h]out = out.view(-1,hidden_size)out = self.linear(out) #[seq,h] => [seq,1]out = out.unsqueeze(dim=0)  #=>[1,seq,1]return out,hidden_prevmodel =Net()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr)hidden_prev = torch.zeros(1,1,hidden_size) #[b,1,10]for iter in range(6000):start = np.random.randint(10,size=1)[0]time_steps = np.linspace(start,start+10,num_time_steps)data = np.sin(time_steps)data = data.reshape(num_time_steps,1)x = torch.tensor(data[:-1]).float().view(1,num_time_steps-1,1)y = torch.tensor(data[1:]).float().view(1,num_time_steps-1,1)output,hidden_prev = model(x,hidden_prev)hidden_prev =hidden_prev.detach()loss = criterion(output,y)model.zero_grad()loss.backward()optimizer.step()if iter%100 == 0:print("Iteration:{} loss{}".format(iter,loss.item()))predictions = []
input = x[:,0,:]
for _ in range(x.shape[1]):input = input.view(1,1,1)(pred,hidden_prev) = model(input,hidden_prev)input = predpredictions.append(pred.detach().numpy().ravel()[0])x= x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1],x.ravel(),s=90)
plt.plot(time_steps[:-1],predictions)plt.scatter(time_steps[1:],predictions)
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/732807.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

消息中间件面试题-参考回答

面试官:RabbitMQ-如何保证消息不丢失 候选人: 嗯!我们当时MYSQL和Redis的数据双写一致性就是采用RabbitMQ实现同步 的,这里面就要求了消息的高可用性,我们要保证消息的不丢失。主要从三 个层面考虑 第一个是开启生产者…

学习SVN

学习SVN 摘要1.简介2.下载安装3.SVN生命周期4.SVN Server搭建5.SVN Client使用6.git与SVN的区别 SVN 官网 Github SVN 源码 摘要 本篇博客对SVN的基础使用进行总结,以便加深理解和记忆 1.简介 SVN是Apache Subversion的缩写,是一个开源的源码版本控制…

java-ssm-jsp-基于ssm的宝文理学生社团管理系统

java-ssm-jsp-基于ssm的宝文理学生社团管理系统 获取源码——》公主号:计算机专业毕设大全

应对高并发的软件架构之道

在去年年终总结的时候,我提出了这样的困惑,究竟什么是真正的技术能力,是对于各种底层技术的钻研吗?钻研是好事,但实践下来,深入钻研并不在实际工作中有用,且钻研的越深,忘得越快&…

AIGC安全研究简述(附资料下载)

2023 AIGC技术实践及展望资料合集(29份).zip 2023 AIGC大型语言模型(LLM)实例代码合集.zip 2023大模型与AIGC峰会(公开)PPT汇总(25份).zip AIGC的安全研究是一个复杂且重要的领域,涉及多个关键…

Leetcode : 1137. 高度检查器

学校打算为全体学生拍一张年度纪念照。根据要求,学生需要按照 非递减 的高度顺序排成一行。 排序后的高度情况用整数数组 expected 表示,其中 expected[i] 是预计排在这一行中第 i 位的学生的高度(下标从 0 开始)。 给你一个整数…

一篇搞懂什么是LRU缓存|一篇搞懂LRU缓存的实现|LRUCache详解和实现

LRUCache 文章目录 LRUCache前言项目代码仓库什么时候会用到缓存(Cache)缓存满了,怎么办?什么是LRUCacheLRUCache的实现LRUCache对应的OJ题实现LRUCache对应的STL风格实现 前言 这里分享我的一些博客专栏,都是干货满满的。 手撕数据结构专栏…

什么是测试用例?如何设计?

在学习或者实际的测试工作中经常都会提到“测试用例”这个词,没错,测试用例是测试工作的核心,不管要做的是什么样的测试,在真正动手执行测试之前,我们都需要先根据软件需求来设计测试用例,之后再依据设计好…

动态加权平衡损失:深度神经网络的类不平衡学习和置信度校准

系列文章目录 文章目录 系列文章目录前言一、研究目的二、研究方法创新点处理类不平衡的大多数方法交叉熵损失函数Brier Score 三、DWB Loss总结 前言 Dynamically Weighted Balanced Loss: ClassImbalanced Learning and Confidence Calibration of Deep Neural Networks 下载…

2024年3月10日 十二生肖 今日运势

小运播报:2024年3月10日,星期日,农历二月初一 (甲辰年丁卯月癸酉日),法定节假日。 红榜生肖:龙、牛、蛇 需要注意:鸡、狗、兔 喜神方位:东南方 财神方位:…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Image)

Image为图片组件,常用于在应用中显示图片。Image支持加载PixelMap、ResourceStr和DrawableDescriptor类型的数据源,支持png、jpg、jpeg、bmp、svg、webp和gif类型的图片格式。 说明: 该组件从API Version 7开始支持。后续版本如有新增内容&am…

作业 字符数组-统计和加密

字串中数字个数 描述 输入一行字符&#xff0c;统计出其中数字字符的个数。 输入 一行字符串&#xff0c;总长度不超过255。 输出 输出为1行&#xff0c;输出字符串里面数字字符的个数。 样例 #include <iostream> #include<string.h> using namespace std; int m…

AI绘画提示词案例(宠物

目录 1. 雪地猫猫&#xff1a;1.1 提示词&#xff1a;1.2 效果&#xff1a; 2. 趴地猫猫&#xff1a;2.1 提示词&#xff1a;2.2 效果&#xff1a; 3. 长城萨摩耶&#xff1a;3.1 提示词&#xff1a;3.2 效果&#xff1a; 4. 沙发猫猫&#xff1a;4.1 提示词&#xff1a;4.2 效…

[BT]小迪安全2023学习笔记(第21天:Web攻防-JWT)

第21天 JSON Web Token&#xff08;JWT&#xff09; JWT是一种紧凑且自包含的方式&#xff0c;用于在网络上安全地传输信息作为JSON对象。这些信息可以被验证和信任&#xff0c;因为它们是数字签名的。JWT通常用于身份验证和信息交换&#xff0c;下面是一个简化的JWT示例&…

Mysql:如何自定义导出表结构

为了方便将mysql表结构信息快速录入到word或Excel表格中&#xff0c;最终实现如下效果&#xff1a; 对于word,则可将Excel表格复制粘贴即可。 废话不多少&#xff0c;开干。 准备准建&#xff1a;navicat 或sqlyog 第一步&#xff1a;编辑sql&#xff0c;如&#xff1a; SE…

P5461 赦免战俘

来自-赦免战俘 - 洛谷 参考&#xff1a;题解 P5461 【赦免战俘】 - 洛谷专栏 代码&#xff1a; #include <iostream> #include <math.h> //利用pow()函数算次方 using namespace std; int a[1500][1500]; //因为最大每边顶多有2^101024人&#xff0c;所以1500…

HTML 01

1.html使用标签来表达 结束标签多一个/ <strong>文字内容</strong> <hr> 包裹内容就是双标签&#xff0c;换行等是单标签 浏览器中显示内容&#xff1a; 2.html的骨架是网页模板 <!DOCTYPE html> <html lang"en"> <head>&l…

Full GC的认识、预防和定位

(/≧▽≦)/~┴┴ 嗨~我叫小奥 ✨✨✨ &#x1f440;&#x1f440;&#x1f440; 个人博客&#xff1a;小奥的博客 &#x1f44d;&#x1f44d;&#x1f44d;&#xff1a;个人CSDN ⭐️⭐️⭐️&#xff1a;传送门 &#x1f379; 本人24应届生一枚&#xff0c;技术和水平有限&am…

【leetcode】429. N 叉树的层序遍历

题目描述 给定一个 N 叉树&#xff0c;返回其节点值的_层序遍历_。&#xff08;即从左到右&#xff0c;逐层遍历&#xff09;。 树的序列化输入是用层序遍历&#xff0c;每组子节点都由 null 值分隔&#xff08;参见示例&#xff09;。 示例 1&#xff1a; 输入&#xff1a;…

使用Python编写简单学生管理系统

学完python基础&#xff0c;把学过的知识运用起来做一个简单的学生管理系统 1、需求分析 需求&#xff1a;进入系统显示系统功能界面&#xff0c;功能如下&#xff1a; ① 添加学员信息 ② 删除学员信息 ③ 修改学员信息 ④ 查询学员信息(只查询某个学员) ⑤ 遍历所有学…