【Python时序预测系列】基于LSTM实现多输入多输出单步预测(案例+源码)

这是我的第312篇原创文章。

一、引言

单站点多变量输入多变量输出单步预测问题----基于LSTM实现。

多输入就是输入多个特征变量

多输出就是同时预测出多个标签的结果

单步就是利用过去N天预测未来1天的结果

二、实现过程

2.1 读取数据集

df=pd.read_csv("data.csv", parse_dates=["Date"], index_col=[0])
print(df.shape)
print(df.head())
fea_num = len(df.columns)

df:

图片

2.2 划分数据集

# 拆分数据集为训练集和测试集
test_split=round(len(df)*0.20)
df_for_training=df[:-test_split]
df_for_testing=df[-test_split:]# 绘制训练集和测试集的折线图
plt.figure(figsize=(10, 6))
plt.plot(train_data, label='Training Data')
plt.plot(test_data, label='Testing Data')
plt.xlabel('Year')
plt.ylabel('Passenger Count')
plt.title('International Airline Passengers - Training and Testing Data')
plt.legend()
plt.show()

共5203条数据,8:2划分:训练集4162,测试集1041。

训练集和测试集:

图片

2.3 归一化

# 将数据归一化到 0~1 范围(整体一起做归一化)
scaler = MinMaxScaler(feature_range=(0,1))
df_for_training_scaled = scaler.fit_transform(df_for_training)
df_for_testing_scaled=scaler.transform(df_for_testing)

2.4 构造LSTM数据集(时序-->监督学习)

def createXY(data, win_size, target_feature_idxs):passwin_size = 12 # 时间窗口
target_feature_idxs = [0, 1, 2, 3, 4] # 指定待预测特征列索引
trainX, trainY = createXY(df_for_training_scaled, win_size, target_feature_idxs)
testX, testY = createXY(df_for_testing_scaled, win_size, target_feature_idxs)
print("训练集形状:", trainX.shape, trainY.shape)
print("测试集形状:", testX.shape, testY.shape)# 将数据集转换为 LSTM 模型所需的形状(样本数,时间步长,特征数)
trainX = np.reshape(trainX, (trainX.shape[0], win_size, fea_num))
testX = np.reshape(testX, (testX.shape[0], win_size, fea_num))print("trainX Shape-- ",trainX.shape)
print("trainY Shape-- ",trainY.shape)
print("testX Shape-- ",testX.shape)
print("testY Shape-- ",testY.shape)

滑动窗口设置为12:

取出df_for_training_scaled第【1-12】行第【1-5】列的12条数据作为trainX[0],取出df_for_training_scaled第【13】行第【1-5】列的1条数据作为trainY[0];依此类推。最终构造出的训练集数量(4150)比划分时候的训练集数量(4162)少一个滑动窗口(12)。

图片

trainX是一个(4150,12,5)的三维数组,三个维度分布表示(样本数量,步长,特征数),每一个样本比如trainX[0]是一个(12,5)二维数组表示(步长,特征数),这也是LSTM模型每一步的输入。

图片

trainY是一个(4150,5)的二维数组,二个维度分布表示(样本数量,标签数),每一个样本比如trainY[0]是一个(5,)一维数组表示(标签数,),这也是LSTM模型每一步的输出。

图片

2.5 建立模拟合模型

# 输入维度
input_shape = Input(shape=(trainX.shape[1], trainX.shape[2]))
# LSTM层
lstm_layer = LSTM(128, activation='relu')(input_shape)
# 全连接层
dense_1 = Dense(64, activation='relu')(lstm_layer)
dense_2 = Dense(32, activation='relu')(dense_1)
# 输出层
output_1 = Dense(1, name='Open')(dense_2)
output_2 = Dense(1, name='High')(dense_2)
output_3 = Dense(1, name='Low')(dense_2)
output_4 = Dense(1, name='Close')(dense_2)
output_5 = Dense(1, name='AdjClose')(dense_2)
model = Model(inputs = input_shape, outputs = [output_1, output_2, output_3, output_4, output_5])
model.compile(loss='mse', optimizer='adam')
model.summary()

这是一个多输入多输出的 LSTM 模型,接受包含12个时间步长和5个特征的输入序列,在经过一层128个神经元的 LSTM 层和5个全连接层后,输出5个单独的预测结果,分别是 Open、High、 Low、Close和 AdjClose。

图片

进行训练,这里[trainY[:,i] for i in range(trainY.shape[1])]把原来的trainY做了转置,是一个(5,4150)的二维数组,分别表示(标签数,样本数)。相当于建立了5个通道,每个通道是(4150,)的一维数组。

