时间序列LSTM实现

这个代码参考了时间序列预测模型实战案例(三)(LSTM)(Python)(深度学习)时间序列预测(包括运行代码以及代码讲解)_lstm预测模型-CSDN博客

结合我之前所学的lstm-seq2seq里所学习到的知识对其进行预测

import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from sklearn.preprocessing import MinMaxScalernp.random.seed(0)def calculate_mae(y_true, y_pred):# 平均绝对误差mae = np.mean(np.abs(y_true - y_pred))return maetrue_data = pd.read_csv(r"C:\Users\33746\Desktop\DailyDelhiClimateTrain.csv")  # 填你自己的数据地址target = 'meanpressure'# 这里加一些数据的预处理, 最后需要的格式是pd.seriestrue_data = np.array(true_data['meanpressure'])# 定义窗口大小
test_data_size = 32
# 训练集和测试集的尺寸划分
test_size = 0.15
train_size = 0.85
# 标准化处理
scaler_train = MinMaxScaler(feature_range=(0, 1))
scaler_test = MinMaxScaler(feature_range=(0, 1))
train_data = true_data[:int(train_size * len(true_data))]
test_data = true_data[-int(test_size * len(true_data)):]
print("训练集尺寸:", len(train_data))
print("测试集尺寸:", len(test_data))
train_data_normalized = scaler_train.fit_transform(train_data.reshape(-1, 1))
test_data_normalized = scaler_test.fit_transform(test_data.reshape(-1, 1))
# 转化为深度学习模型需要的类型Tensor
train_data_normalized = torch.FloatTensor(train_data_normalized).view(-1)
test_data_normalized = torch.FloatTensor(test_data_normalized).view(-1)def create_inout_sequences(input_data, tw, pre_len):inout_seq = []L = len(input_data)for i in range(L - tw):train_seq = input_data[i:i + tw]if (i + tw + 4) > len(input_data):breaktrain_label = input_data[i + tw:i + tw + pre_len]inout_seq.append((train_seq, train_label))return inout_seqpre_len = 4
train_window = 16
# 定义训练器的的输入
train_inout_seq = create_inout_sequences(train_data_normalized, train_window, pre_len)class LSTM(nn.Module):def __init__(self, input_dim=1, hidden_dim=350, output_dim=1):super(LSTM, self).__init__()self.hidden_dim = hidden_dimself.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = x.unsqueeze(1)h0_lstm = torch.zeros(1, self.hidden_dim).to(x.device)c0_lstm = torch.zeros(1, self.hidden_dim).to(x.device)out, _ = self.lstm(x, (h0_lstm, c0_lstm))out = out[:, -1]out = self.fc(out)return outlstm_model = LSTM(input_dim=1, output_dim=pre_len, hidden_dim=train_window)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)
epochs = 10
Train = False  # 训练还是预测if Train:losss = []lstm_model.train()  # 训练模式start_time = time.time()  # 计算起始时间for i in range(epochs):for seq, labels in train_inout_seq:lstm_model.train()optimizer.zero_grad()y_pred = lstm_model(seq)single_loss = loss_function(y_pred, labels)single_loss.backward()optimizer.step()print(f'epoch: {i:3} loss: {single_loss.item():10.8f}')losss.append(single_loss.detach().numpy())torch.save(lstm_model.state_dict(), 'save_model.pth')print(f"模型已保存,用时:{(time.time() - start_time) / 60:.4f} min")plt.plot(losss)# 设置图表标题和坐标轴标签plt.title('Training Error')plt.xlabel('Epoch')plt.ylabel('Error')# 保存图表到本地plt.savefig('training_error.png')
else:# 加载模型进行预测lstm_model.load_state_dict(torch.load('save_model.pth'))lstm_model.eval()  # 评估模式results = []reals = []losss = []test_inout_seq = create_inout_sequences(test_data_normalized, train_window, pre_len)for seq, labels in train_inout_seq:pred = lstm_model(seq)[0].item()results.append(pred)mae = calculate_mae(pred, labels.detach().numpy())  # MAE误差计算绝对值(预测值  - 真实值)reals.append(labels.detach().numpy())losss.append(mae)print("模型预测结果:", results)print("预测误差MAE:", losss)plt.style.use('ggplot')# 创建折线图plt.plot(results, label='real', color='blue')  # 实际值plt.plot(reals, label='forecast', color='red', linestyle='--')  # 预测值# 增强视觉效果plt.grid(True)plt.title('real vs forecast')plt.xlabel('time')plt.ylabel('value')plt.legend()plt.savefig('test——results.png')

train_data = true_data[:int(train_size * len(true_data))]

test_data = true_data[-int(test_size * len(true_data)):]是在干什么

这两行代码的作用是将数据集 true_data 划分为训练集 train_data 和测试集 test_data。具体来说:

