【Python时序预测系列】基于LSTM实现多输入多输出单步预测(案例+源码)

这是我的第312篇原创文章。

一、引言

单站点多变量输入多变量输出单步预测问题----基于LSTM实现。

多输入就是输入多个特征变量

多输出就是同时预测出多个标签的结果

单步就是利用过去N天预测未来1天的结果

二、实现过程

2.1 读取数据集

df=pd.read_csv("data.csv", parse_dates=["Date"], index_col=[0])
print(df.shape)
print(df.head())
fea_num = len(df.columns)

df:

图片

2.2 划分数据集

# 拆分数据集为训练集和测试集
test_split=round(len(df)*0.20)
df_for_training=df[:-test_split]
df_for_testing=df[-test_split:]# 绘制训练集和测试集的折线图
plt.figure(figsize=(10, 6))
plt.plot(train_data, label='Training Data')
plt.plot(test_data, label='Testing Data')
plt.xlabel('Year')
plt.ylabel('Passenger Count')
plt.title('International Airline Passengers - Training and Testing Data')
plt.legend()
plt.show()

共5203条数据,8:2划分:训练集4162,测试集1041。

训练集和测试集:

图片

2.3 归一化

# 将数据归一化到 0~1 范围(整体一起做归一化)
scaler = MinMaxScaler(feature_range=(0,1))
df_for_training_scaled = scaler.fit_transform(df_for_training)
df_for_testing_scaled=scaler.transform(df_for_testing)

2.4 构造LSTM数据集(时序-->监督学习)

def createXY(data, win_size, target_feature_idxs):passwin_size = 12 # 时间窗口
target_feature_idxs = [0, 1, 2, 3, 4] # 指定待预测特征列索引
trainX, trainY = createXY(df_for_training_scaled, win_size, target_feature_idxs)
testX, testY = createXY(df_for_testing_scaled, win_size, target_feature_idxs)
print("训练集形状:", trainX.shape, trainY.shape)
print("测试集形状:", testX.shape, testY.shape)# 将数据集转换为 LSTM 模型所需的形状(样本数,时间步长,特征数)
trainX = np.reshape(trainX, (trainX.shape[0], win_size, fea_num))
testX = np.reshape(testX, (testX.shape[0], win_size, fea_num))print("trainX Shape-- ",trainX.shape)
print("trainY Shape-- ",trainY.shape)
print("testX Shape-- ",testX.shape)
print("testY Shape-- ",testY.shape)

滑动窗口设置为12:

取出df_for_training_scaled第【1-12】行第【1-5】列的12条数据作为trainX[0],取出df_for_training_scaled第【13】行第【1-5】列的1条数据作为trainY[0];依此类推。最终构造出的训练集数量(4150)比划分时候的训练集数量(4162)少一个滑动窗口(12)。

图片

trainX是一个(4150,12,5)的三维数组,三个维度分布表示(样本数量,步长,特征数),每一个样本比如trainX[0]是一个(12,5)二维数组表示(步长,特征数),这也是LSTM模型每一步的输入。

图片

trainY是一个(4150,5)的二维数组,二个维度分布表示(样本数量,标签数),每一个样本比如trainY[0]是一个(5,)一维数组表示(标签数,),这也是LSTM模型每一步的输出。

图片

2.5 建立模拟合模型

# 输入维度
input_shape = Input(shape=(trainX.shape[1], trainX.shape[2]))
# LSTM层
lstm_layer = LSTM(128, activation='relu')(input_shape)
# 全连接层
dense_1 = Dense(64, activation='relu')(lstm_layer)
dense_2 = Dense(32, activation='relu')(dense_1)
# 输出层
output_1 = Dense(1, name='Open')(dense_2)
output_2 = Dense(1, name='High')(dense_2)
output_3 = Dense(1, name='Low')(dense_2)
output_4 = Dense(1, name='Close')(dense_2)
output_5 = Dense(1, name='AdjClose')(dense_2)
model = Model(inputs = input_shape, outputs = [output_1, output_2, output_3, output_4, output_5])
model.compile(loss='mse', optimizer='adam')
model.summary()

这是一个多输入多输出的 LSTM 模型,接受包含12个时间步长和5个特征的输入序列,在经过一层128个神经元的 LSTM 层和5个全连接层后,输出5个单独的预测结果,分别是 Open、High、 Low、Close和 AdjClose。

图片

进行训练,这里[trainY[:,i] for i in range(trainY.shape[1])]把原来的trainY做了转置,是一个(5,4150)的二维数组,分别表示(标签数,样本数)。相当于建立了5个通道,每个通道是(4150,)的一维数组。