history = model.fit(trainX, [trainY[:,i] for i in range(trainY.shape[1])], epochs=20, batch_size=32)

2.6 进行预测

进行预测,上面我们分析过模型每一步的输入是一个(12,5)二维数组表示(步长,特征数),模型每一步的输出是是一个(5,)一维数组表示(标签数,)

prediction_test = model.predict(testX)

如果直接model.predict(testX),testX的形状是(1029,12,5),是一个批量预测,输出prediction_test是一个(5,1029,1)的三维数组,prediction_test[0]就是第一个标签的预测结果,prediction_test[1]就是第二个标签的预测结果...多输出就是同时预测出多个标签的结果

图片

2.7 预测效果展示

分析一下第一个变量open的效果,i=0:

prediction_train = model.predict(trainX)
prediction_train0=model.predict(trainX)[i]
prediction_train_copies_array = ...
pred_train=...
original_train_copies_array = trainY
original_train=...
print("train Pred Values-- ", pred_train)
print("\ntrain Original Values-- ", original_train)
plt.plot(df_for_training.index[win_size:,], original_train, color = 'red', label = '真实值')
plt.plot(df_for_training.index[win_size:,], pred_train, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

训练集真实值与预测值:

图片

prediction_test = model.predict(testX)
prediction_test0=model.predict(testX)[i]
prediction_test_copies_array = ...
pred_test=...
original_test_copies_array = testY
original_test=...
print("\ntest Original Values-- ", original_test)
plt.plot(df_for_testing.index[win_size:,], original_test, color = 'red', label = '真实值')
plt.plot(df_for_testing.index[win_size:,], pred_test, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

测试集真实值与预测值:

图片

2.8 评估指标

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/863990.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS-实例-div 水平居中 垂直靠上

1 需求 2 语法 3 示例 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>表格水平居中、垂直靠上示例…

AI 上车的一些随想

​最近一直在做AI大模型上车的战略企划工作&#xff0c;听了好多供应商的宣讲&#xff0c;自己也查阅了大量书籍、资料。信息输入呈现爆炸性增长&#xff0c;受限于专业知识水平&#xff0c;仅能在应用层面上有所思考。纯个人观点&#xff0c;仅供参考。 车自古以来都是移动工…

2024年06月CCF-GESP编程能力等级认证Scratch图形化编程四级真题解析

本文收录于《Scratch等级认证CCF-GESP图形化真题解析》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 一、单选题(共 10 题,每题 2 分,共 30 分) 第1题 小杨父母带他到某培训机构给他报名参加 CCF 组织的 GESP 认证考试的第 1 级,那他可以选择的认证语言有几…

前端面试题(基础篇十四)

一、DOMContentLoaded 事件和 Load 事件的区别&#xff1f; 当初始的 HTML 文档被完全加载和解析完成之后&#xff0c;DOMContentLoaded 事件被触发&#xff0c;而无需等待样式表、图像和子框架的加载完成。 Load 事件是当所有资源加载完成后触发的。 二、简述一下你对 HTML 语…

机器学习 中数据是如何处理的?

数据处理是将数据从给定形式转换为更可用和更理想的形式的任务&#xff0c;即使其更有意义、信息更丰富。使用机器学习算法、数学建模和统计知识&#xff0c;整个过程可以自动化。这个完整过程的输出可以是任何所需的形式&#xff0c;如图形、视频、图表、表格、图像等等&#…

理想汽车提出3DRealCar:首个大规模3D真实汽车数据集

理想提出3DRealCar&#xff0c;这是第一个大规模 3D 实车数据集&#xff0c;包含 2500 辆在真实场景中拍摄的汽车。我们希望 3DRealCar 可以成为促进汽车相关任务的宝贵资源。 理想汽车提出3DRealCar&#xff1a;首个大规模3D真实汽车数据集! 我们精心策划的高质量3DRealCar数…

全球点赞第一起名大师颜廷利:是金子总会“花光”的

在物质世界的繁华背后&#xff0c;隐藏着一个深刻的真理&#xff1a;有形之物的分享会逐渐减少&#xff0c;而无形之物的传递却能不断增值。金钱、货币、银两这些商业领域的实体&#xff0c;往往激发出人类对更多财富的渴望和对资源枯竭的恐惧。这种恐惧源于资源的有限性&#…

【数据结构】(C语言):二叉搜索树

二叉搜索树&#xff1a; 树不是线性的&#xff0c;是层级结构。基本单位是节点&#xff0c;每个节点最多2个子节点。有序。每个节点&#xff0c;其左子节点都比它小&#xff0c;其右子节点都比它大。每个子树都是一个二叉搜索树。每个节点及其所有子节点形成子树。可以是空树。…

PointNet++论文导读

PointNet论文导读 主要改进网络结构&#xff1a;非均匀采样下的特征学习的鲁棒性利用点特征传播处理数据集分割 论文链接:https://arxiv.org/abs/1612.00593 主要改进 PointNet的基本思想是学习每个点的空间编码&#xff0c;然后将所有单个点的特征聚合成一个全局点云标签&am…

Apache Ranger 2.4.0 集成hadoop 3.X(Kerbos)

1、安装Ranger 参照上一个文章 2、修改配置 把各种plugin转到统一目录&#xff08;源码编译的target目录下拷贝过来&#xff09;&#xff0c;比如 tar zxvf ranger-2.4.0-hdfs-plugin.tar.gz tar zxvf ranger-2.4.0-hdfs-plugin.tar.gz vim install.properties POLICY_MG…

论文阅读:Simple and Efficient Heterogeneous Graph Neural Network

Yang, Xiaocheng, Mingyu Yan, Shirui Pan, Xiaochun Ye and Dongrui Fan. “Simple and Efficient Heterogeneous Graph Neural Network.” AAAI Conference on Artificial Intelligence (2022). 论文地址&#xff1a;[PDF] Simple and Efficient Heterogeneous Graph Neural…

Bytebase 2.20.0 - 支持为工单事件配置飞书个人通知

&#x1f680; 新功能 支持 Databricks。支持 SQL Server 的 TLS/SSL 连接。支持为工单事件配置飞书个人通知。支持限制用户注册的邮箱域名。 &#x1f514; 重大变更 将分类分级同步设置从数据库配置移至工作空间的全局配置。 SQL 编辑器只读模式下只允许执行 Redis 的只读…

着色器预热?为什么 Flutter 需要?为什么原生 App 不需要?那 Compose 呢?Impeller 呢?

依旧是来自网友的问题&#xff0c;这个问题在一定程度上还是很意思的&#xff0c;因为大家可能会想&#xff0c;Flutter 使用 skia&#xff0c;原生 App 是用 skia &#xff0c;那为什么在 Flutter 上会有着色器预热&#xff08;Shader Warmup&#xff09;这样的说法&#xff1…

论文1--ViT

Vision Transformer (ViT) 论文&#xff1a;https://arxiv.org/abs/2010.11929代码&#xff1a;GitHub - google-research/vision_transformer 1.背景 &#xff08;1&#xff09;transformer在NLP很强&#xff0c;但在CV的应用还非常有限&#xff0c;在此之前只有目标检测中…

【轻量化】YOLOv8 更换骨干网络之 MobileNetv4 | 《号称最强轻量化网络》

论文地址:https://arxiv.org/pdf/2404.10518 代码地址:https://github.com/tensorflow/models/blob/master/official/vision/modeling/backbones/mobilenet.py 文章速览 文章摘要 MobileNetV4引入了一个名为Universal Inverted Bottleneck (UIB) 的新搜索模块,这个模块融合…

民用无人机企业招标投标需要资质证书详解

一、基础资质 在民用无人机企业的招标投标过程中&#xff0c;基础资质是首要考虑的因素。这些资质通常包括企业注册资质、税务登记证、组织机构代码证等。 1.1 企业注册资质 企业应具备合法的注册资质&#xff0c;即营业执照。该执照应包含企业名称、注册地址、法定代表人、…

idea集成uglifycss压缩混淆css

Uglifycss介绍 https://www.npmjs.com/package/uglifycss 命令行 $ uglifycss [options] [filename] [...] > output 选项&#xff1a; --max-line-len n每个字符添加一个换行符&#xff08;大约&#xff09;; 表示无换行符&#xff0c;并且是默认值n0 --expand-vars扩…

考研数学|《660》一刷的错题,二刷还错,怎么能做对?

660这本习题册的难度不小&#xff0c;它不仅考察你对知识点的掌握程度&#xff0c;还考察你的解题思路和方法。很多题目会同时涉及多个知识点&#xff0c;而且对概念的挖掘非常深入&#xff0c;甚至在一些容易出错的地方还设置了陷阱&#xff0c;这对于基础不扎实的同学来说&am…

商城积分系统的代码实现(下)-- 积分订单的退款与结算

一、接着上文 用户在消耗积分的时候&#xff0c;需要根据一定的逻辑&#xff0c;除了扣减账户的当前余额&#xff0c;还需要依次消费积分订单的余额。 private void updatePointsOrderByUse(Integer schoolId, Long userId, String pointsType, int usingPoints) {List<Po…