代码解释:

  1. train_data = true_data[:int(train_size * len(true_data))]:

    • train_size 是一个比例值(例如0.8,表示80%),表示训练集占总数据集的比例。

    • len(true_data) 返回 true_data 的长度。

    • train_size * len(true_data) 计算出训练集的大小。

    • int(train_size * len(true_data)) 将计算结果转换为整数,表示训练集的实际样本数量。

    • true_data[:int(train_size * len(true_data))] 使用切片操作从 true_data 的开头到训练集的样本数量,提取出训练集。

  2. test_data = true_data[-int(test_size * len(true_data)):]:

    • test_size 是另一个比例值(例如0.2,表示20%),表示测试集占总数据集的比例。

    • test_size * len(true_data) 计算出测试集的大小。

    • int(test_size * len(true_data)) 将计算结果转换为整数,表示测试集的实际样本数量。

    • true_data[-int(test_size * len(true_data)):] 使用切片操作从 true_data 的末尾到测试集的样本数量,提取出测试集。

示例:

假设 true_data 有100个样本,train_size 为0.8,test_size 为0.2:

  • train_size * len(true_data) = 0.8 * 100 = 80

  • test_size * len(true_data) = 0.2 * 100 = 20

那么:

  • train_data = true_data[:80] 会提取 true_data 的前80个样本作为训练集。

  • test_data = true_data[-20:] 会提取 true_data 的后20个样本作为测试集。

注意事项:

  • 数据顺序:如果 true_data 是有序的(例如时间序列数据),这种划分方式可能会导致训练集和测试集之间存在时间上的重叠,从而引入数据泄露。在这种情况下,应该按照时间顺序进行划分。

  • 随机性:如果 true_data 是无序的,这种划分方式是合理的。但如果数据没有被打乱,建议在划分之前先打乱数据,以确保训练集和测试集能够代表整体数据的分布。

  • 假设 input_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]tw = 3pre_len = 2

  • 第一次循环:

    • i = 0

    • train_seq = [1, 2, 3]

    • train_label = [4, 5]

    • inout_seq = [([1, 2, 3], [4, 5])]

  • 第二次循环:

    • i = 1

    • train_seq = [2, 3, 4]

    • train_label = [5, 6]

    • inout_seq = [([1, 2, 3], [4, 5]), ([2, 3, 4], [5, 6])]

  • 以此类推,直到 i = 7 时,train_seq = [8, 9, 10]train_label = [],此时 i + tw + pre_len 超出 input_data 的范围,循环结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/880507.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Meta Sapiens 人体AI模型

Meta 一直是开发图像和视频模型的领导者,现在他们又增加了一个新东西:Meta Sapiens。和Homo sapiens一样,这个模型也是关于人类的。它旨在执行与人类相关的任务,例如理解身体姿势、识别身体部位、预测深度,甚至确定皮肤…

算法课习题汇总(3)

循环日程表 设有N个选手进行循环比赛,其中N2M,要求每名选手要与其他N−1名选手都赛一次,每名选手每天比赛一次,循环赛共进行N−1天,要求每天没有选手轮空。 例如4个人进行比赛: 思路: 把表格…

Spring MVC 基本配置步骤 总结

1.简介 本文记录Spring MVC基本项目拉起配置步骤。 2.步骤 在pom.xml中导入依赖&#xff1a; <dependency><groupId>org.springframework</groupId><artifactId>spring-webmvc</artifactId><version>6.0.6</version><scope>…

通过WebTopo在ARMxy边缘计算网关上实现系统集成

随着工业互联网技术的发展&#xff0c;边缘计算成为了连接物理世界与数字世界的桥梁&#xff0c;其重要性日益凸显。边缘计算网关作为数据采集、处理与传输的核心设备&#xff0c;在智能制造、智慧城市等领域发挥着关键作用。 1. BL340系列概述 BL340系列是基于全志科技T507-…

MATLAB仿真实现图像去噪

摘要 数字图像处理是一门新兴技术&#xff0c;随着计算机硬件的发展&#xff0c;其处理能力的不断增强&#xff0c;数字图像的实时处理已经成为可能。由于数字图像处理的各种算法的出现&#xff0c;图像处理学科在飞速发展的同时逐渐向其他学科交叉渗透。数字图像处理是一种通过…

【目标检测】隐翅虫数据集386张VOC+YOLO

隐翅虫数据集&#xff1a;图片来自网页爬虫&#xff0c;删除重复项后整理标注而成 数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;386 标注…

电子电路的基础知识

电子电路是现代电子技术的基础&#xff0c;由电子元件&#xff08;如电阻、电容、电感、二极管、晶体管等&#xff09;和无线电元件通过一定方式连接而成的电路系统。 以下是对电子电路的详细概述&#xff1a; 一、定义与分类 定义&#xff1a;电子电路是指由电子器件和有关无…

240925-GAN生成对抗网络

GAN生成对抗网络 GAN&#xff0c;顾名思义&#xff0c;gan……咳咳&#xff0c;就是干仗嘛&#xff08;听子豪兄的课讲说这个名字还真的源于中文这个字&#xff09;&#xff0c;对应的就有两方&#xff0c;放在这里就是有两个网络互相对抗互相学习。类比武林高手切磋&#xff…