history = model.fit(trainX, [trainY[:,i] for i in range(trainY.shape[1])], epochs=20, batch_size=32)

2.6 进行预测

进行预测,上面我们分析过模型每一步的输入是一个(12,5)二维数组表示(步长,特征数),模型每一步的输出是是一个(5,)一维数组表示(标签数,)

prediction_test = model.predict(testX)

如果直接model.predict(testX),testX的形状是(1029,12,5),是一个批量预测,输出prediction_test是一个(5,1029,1)的三维数组,prediction_test[0]就是第一个标签的预测结果,prediction_test[1]就是第二个标签的预测结果...多输出就是同时预测出多个标签的结果

图片

2.7 预测效果展示

分析一下第一个变量open的效果,i=0:

prediction_train = model.predict(trainX)
prediction_train0=model.predict(trainX)[i]
prediction_train_copies_array = ...
pred_train=...
original_train_copies_array = trainY
original_train=...
print("train Pred Values-- ", pred_train)
print("\ntrain Original Values-- ", original_train)
plt.plot(df_for_training.index[win_size:,], original_train, color = 'red', label = '真实值')
plt.plot(df_for_training.index[win_size:,], pred_train, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

训练集真实值与预测值:

图片

prediction_test = model.predict(testX)
prediction_test0=model.predict(testX)[i]
prediction_test_copies_array = ...
pred_test=...
original_test_copies_array = testY
original_test=...
print("\ntest Original Values-- ", original_test)
plt.plot(df_for_testing.index[win_size:,], original_test, color = 'red', label = '真实值')
plt.plot(df_for_testing.index[win_size:,], pred_test, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

测试集真实值与预测值:

图片

2.8 评估指标

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/863990.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS-实例-div 水平居中 垂直靠上

1 需求 2 语法 3 示例 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>表格水平居中、垂直靠上示例…

AI 上车的一些随想

​最近一直在做AI大模型上车的战略企划工作&#xff0c;听了好多供应商的宣讲&#xff0c;自己也查阅了大量书籍、资料。信息输入呈现爆炸性增长&#xff0c;受限于专业知识水平&#xff0c;仅能在应用层面上有所思考。纯个人观点&#xff0c;仅供参考。 车自古以来都是移动工…

2024年06月CCF-GESP编程能力等级认证Scratch图形化编程四级真题解析

本文收录于《Scratch等级认证CCF-GESP图形化真题解析》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 一、单选题(共 10 题,每题 2 分,共 30 分) 第1题 小杨父母带他到某培训机构给他报名参加 CCF 组织的 GESP 认证考试的第 1 级,那他可以选择的认证语言有几…

前端面试题(基础篇十四)

一、DOMContentLoaded 事件和 Load 事件的区别&#xff1f; 当初始的 HTML 文档被完全加载和解析完成之后&#xff0c;DOMContentLoaded 事件被触发&#xff0c;而无需等待样式表、图像和子框架的加载完成。 Load 事件是当所有资源加载完成后触发的。 二、简述一下你对 HTML 语…

机器学习 中数据是如何处理的?

数据处理是将数据从给定形式转换为更可用和更理想的形式的任务&#xff0c;即使其更有意义、信息更丰富。使用机器学习算法、数学建模和统计知识&#xff0c;整个过程可以自动化。这个完整过程的输出可以是任何所需的形式&#xff0c;如图形、视频、图表、表格、图像等等&#…

HarmonyOS开发实战:JSON组件使用方式-API

JSON类组件 模块介绍JSONValue提供eftool中的JSON相关对象的类型定义JSONObject提供类Java的JSON对象的系列方法以及相互转换JSONArray提供类Java的JSON数组的系列方法以及相互转换JSONArrayList提供类Java的JSON数组的系列方法以及相互转换JSONUtil提供JSON转换一系列判断方法…

理想汽车提出3DRealCar:首个大规模3D真实汽车数据集

理想提出3DRealCar&#xff0c;这是第一个大规模 3D 实车数据集&#xff0c;包含 2500 辆在真实场景中拍摄的汽车。我们希望 3DRealCar 可以成为促进汽车相关任务的宝贵资源。 理想汽车提出3DRealCar&#xff1a;首个大规模3D真实汽车数据集! 我们精心策划的高质量3DRealCar数…

全球点赞第一起名大师颜廷利:是金子总会“花光”的

在物质世界的繁华背后&#xff0c;隐藏着一个深刻的真理&#xff1a;有形之物的分享会逐渐减少&#xff0c;而无形之物的传递却能不断增值。金钱、货币、银两这些商业领域的实体&#xff0c;往往激发出人类对更多财富的渴望和对资源枯竭的恐惧。这种恐惧源于资源的有限性&#…

【数据结构】(C语言):二叉搜索树

二叉搜索树&#xff1a; 树不是线性的&#xff0c;是层级结构。基本单位是节点&#xff0c;每个节点最多2个子节点。有序。每个节点&#xff0c;其左子节点都比它小&#xff0c;其右子节点都比它大。每个子树都是一个二叉搜索树。每个节点及其所有子节点形成子树。可以是空树。…

VRRP和IPVS

1.VRRP VRRP(Virtual Router Redundancy Protocol,简称VRRP,虚拟路由冗余协议)是一种选择协议,它可以把一个虚拟路由器的责任动态分配到局域网上的VRRP路由器中的一台。控制虚拟路由器IP地址的VRRP路由器称为主路由器,它负责转发数据包到这些虚拟IP地址。 VRRP一旦主路由…

PointNet++论文导读

PointNet论文导读 主要改进网络结构&#xff1a;非均匀采样下的特征学习的鲁棒性利用点特征传播处理数据集分割 论文链接:https://arxiv.org/abs/1612.00593 主要改进 PointNet的基本思想是学习每个点的空间编码&#xff0c;然后将所有单个点的特征聚合成一个全局点云标签&am…

Apache Ranger 2.4.0 集成hadoop 3.X(Kerbos)

1、安装Ranger 参照上一个文章 2、修改配置 把各种plugin转到统一目录&#xff08;源码编译的target目录下拷贝过来&#xff09;&#xff0c;比如 tar zxvf ranger-2.4.0-hdfs-plugin.tar.gz tar zxvf ranger-2.4.0-hdfs-plugin.tar.gz vim install.properties POLICY_MG…

防火墙防御体系结构类型

防火墙防御体系结构类型 防火墙是网络安全的核心组件&#xff0c;用于保护网络和系统免受未经授权的访问和各种网络攻击。防火墙防御体系结构类型多样化&#xff0c;每种类型都针对不同的安全需求和应用场景&#xff0c;提供不同层次的保护。以下是几种常见的防火墙防御体系结…

【车载开发系列】NXP开发环境介绍

【车载开发系列】NXP开发环境介绍 【车载开发系列】NXP开发环境介绍 【车载开发系列】NXP开发环境介绍一. 开发环境1、S32 Design Studio for S32 Platform2、S32 Design Studio for ARM3、S32 Design Studio IDE 二. NXP开发环境支持的单片机1&#xff09;Kinetis系列2&#x…

力扣3152.特殊数组 II

力扣3152.特殊数组 II 满足条件为0 &#xff0c; 不满足为1 最终如果区间和为0 则为特殊数组 class Solution {public:vector<bool> isArraySpecial(vector<int>& nums, vector<vector<int>>& queries) {int n nums.size();vector<int&…

论文阅读:Simple and Efficient Heterogeneous Graph Neural Network

Yang, Xiaocheng, Mingyu Yan, Shirui Pan, Xiaochun Ye and Dongrui Fan. “Simple and Efficient Heterogeneous Graph Neural Network.” AAAI Conference on Artificial Intelligence (2022). 论文地址&#xff1a;[PDF] Simple and Efficient Heterogeneous Graph Neural…

Java集合框架性能优化与选择指南

Java集合框架性能优化与选择指南 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01; 引言 Java集合框架是每位Java开发者日常工作中不可或缺的一部分。正确选择和…

Bytebase 2.20.0 - 支持为工单事件配置飞书个人通知

&#x1f680; 新功能 支持 Databricks。支持 SQL Server 的 TLS/SSL 连接。支持为工单事件配置飞书个人通知。支持限制用户注册的邮箱域名。 &#x1f514; 重大变更 将分类分级同步设置从数据库配置移至工作空间的全局配置。 SQL 编辑器只读模式下只允许执行 Redis 的只读…

着色器预热?为什么 Flutter 需要?为什么原生 App 不需要?那 Compose 呢?Impeller 呢?

依旧是来自网友的问题&#xff0c;这个问题在一定程度上还是很意思的&#xff0c;因为大家可能会想&#xff0c;Flutter 使用 skia&#xff0c;原生 App 是用 skia &#xff0c;那为什么在 Flutter 上会有着色器预热&#xff08;Shader Warmup&#xff09;这样的说法&#xff1…