【python量化】多种Transformer模型用于股价预测(Autoformer, FEDformer和PatchTST等)_neuralforecast

bb1fee63b7d3f7f1db42af482660a610.png

写在前面

在本文中,我们利用Nixtla的NeuralForecast框架,实现多种基于Transformer的时序预测模型,包括:Transformer, Informer, Autoformer, FEDformer和PatchTST模型,并且实现将它们应用于股票价格预测的简单例子

1

NeuralForecast

neuralforecast 是一个旨在为时间序列预测提供一个丰富的、高度可用和鲁棒的神经网络模型集合的工具库。这个库集成了从传统的多层感知器(MLP)和递归神经网络(RNN)到最新的模型如N-BEATS、N-HiTS、TFT,以及其他高级架构,以适应多样化的预测需求。它的关键功能包括对静态、历史和未来的外生变量的支持,提高了模型在实际应用中的灵活性。库中的模型提供了良好的预测可解释性,允许用户绘制趋势、季节性以及外生预测组件。neuralforecast 还实现了概率预测,通过简单的适配器支持量化损失和参数分布,增加了预测结果的置信度。此外,它提供了自动模型选择功能,通过并行自动超参数调整来高效确定最优的模型配置。库的简洁接口设计与SKLearn兼容,确保了易用性,并且训练和评估损失的计算能够适应不同的比例,这为不同规模的数据集提供了灵活性。最后,neuralforecast 包含了一个广泛的模型集合,包括但不限于LSTM、RNN、TCN、N-BEATS、N-HiTS、ESRNN以及各种基于Transformer的预测模型等,都是以即插即用的方式实现,方便用户直接应用于各种时间序列预测场景。这些特性使得neuralforecast 成为那些寻求高效、精确且可解释时间序列预测模型的研究人员和实践者的有力工具。本文将利用neuralforecast 实现各种Transformer模型,并展示将它们应用于股票价格预测的简单例子。

2

环境配置

本地环境:

Python 3.8
IDE:Pycharm

库版本:

Pandas version: 2.0.3
Matplotlib version: 3.7.1
Neuralforecast version: 1.6.4

为了使用最新的其他模型,也可以直接fork neuralforecast的源码:

git clone https://github.com/Nixtla/neuralforecast.git
cd neuralforecast
pip install -e .

3

代码实现

步骤 1: 导入所需的库
  • 导入库:首先,导入处理数据所需的 pandas 库,绘图所需的 matplotlib.pyplot 库,以及 neuralforecast 中的多个模块。这些模块包括各种预测模型和评估指标函数。
import pandas as pd
from neuralforecast.models import VanillaTransformer, Informer, Autoformer, FEDformer, PatchTST
from neuralforecast.core import NeuralForecast
import matplotlib.pyplot as plt
from neuralforecast.losses.numpy import mae, rmse, mse
步骤 2: 数据准备
  • 读取数据:使用 pandas从 CSV 文件加载数据。这个数据集包含股票的每日收盘价。

  • 数据预处理:重命名列以符合模型的输入要求(例如,将日期列重命名为 ‘ds’,将收盘价列重命名为 ‘y’)。此外,将日期列转换为日期时间格式,并为数据集添加一个唯一标识符,这对于使用neuralforecast进行时间序列预测是必要的。

df = pd.read_csv('./000001_Daily_Close.csv')
df['unique_id'] = 1
df = df.rename(columns={'date': 'ds', 'Close': 'y'})
df['ds'] = pd.to_datetime(df['ds'])
步骤 3: 定义预测模型
  • 初始化模型:定义一个模型列表,每个模型都是 neuralforecast 库中的一个类的实例。对于每个模型,指定预测范围(horizon)、输入窗口大小(input_size)以及其他训练参数(如 max_steps, val_check_steps)。

  • 模型配置:这些参数决定了模型的训练方式,包括训练持续时间、评估频率和早停机制等。每个模型都有一些公共的参数以及它们自身的参数可以调整,这里均使用它们默认的参数进行模型初始化。

