动手学深度学习(Pytorch版)代码实践 -循环神经网络-51序列模型

51序列模型

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as pltT = 1000  # 总共产生1000个点
time = torch.arange(1, T + 1, dtype=torch.float32)
x = torch.sin(0.01 * time) + torch.normal(mean=0, std=0.2, size=(T,))
d2l.plot(time, [x], 'time', 'x', xlim=[1, 1000], figsize=(6, 3))
plt.show()tau = 4
features = torch.zeros((T - tau, tau)) # torch.Size([996, 4])
for i in range(tau):# features 矩阵的每一行将包含时间序列中连续 tau 个时间步的数据features[:, i] = x[i: T - tau + i]
"""
x = [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, ...]
features = [[x0, x1, x2, x3],[x1, x2, x3, x4],[x2, x3, x4, x5],[x3, x4, x5, x6],[x4, x5, x6, x7],[x5, x6, x7, x8],...
]
"""
labels = x[tau:].reshape((-1, 1))batch_size, n_train = 16, 600
# 只有前n_train个样本用于训练
train_iter = d2l.load_array((features[:n_train], labels[:n_train]),batch_size, is_train=True)# 初始化网络权重的函数
def init_weights(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)# 一个简单的多层感知机
def get_net():net = nn.Sequential(nn.Linear(4, 10),nn.ReLU(),nn.Linear(10, 1))net.apply(init_weights)return net# 平方损失。注意:MSELoss计算平方误差时不带系数1/2
loss = nn.MSELoss(reduction='none')def train(net, train_iter, loss, epochs, lr):trainer = torch.optim.Adam(net.parameters(), lr)for epoch in range(epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, 'f'loss: {d2l.evaluate_loss(net, train_iter, loss):f}')net = get_net()
train(net, train_iter, loss, 10, 0.01)
# epoch 8, loss: 0.044640
# epoch 9, loss: 0.045863
# epoch 10, loss: 0.045066# 单步预测 
onestep_preds = net(features) #对输入特征 features 进行单步预测。
d2l.plot([time, time[tau:]],[x.detach().numpy(), onestep_preds.detach().numpy()], 'time','x', legend=['data', '1-step preds'], xlim=[1, 1000],figsize=(6, 3))
"""
[time, time[tau:]]:表示 x 轴的数据。time 是完整的时间序列,time[tau:] 是从第 tau 个时间步开始的时间序列,长度为 T - tau。
[x.detach().numpy(), onestep_preds.detach().numpy()]:表示 y 轴的数据。x 是实际的时间序列数据,onestep_preds 是神经网络的预测结果。使用 detach().numpy() 将 PyTorch 张量转换为 NumPy 数组,以便绘图函数可以处理。
'time':x 轴的标签,表示时间。
'x':y 轴的标签,表示时间序列数据的值。
legend=['data', '1-step preds']:图例,分别标记实际数据和单步预测。
xlim=[1, 1000]:设置 x 轴的范围,从 1 到 1000。
figsize=(6, 3):设置图的大小,宽度为 6 英寸,高度为 3 英寸。
"""
plt.show()# 多步预测
multistep_preds = torch.zeros(T) # 用于存储多步预测的结果
multistep_preds[: n_train + tau] = x[: n_train + tau]
# 将 x 的前 n_train + tau 个元素复制到 multistep_preds 的对应位置。
# 这样做是为了在进行多步预测之前,保留训练集和前 tau 个时间步的数据。
for i in range(n_train + tau, T): # 从 n_train + tau 开始到 T,逐步进行预测。multistep_preds[i] = net(multistep_preds[i - tau:i].reshape((1, -1)))d2l.plot([time, time[tau:], time[n_train + tau:]],[x.detach().numpy(), onestep_preds.detach().numpy(),multistep_preds[n_train + tau:].detach().numpy()], 'time','x', legend=['data', '1-step preds', 'multistep preds'],xlim=[1, 1000], figsize=(6, 3))
plt.show()# 生成特征矩阵进行多步预测
max_steps = 64 # 最大预测步数为64
features = torch.zeros((T - tau - max_steps + 1, tau + max_steps))# tau 列是时间序列 x 的观测数据,而后 max_steps 列是基于前面列的预测结果。
# 列i(i<tau)是来自x的观测,其时间步从(i)到(i+T-tau-max_steps+1)
for i in range(tau):features[:, i] = x[i: i + T - tau - max_steps + 1]
# 列i(i>=tau)是来自(i-tau+1)步的预测,其时间步从(i)到(i+T-tau-max_steps+1)
for i in range(tau, tau + max_steps):features[:, i] = net(features[:, i - tau:i]).reshape(-1)steps = (1, 4, 16, 64)
d2l.plot([time[tau + i - 1: T - max_steps + i] for i in steps],[features[:, (tau + i - 1)].detach().numpy() for i in steps], 'time', 'x',legend=[f'{i}-step preds' for i in steps], xlim=[5, 1000],figsize=(6, 3))
plt.show()"""
1-step 预测:每一步预测只预测下一个时间步的数据。模型每次使用的是最近的观测数据进行预测。
4-step 预测:每一步预测预测接下来的四个时间步的数据。模型需要预测四步后的数据。
16-step 预测:每一步预测预测接下来的十六个时间步的数据。模型需要预测更远的未来数据。
64-step 预测:每一步预测预测接下来的六十四个时间步的数据。模型需要预测很远的未来数据。
"""

单步预测:
在这里插入图片描述

多步预测:
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/41542.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ERROR | Web server failed to start. Port 8080 was already in use.

错误提示&#xff1a; *************************** APPLICATION FAILED TO START ***************************Description:Web server failed to start. Port 8080 was already in use.Action:Identify and stop the process thats listening on port 8080 or configure thi…

2024年特种设备(电梯作业)题库考试题库

1.直接作用式液压电梯轿厢与柱塞&#xff08;缸筒&#xff09;之间的连接应为&#xff08;&#xff09;。 A.刚性连接 B.固定连接 C.法兰连接 D.挠性连接 答案&#xff1a;D 2.正常情况下&#xff0c;当电磁式继电器线圈得电时&#xff0c;其常开触点将&#xff08;&…

二进制求和、字符串相加-sting类题型

67. 二进制求和 - 力扣&#xff08;LeetCode&#xff09; 两个题目方法完全一样 用两个数据的末尾位相加&#xff0c;从末尾位开始逐位相加&#xff0c;记录进位&#xff1b; class Solution { public:string addBinary(string a, string b) {int end1 a.size() - 1;int end…

ASUS/华硕飞行堡垒9 FX506H FX706H系列 原厂win10系统 工厂文件 带F12 ASUS Recovery恢复

华硕工厂文件恢复系统 &#xff0c;安装结束后带隐藏分区&#xff0c;一键恢复&#xff0c;以及机器所有驱动软件。 系统版本&#xff1a;Windows10 原厂系统下载网址&#xff1a;http://www.bioxt.cn 需准备一个20G以上u盘进行恢复 请注意&#xff1a;仅支持以上型号专用…

昇思25天学习打卡营第18天|Pix2Pix实现图像转换

Pix2Pix概述 Pix2Pix是基于条件生成对抗网络实现的一种深度学习图像转换模型。Pix2Pix是将cGAN应用于有监督的图像到图像翻译&#xff0c;包括生成器和判别器。 基础原理 cGAN的生成器是将输入图片作为指导信息&#xff0c;由输入图像不断尝试生成用于迷惑判别器的“假”图像…

实现沉浸式体验的秘诀:深入了解折幕投影技术!

在当今多媒体技术的浪潮中&#xff0c;投影技术已蜕变成为超越传统内容展示范畴的非凡工具&#xff0c;它深度融合了互动性与沉浸感&#xff0c;成为连接观众与虚拟世界的桥梁。折幕投影技术&#xff0c;作为这一领域的璀璨明珠&#xff0c;更是以其独特而神奇的手法&#xff0…

lua入门(2) - 数据类型

前言 本文参考自: Lua 数据类型 | 菜鸟教程 (runoob.com) 希望详细了解的小伙伴还请查看上方链接: 八个基本类型 type - 函数查看数据类型: 测试程序: print(type("Hello world")) --> string print(type(10.4*3)) --> number print(t…

WEB安全-靶场

1 需求 2 语法 3 示例 男黑客|在线渗透测试靶场|网络安全培训基地|男黑客安全网 4 参考资料

中英双语介绍伦敦大学学院(University College London,UCL)

中文版 伦敦大学学院&#xff08;UCL&#xff09;简介 位置和周边环境 伦敦大学学院&#xff08;University College London&#xff0c;简称UCL&#xff09;位于英国伦敦市中心的布卢姆斯伯里&#xff08;Bloomsbury&#xff09;区。具体地址为&#xff1a; Gower Street, …

C语言 -- 扫雷游戏

C语言 – 扫雷游戏 游戏规则&#xff1a; 给定一个棋盘&#xff0c;玩家需要排查出所有隐藏的雷&#xff0c;也就是选择出所有不是雷的地方。 玩家选择位置&#xff0c;若此处有雷&#xff0c;玩家被炸死&#xff0c;游戏结束&#xff1b; 若此处无雷&#xff0c;此处提示周围一…

12.SQL注入-盲注基于时间(base on time)

SQL注入-盲注基于时间(base on time) boolian的盲注类型还有返回信息的状态&#xff0c;但是基于时间的盲注就什么都没有返回信息。 输入payload语句进行睡5秒中&#xff0c;通过开发这工具查看时间&#xff0c;如图所示&#xff0c;会在5秒钟后在执行&#xff0c;因此存在基于…

基于Java技术的篮球论坛系统

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言 Java 数据库 MySQL 技术 B/S模式、Java技术 工具 Visual Studio、MySQL数据库开发工具 系统展示 首页 用户注册界面 篮球论坛界面 个人中心界面 摘要 本…

Vite: 近几个版本的更新

概述 在 2021 年 2 月&#xff0c;尤大正式推出了 Vite 2.0 版本&#xff0c;可以说是 Vite 的一个重要转折点&#xff0c;自此之后 Vite 的用户量发生了非常迅速的增长&#xff0c;很快达到了每周 100 万的 npm 下载量。同时&#xff0c;Vite 的社区也越来越活跃&#xff0c;…

机器学习原理之 -- XGboost原理详解

XGBoost&#xff08;eXtreme Gradient Boosting&#xff09;是近年来在数据科学和机器学习领域中广受欢迎的集成学习算法。它在多个数据科学竞赛中表现出色&#xff0c;被广泛应用于各种机器学习任务。本文将详细介绍XGBoost的由来、基本原理、算法细节、优缺点及应用场景。 X…

14-38 剑和诗人12 - RAG+ 思维链 ⇒ 检索增强思维(RAT)

在快速发展的 NLP 和 LLM 领域&#xff0c;研究人员不断探索新技术来增强这些模型的功能。其中一种备受关注的技术是检索增强生成 (RAG) 方法&#xff0c;它将 LLM 的生成能力与从外部来源检索相关信息的能力相结合。然而&#xff0c;最近一项名为检索增强思维 (RAT) 的创新通过…

Go基础知识

目标 简单介绍一下 GO 语言的诞生背景&#xff0c;使用场景&#xff0c;目前使用方案简单介绍一下 GO的使用&#xff0c;GO的基础语法&#xff0c;简单过一下一些GO的语言例子着重介绍一下GO的特性&#xff0c;为什么大家都使用GO语言&#xff0c;GO的内存结构、为什么都说GO快…

【笔记】记一次在linux上通过在线安装mysql报错 CentOS 7 的官方镜像已经不再可用的解决方法+mysql配置

报错&#xff08;恨恨恨恨恨恨恨&#xff01;&#xff01;&#xff01;&#xff01;&#xff01;&#xff09;&#xff1a; [rootlocalhost ~]# sudo yum install mysql-server 已加载插件&#xff1a;fastestmirror, langpacks Determining fastest mirrors Could not retrie…

Unity中使用VectorGraphics插件时,VectorUtils.RenderSpriteToTexture2D方法返回结果错误的解决方法

Unity中使用VectorGraphics插件时&#xff0c;如果使用VectorUtils.BuildSprite方法创建Sprite&#xff0c;那么得到的Sprite往往是一个三角网格数比较多的Sprite&#xff0c;如果想要得到使用贴图只有两个三角面的方形Sprite&#xff0c;可以使用该插件提供的VectorUtils.Rend…

数据库概念题总结

1、 2、简述数据库设计过程中&#xff0c;每个设计阶段的任务 需求分析阶段&#xff1a;从现实业务中获取数据表单&#xff0c;报表等分析系统的数据特征&#xff0c;数据类型&#xff0c;数据约束描述系统的数据关系&#xff0c;数据处理要求建立系统的数据字典数据库设计…

ctfshow-web入门-文件包含(web82-web86)条件竞争实现session会话文件包含

目录 1、web82 2、web83 3、web84 4、web85 5、web86 1、web82 新增过滤点 . &#xff0c;查看提示&#xff1a;利用 session 对话进行文件包含&#xff0c;通过条件竞争实现。 条件竞争这个知识点在文件上传、不死马利用与查杀这些里面也会涉及&#xff0c;如果大家不熟悉…