dev containers plugins for vscode构建虚拟开发环境

0. 需求说明 自用笔记本构建一套开发环境&#xff0c;用docker 虚拟插件 dev containers,实现开发环境的构建&#xff0c;我想构建一套LLMs的环境&#xff0c;由于环境配置太多&#xff0c;不想污染本地环境&#xff0c;所以选择隔离技术 1. 环境准备 vscodedocker 2. 步骤…

韦东山FreeRTOS笔记

介绍 这篇文章是我学习FreeRTOS的笔记 学的是哔哩哔哩韦东山老师的课程 在学习FreeRTOS之前已经学习过江协的标准库和一丢丢的超子说物联网的HAL了。他们讲的都很不错 正在更新&#xff0c; 大家可以在我的Gitee仓库中下载笔记源文件、项目资料等 笔记源文件可以在Notion…

idea.vmoptions 最佳配置

1. 推荐的 idea64.exe.vmoptions 配置&#xff1a; -Xms1024m -Xmx4096m -XX:ReservedCodeCacheSize512m -XX:UseG1GC -XX:SoftRefLRUPolicyMSPerMB50 -XX:CICompilerCount4 -XX:HeapDumpOnOutOfMemoryError -XX:-OmitStackTraceInFastThrow -Dsun.io.useCanonCachesfalse -Dj…

微服务JSR303解析部署使用全流程

目录 1、什么是JSR303校验 2、小试牛刀 【2.1】添加依赖 【2.2】添加application.yml配置文件修改端口 【2.3】创建实体类User 【2.4】创建控制器 【2.5】创建启动类 【注意】不必创建前端页面 3、规范返回值格式&#xff1a; 3.1添加ResultCode工具类 3.2添加Resul…

NASA数据集:ATLAS/ICESat-2 L3B 南极和北极网格陆地冰高,第 3 版

目录 简介 摘要 代码 引用 网址推荐 0代码在线构建地图应用 机器学习 ATLAS/ICESat-2 L3B Gridded Antarctic and Arctic Land Ice Height V003 简介 ATLAS/ICESat-2 L3B 南极和北极网格陆地冰高&#xff0c;第 3 版 ATL14 和 ATL15 将 ATLAS/ICESat-2 L3B 年度陆地冰…

【蓝桥杯省赛真题55】Scratch找不同游戏 蓝桥杯scratch图形化编程 中小学生蓝桥杯省赛真题讲解

scratch找不同游戏 第十五届青少年蓝桥杯scratch编程选拔赛真题解析 PS&#xff1a;其实这题在选拔赛里面就出现过类似的题目&#xff0c;只是难度提升了一点&#xff0c;具体可以见【蓝桥杯选拔赛真题84】Scratch找不同游戏 第十五届蓝桥杯scratch图形化编程 少儿编程创意编…

java日志门面之JCL和SLF4J

文章目录 前言一、JCL1、JCL简介2、快速入门3、 JCL原理 二、SLF4J1、SLF4J简介2、快速入门2.1、输出动态信息2.2、异常信息的处理 3、绑定日志的实现3.1、slf4j实现slf4j-simple和logback3.2、slf4j绑定适配器实现log4j3.2、Slf4j注解 4、桥接旧的日志框架4.1、log4j日志重构为…

通过队列实现栈

请你仅使用两个队列实现一个后入先出&#xff08;LIFO&#xff09;的栈&#xff0c;并支持普通栈的全部四种操作&#xff08;push、top、pop 和 empty&#xff09;。 实现 MyStack 类&#xff1a; void push(int x) 将元素 x 压入栈顶。int pop() 移除并返回栈顶元素。int to…

Android源码管理

文章目录 需求及场景需求困难疑惑点 源码管理方式及过程基本仓库管理方式 常用源码git 命令git init添加.gitignoregit add allgit add 文件名称git commit -a -m "提交内容说明"git statusgit loggit reset --hardgit clean -fd实际场景&#xff0c;从一个项目切换到…

大屏走马灯与echarts图表柱状图饼图开发小结

一、使用ant-design-vue的走马灯(a-carousel)注意事项 <!-- 左边的轮播图片 --><a-carousel :after-change"handleCarouselChange" autoplay class"carousel" :transition"transitionName"><div v-for"(item, index) in it…

论文阅读【时间序列】ModerTCN (ICLR2024)

【时间序列】ModerTCN (ICLR2024) 原文链接&#xff1a;ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis 代码仓库&#xff1a;ModerTCN 简易版本实现代码可以参考&#xff1a;&#xff08;2024 ICLR&#xff09;ModernTCN&#xff1a;A Mod…

解决hbase和hadoop的log4j依赖冲突的警告

一、运行hbase的发现依赖冲突的警告 这警告不影响使用 二、重命名log4j文件 进入HBase的lib包下&#xff0c;将HBase的log4j文件重命名&#xff0c;改成备份&#xff0c;这样再次运行hbase的时候&#xff0c;就没有依赖冲突了。 三、冲突成功解决