models = [VanillaTransformer(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3,scaler_type='standard'),Informer(h=horizon,  # Forecasting horizoninput_size=input_size,  # Input sizemax_steps=train_steps,  # Number of training iterationsval_check_steps=check_steps,  # Compute validation loss every 100 stepsearly_stop_patience_steps=3,  # Number of validation iterations before early stoppingscaler_type='standard'),  # Stop training if validation loss does not improveFEDformer(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3),Autoformer(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3),PatchTST(h=horizon,input_size=input_size,max_steps=train_steps,val_check_steps=check_steps,early_stop_patience_steps=3),]
步骤 4: 模型训练与交叉验证
  • 创建 NeuralForecast 实例:使用 NeuralForecast 类整合所有的模型。这个类提供了一个统一的接口来训练和评估多个模型。

  • 执行交叉验证:使用 cross_validation 方法对每个模型进行训练和评估。这个方法自动进行时间序列的交叉验证,分割数据集并评估模型在不同时间窗口上的性能。

nf = NeuralForecast(models=models,freq='B')Y_hat_df = nf.cross_validation(df=df,val_size=100,test_size=100,n_windows=None)
步骤 5: 数据筛选
  • 筛选数据点:通过选择特定的“cutoff”点来过滤 Y_hat_df 中的预测。这种筛选基于预测范围 horizon,确保评估是在均匀间隔的时间点上进行。
Y_plot = Y_hat_df
cutoffs = Y_hat_df['cutoff'].unique()[::horizon]
Y_plot = Y_plot[Y_hat_df['cutoff'].isin(cutoffs)]
步骤 6: 绘图与性能评估
  • 绘制预测结果:使用 matplotlib 绘制真实数据与每个模型的预测结果。这有助于直观地比较不同模型的预测准确性。

  • 计算评估指标:对每个模型,计算和打印均方根误差(RMSE)、平均绝对误差(MAE)和均方误差(MSE)等性能指标。这些指标提供了量化模型性能的方式。

plt.figure(figsize=(20, 5))
plt.plot(Y_plot['ds'], Y_plot['y'], label='True')
for model in models:plt.plot(Y_plot['ds'], Y_plot[model], label=model)rmse_value = rmse(Y_hat_df['y'], Y_hat_df[model])mae_value = mae(Y_hat_df['y'], Y_hat_df[model])mse_value = mse(Y_hat_df['y'], Y_hat_df[model])print(f'{model}: rmse {rmse_value:.4f} mae {mae_value:.4f} mse {mse_value:.4f}')plt.xlabel('Datestamp')
plt.ylabel('Close')
plt.grid()
plt.legend()
plt.show()
步骤 7: 结果展示
  • 展示图表:最后,显示绘制的图表。图表展示了不同模型在整个时间序列上的预测表现,允许直观地评估和比较模型。

5d185d6c7ec0781a5971ebf64ad56ad5.png

VanillaTransformer: rmse 56.5187 mae 38.8573 mse 3194.3650
Informer: rmse 52.2324 mae 39.1110 mse 2728.2239
FEDformer: rmse 48.9400 mae 35.9884 mse 2395.1237
Autoformer: rmse 58.5010 mae 45.7157 mse 3422.3614
PatchTST: rmse 48.5870 mae 36.1392 mse 2360.6968

在对比基于 Transformer 的各种模型在股票价格预测任务上的表现时,从可视化以及评估结果中,我们发现 FEDformer 和 PatchTST 在所有评估指标(RMSE、MAE、MSE)上表现最为出色,这可能归因于它们在处理长期依赖关系和捕获时间序列数据中的复杂模式方面的优势。相较之下,虽然 Informer 显示了合理的性能,但其表现略逊于 FEDformer 和 PatchTST。VanillaTransformer 和 Autoformer 的性能相对较差。这些结果强调了根据特定任务的需求选择合适的模型架构的重要性,同时也表明了在实际应用中进行模型选择时需要考虑到模型的特定优势和潜在的局限性。

4

总结

本文展示了如何使用 neuralforecast 实现多种 Transformer 模型(包括 Informer, Autoformer, FEDformer 和 PatchTST),并将它们应用于股票价格预测的简单示例。通过这个演示,我们可以看到 Transformer 模型在处理时间序列数据方面的潜力和灵活性。虽然我们的实验是初步的,但它为进一步的研究和应用提供了一个基础。读者可以在此基础上进行更深入的模型调优、特征工程和超参数实验,以提升预测性能。此外,这些模型的应用不限于股票价格预测,还可以扩展到其他领域的时间序列分析。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/715273.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Libero集成开发环境中Identify应用与提高

Libero集成开发环境中Identify应用与提高 Identify的安装

操作系统原理与实验——实验三优先级进程调度

实验指南 运行环境: Dev c 算法思想: 本实验是模拟进程调度中的优先级算法,在先来先服务算法的基础上,只需对就绪队列到达时间进行一次排序。第一个到达的进程首先进入CPU,将其从就绪队列中出队后。若此后队首的进程的…

多租户 TransmittableThreadLocal 线程安全问题

在一个多租户项目中,用户登录时,会在自定义请求头拦截器AsyncHandlerInterceptor将该用户的userId,cstNo等用户信息设置到TransmittableThreadLocal中,在后续代码中使用.代码如下: HeaderInterceptor 请求头拦截器 public class HeaderInterceptor implements Asyn…

阿里云国际云服务器全局流量分析功能详细介绍

进行全局流量分析时,内网DNS解析会作为一个整体模块,其他模块的边缘虚框颜色会置灰,示意作为一个整体进行全局分析,左侧Region可以展开/汇总,也可以单独选中某个Region模块进行分析(这时其他Region的流量线…

加密与安全_探索签名算法

文章目录 概述应用常用数字签名算法CodeDSA签名ECDSA签名小结 概述 在非对称加密中,使用私钥加密、公钥解密确实是可行的,而且有着特定的应用场景,即数字签名。 数字签名的主要目的是确保消息的完整性、真实性和不可否认性。通过使用私钥加…

云服务器购买教程

在购买云服务器之前,建议仔细评估自身需求和预算,并与多个云服务提供商进行比较,以确保选择到最适合的解决方案。购买云服务器的具体步骤可能因所选云服务提供商而异。以下以实际操作的方式介绍如何购买一款云服务器。 云服务器购买常见问题…

Linux进程——信号详解(上)

文章目录 信号入门生活角度的信号技术应用角度的信号用kill -l命令可以察看系统定义的信号列表信号处理常见方式概述 产生信号通过键盘进行信号的产生,ctrlc向前台发送2号信号通过系统调用异常软件条件 信号入门 生活角度的信号 你在网上买了很多件商品&#xff0…

前端面试练习24.3.2-3.3

HTMLCSS部分 一.说一说HTML的语义化 在我看来,它的语义化其实是为了便于机器来看的,当然,程序员在使用语义化标签时也可以使得代码更加易读,对于用户来说,这样有利于构建良好的网页结构,可以在优化用户体…

vue3项目中如何一个vue组件中的一个div里面的图片铺满整个屏幕样式如何设置

在Vue 3项目中,要使一个div内的图片铺满整个屏幕,你需要确保几个关键点:div元素和图片元素的样式设置正确,以及确保它们能够覆盖整个视口(viewport)。以下是一个简单的步骤和代码示例,帮助你实现…

【JavaSE】实用类——String、日期等

目录 String类常用方法String类的equals()方法String中equals()源码展示 “”和equals()有什么区别呢? StringBuffer类常用构造方法常用方法代码示例 面试题:String类、StringBuffer类和StringBuilder类的区别?日期类Date类Calendar类代码示例…

【vue3】命令式组件封装,message封装示例;(函数式组件?)

仅做代码示例;当然改进的地方还是不少的,仅作为该类组件封装方式的初步启发; 理想大成肯定是想要像 饿了么 这些组件库一样。 有的人叫这函数式组件,有的人叫这命令式组件,我个人还是偏向于命令式组件的称呼。因为以vu…

Django配置静态文件

Django配置静态文件 目录 Django配置静态文件静态文件配置调用方法 一般我们将html文件都放在默认templates目录下 静态文件放在static目录下 static目录大致分为 js文件夹css文件夹img文件夹plugins文件夹 在浏览器输入url能够看到对应的静态资源,如果看不到说明…

支持向量机算法(带你了解原理 实践)

引言 在机器学习和数据科学中,分类问题是一种常见的任务。支持向量机(Support Vector Machine, SVM)是一种广泛使用的分类算法,因其出色的性能和高效的计算效率而受到广泛关注。本文将深入探讨支持向量机算法的原理、特点、应用&…

13. Springboot集成Protobuf

目录 1、前言 2、Protobuf简介 2.1、核心思想 2.2、Protobuf是如何工作的? 2.3、如何使用 Protoc 生成代码? 3、Springboot集成 3.1、引入依赖 3.2、定义Proto文件 3.3、Protobuf生成Java代码 3.4、配置Protobuf的序列化和反序列化 3.5、定义…

【中英对照】【自译】【精华】麻省理工学院MIT技术双月刊(Bimonthly MIT Technology Review)2024年3/4月刊内容概览

一、说明 Notation 仅供学习、参考,请勿用于商业行为。 二、本期封面、封底 Covers 本期杂志购于新加坡樟宜机场Changi Airport Singapore,售价为20.50新元。 本期仍然关注伦敦的AI大会。(笔者十分想去,在伦敦和MIT校园均设有会…

IDEA的安装教程

1、下载软件安装包 官网下载:https://www.jetbrains.com/idea/ 2、开始安装IDEA软件 解压安装包,找到对应的idea可执行文件,右键选择以管理员身份运行,执行安装操作 3、运行之后,点击NEXT,进入下一步 4、…

GraphPad Prism 10: 你的数据,我们的魔法 mac/win版

GraphPad Prism 10是GraphPad Software公司推出的一款功能强大的数据分析和可视化软件。它集数据整理、统计分析、图表制作和报告生成于一体,为科研工作者、学者和数据分析师提供了一个高效、便捷的工作平台。 GraphPad Prism 10软件获取 Prism 10拥有丰富的图表类…

2023义乌最全“电商+跨境+直播”数据总结篇章!

值得收藏|2023义乌最全“电商跨境直播”数据总结篇章! 麦琪享资讯2024-01-20 14:28浙江 新年伊始,央视就把镜头对准了义乌电商,以电商的蓬勃之势展现这座国际商城的开放与活力。 过去的一年 义乌电商量质齐升 实力出圈 跑出了…

nginx 根据参数动态代理

一、问题描述 nginx反向代理配置一般都是配置静态地址,比如: server {listen 80;location / {proxy_pass http://myapp1;proxy_set_header Host $host;proxy_set_header X-Real-IP $remote_addr;}} 这个反向代理表示访问80端口跳转到 http://myapp1 …

腾讯云优惠券领取入口_先领取再下单_2024腾讯云优惠攻略

腾讯云优惠代金券领取入口共三个渠道,腾讯云新用户和老用户均可领取8888元代金券,可用于云服务器等产品购买、续费和升级使用,阿腾云atengyun.com整理腾讯云优惠券(代金券)领取入口、代金券查询、优惠券兑换码使用方法…