使用LSTM神经网络对股票日线行情进行回归训练(Pytorch版)

版权声明:本文为博主原创文章,如需转载请贴上原博文链接:使用LSTM神经网络对股票日线行情进行回归训练(Pytorch版)-CSDN博客


前言:近期在尝试使用lstm对股票日线数据进行拟合,初见成型但是效果不甚理想(过拟合/欠拟合),原因可能有如下几点:①前传网络的架构不够完善;②网络参数组合不合理;③使用CPU而非GPU进行的训练(性能低导致效果不佳);④Epoch数过少。暂且先记录验证结果以便后续做出相对应的调整,完整代码见文末。


目录

〇、各依赖包版本

一、Pytorch中LSTM网络参数说明

二、网络结构构建及参数组合选择

2.1 网络结构构建

2.2 参数组合选择

三、验证结果展示

3.1 优化器ADAM&SGD的验证效果情况

3.2 Batch_Size=64&320的验证效果情况

3.3 num_layers=1、2和dropout=0、0.2、0.4的验证效果情况

参考文献


〇、各依赖包版本

mplfinance==0.12.9b7
pandas==1.1.5
SQLAlchemy==1.4.41
SQLAlchemy_Utils==0.41.2
tushare==1.2.85
backtrader==1.9.78.123
akshare==1.10.42
torch==1.13.1
numpy==1.21.6
matplotlib==3.5.3

一、Pytorch中LSTM网络参数说明

图1.1 lstm参数

具体参照LSTM:其中`bias`和`proj_size`使用默认值,其余参数可在代码中自行修改。

二、网络结构构建及参数组合选择

import torch# 选择在cpu或gpu上跑训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cuda:0')	# 如果知道gpu有多少核,可以直接指定# 如果只有cpu,则以下代码可不执行
# 加载模型到device上
model = SimpleLSTM()
model.to(device)# 加载数据到device上
X, y = X.to(device), y.to(device)

2.1 网络结构构建

        LSTM网络核心的便是网络结构的构建,参考网络上有关lstm网络的文章,多数都是构建最普遍的网络结构,如下:

import torch.nn as nnclass SimpleLSTM(nn.Module):def __init__(self, INPUT_SIZE, HIDDEN_SIZE, OUTPUR_SIZE, NUM_LAYERS):super(SimpleLSTM, self).__init__()self.lstm = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, batch_first=True)self.fc = nn.Linear(HIDDEN_SIZE, OUTPUR_SIZE)self.sigmoid = nn.Sigmoid()def forward(self, X):X, hidden = self.lstm(X, None)X = self.fc(X[:, -1, :])X = self.sigmoid(X)return Xmodel = SimpleLSTM(INPUT_SIZE, HIDDEN_SIZE, OUTPUR_SIZE, NUM_LAYERS)

        通常这样的结构完全够用了,但是为了方便后续能够对网络进行灵活的调整,还是将隐藏状态h_0、h_n及单元状态c_0、c_n开放出来,构建更加完整的网络结构,同时对输入数据进行批标准化处理(取消了对输出数据的非线性激活,待后续研究后再考虑是否增加该层),后续训练及验证都使用以下网络结构:

import torch
import torch.nn as nnclass SimpleLSTM(nn.Module):def __init__(self, INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS):super(SimpleLSTM, self).__init__()self.D = 1if BIDIRECT:self.D = 2self.h_0 = torch.randn(self.D * NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)self.c_0 = torch.randn(self.D * NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)self.h_n = torch.zeros(self.D * NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)self.c_n = torch.zeros(self.D * NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)self.lstm = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, batch_first=BATCH1ST, dropout=DROPOUT, bidirectional=BIDIRECT)self.fc = nn.Linear(HIDDEN_SIZE * self.D, OUTPUT_SIZE)self.bn = nn.BatchNorm1d(num_features=WINDOW_SIZE)def forward(self, input_, h_0, c_0):input_ = self.bn(input_)x, (h_n, c_n) = self.lstm(input_, (h_0, c_0))x = self.fc(x)return x, (h_n, c_n)model = SimpleLSTM(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, NUM_LAYERS)

2.2 参数组合选择

# === 定义固定超参(其中SHUFFLE、BIDIRECT可不开放出来,直接使用默认参数)
STOCK_TSCODE = '000001.SZ'  # 平安YH# DATA parameters
TRAIN_SCALE = 0.8	# 训练集和测试/验证集比例=8:2
WINDOW_SIZE = 1000	# DATALOADER parameters
SHUFFLE = False	# 对于有序的数据,每个Epoch不需要打乱
DROP_LAST = True	# 对于含有初始隐藏状态h_0及初始单元状态c_0的网络,将多于的数据舍弃以便能完成训练(对于数据形式不规整的数据集,可以构建不含有h_0和c_0的网络结构,或者修改TRAIN_SCALE比例亦或将测试/验证集的数据量固定,使得全部数据得以测试、验证)# train patameters
EPOCHS = 1	# 只训练一个Epoch
LEARNING_RATE = 1e-2
MOMENTUM = 0.9# lstm parameters
INPUT_SIZE = 1
HIDDEN_SIZE = 128
OUTPUT_SIZE = 1BATCH1ST = True	# 将batch_size参数置于首位
BIDIRECT = False	# 不使用双向网络
# === 定义可变超参 LSTM parameters(可以通过修改以下参数来调整网络的训练效果,注释中提供2~3中参数可供选择)
OPTIM_NAME = ''  # 'ADAM'、'SGD'
BATCH_SIZE = 64	# 64、320
NUM_LAYERS = 2	# 1、2
DROPOUT = 0.2	# 0、0.2、0.4

*当`SHUFFLE=True`时会将时序的数据打乱,起到反作用;

*当`BIDIRECT=True`时会消耗大量内存(用cpu跑训练不建议设置为True)且对于某些类型的网络并不需要使用双向网络模型,故虽然开放该参数但仅给一个拟合结果以作对比(见图2.1),其余需自行训练;

图2.1 BIDIRECT=True的训练模型拟合验证效果

        根据上述可变超参,对于同一种网络结构一共可以组成4种属性共24种不同的组合,如下表2-1所示:

表2-1 参数组合

lstm modelnum_layersbatch_sizedropoutoptim_name
lstm(input, (h_0, c_0))1、264、3200、0.2、0.4Adam、SGD

        后续所有验证结果图片名称都是以各参数名称及数值来命名的,范例如下:

lstm_model_BS64_EP1_NL1_DO0_OPADAM_DLTrue_BDFalse
                        |        |        |        |               |              |           |
                        |   Epoch    |   dropout        |              |   bidirectional
                batch_size         |                 optimizer       |
                                 num_layers                        drop_last

三、验证结果展示

        下表3-1给出24种参数组合的验证结果,这24种验证结果使用的是都只经过一次训练的模型(即这些模型只经过一个Epoch的训练),不排除经过多次训练后的模型得出的验证效果会和下表呈现出不同的情况:

表3-1 24种参数组合验证效果

num_layersbatch_sizedropoutoptim_name拟合效果
1640ADAM
1640.2ADAM
1640.4ADAM
13200ADAM×
13200.2ADAM×
13200.4ADAM×
2640ADAM
2640.2ADAM
2640.4ADAM
23200ADAM×
23200.2ADAM×
23200.4ADAM×
1640SGD×
1640.2SGD×
1640.4SGD×
13200SGD×
13200.2SGD×
13200.4SGD×
2640SGD×
2640.2SGD×
2640.4SGD×
23200SGD×
23200.2SGD×
23200.4SGD×

3.1 优化器ADAM&SGD的验证效果情况

        从上表可见,ADAM比SGD的效果更好,这也很符合优化器进化的顺序。在12种优化器是SGD的验证结果中,选出效果最好的一种,如图3.1所示,蓝色点是收盘价,橙色线是预测值,相差很远,这样的结果也很难运用在后面的策略中。但即便这样也不能完全说SGD不好,因为只跑了一个回合,跑多一些回合效果可能会变好甚至特定时候会超过多回合的ADAM,这得需要各位自己试验了。至此后续仅针对ADAM优化器进行讨论。

图3.1 优化器为SGD的训练模型的验证效果

3.2 Batch_Size=64&320的验证效果情况

        从表3-1可见,即便优化器是ADAM,batch_size=320的效果也没有batch_size=64的效果好,如图3.2所示,欠拟合的情况很明显,当然batch_size的选择还要根据机器的性能来设置,或许batch_size=32的效果比64更好。

图3.2 batch_size为320的训练模型的验证效果

3.3 num_layers=1、2和dropout=0、0.2、0.4的验证效果情况

        剩下的参数组合见图3.3,共三行两列6张图,左列num_layers=1,右列=2;自上而下dropout=0、0.2、0.4,当lstm只有一层(左列),波动很大,DO0预测值和实际值贴合很近(过拟合),DO0.2欠拟合情况较为明显;而当lstm有两层的时候,预测的较大值和较小值均在实际值范围内,算是预测较准确的,但是局部依旧存在过拟合的情况,还达不到制定策略的要求。

图3.3 多种参数组合验证效果对比

        至此lstm网络的训练基本完成,但过拟合的问题尚未解决且并不能通过控制DROPOUT系数来完全避免,后续可以尝试构建更深层的网络结构或尝试使用GRU来看是否有所改善。

        相关代码见:demo_lstm_model.py,同目录下的`data`文件夹保存了全部24种参数组合的训练模型(一个Epoch),`lstm(input, (h_0, c_0))`文件夹下保存了“ADAM”和“SGD”共24张验证结果图片。

参考文献

1.基于深度学习的股票预测(完整版,有代码)

2.【python量化】基于backtrader的深度学习模型量化回测框架

3.AI金融:利用LSTM预测股票每日最高价

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890000.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

睡岗和玩手机数据集,4653张原始图,支持YOLO,VOC XML,COCO JSON格式的标注

睡岗和玩手机数据集,4653张原始图,支持YOLO,VOC XML,COCO JSON格式的标注 数据集分割 训练组70% 3257图片 有效集20% 931图片 测试集10% 465图片 预处理 没有采用任何预处…

Pandas 索引

在 Pandas 中,索引(Index)是 DataFrame 和 Series 的核心组成部分,用于标识和访问数据。索引提供了快速、灵活和强大的数据检索方法。以下是关于 Pandas 索引的一些关键点: 1. 创建索引 当创建一个 DataFrame 或 Seri…

labml.ai Deep Learning Paper Implementations (带注释的 PyTorch 版论文实现)

labml.ai Deep Learning Paper Implementations {带注释的 PyTorch 版论文实现} 1. labml.ai2. labml.ai Deep Learning Paper Implementations3. Sampling Techniques for Language Models (语言模型的采样技术)4. Multi-Headed Attention (MHA)References 1. labml.ai https…

使用 Marp 将 Markdown 导出为 PPT 后不可编辑的原因说明及解决方案

Marp 是一个流行的 Markdown 演示文稿工具,能够将 Markdown 文件转换为 PPTX 格式。然而,用户在使用 Marp 导出 PPT 时,可能会遇到以下问题: 导出 PPT 不可直接编辑的原因 根据 Marp GitHub 讨论,Marp 导出的 PPTX 文…

构建一个rust生产应用读书笔记四(实战2)

此门课程学习采用actix-web框架完成一个生产级别的rust应用,在 actix-web 中,Extractors 是一个非常重要的概念,它们用于从传入的 HTTP 请求中提取特定的信息片段。actix-web 提供了多种内置的提取器,以满足常见的使用场景。说白了…

优选生产报工系统:关键选择要素

【优选生产报工系统:数据分析、产品管理与基础数据登录的关键选择要素】 在快速变化的制造业环境中,生产报工系统的重要性不言而喻。它不仅仅是一种记录工时和监控生产进度的工具,更是一种能够实现数据驱动决策、优化产品管理和确保基础数据…

使用Python打造高效的PDF文件管理应用(合并以及分割)

在日常工作和学习中,我们经常需要处理大量PDF文件。手动合并、分割PDF不仅耗时,还容易出错。今天,我们将使用Python的wxPython和PyMuPDF库,开发一个强大且易用的PDF文件管理工具。 C:\pythoncode\new\mergeAndsplitPdf.py 所有代…

【C语言程序设计——入门】C语言程序开发环境(头歌实践教学平台习题)【合集】

目录&#x1f60b; <第1关&#xff1a;程序改错> 任务描述 相关知识 编程要求 测试说明 我的通关代码: 测试结果&#xff1a; <第2关&#xff1a;scanf 函数> 任务描述 相关知识 编程要求 测试说明 我的通关代码: 测试结果&#xff1a; <第1关&a…

皮肤伤口分割数据集labelme格式248张5类别

数据集格式&#xff1a;labelme格式(不包含mask文件&#xff0c;仅仅包含jpg图片和对应的json文件) 图片数量(jpg文件个数)&#xff1a;284 标注数量(json文件个数)&#xff1a;284 标注类别数&#xff1a;5 标注类别名称:["bruises","burns","cu…

JVM系列之内存区域

每日禅语 有一位年轻和尚&#xff0c;一心求道&#xff0c;多年苦修参禅&#xff0c;但一直没有开悟。有一天&#xff0c;他打听到深山中有一古寺&#xff0c;住持和尚修炼圆通&#xff0c;是得道高僧。于是&#xff0c;年轻和尚打点行装&#xff0c;跋山涉水&#xff0c;千辛万…

大腾智能CAD:国产云原生三维设计新选择

在快速发展的工业设计领域&#xff0c;CAD软件已成为不可或缺的核心工具。它通过强大的建模、分析、优化等功能&#xff0c;不仅显著提升了设计效率与精度&#xff0c;还促进了设计思维的创新与拓展&#xff0c;为产品从概念构想到实体制造的全过程提供了强有力的技术支持。然而…

leetcode 3195.包含所有1的最小矩形面积I

1.题目要求: 2.解题步骤: class Solution { public:int minimumArea(vector<vector<int>>& grid) {//设置二维数组deque<deque<int>> row_distance;for(int i 0;i < grid.size();i){//遍历数组&#xff0c;把每行头部1的小标和尾部1的下标代…

搭建Tomcat(三)---重写service方法

目录 引入 一、在Java中创建一个新的空项目&#xff08;初步搭建&#xff09; 问题&#xff1a; 要求在tomcat软件包下的MyTomcat类中编写main文件&#xff0c;实现在MyTomcat中扫描myweb软件包中的所有Java文件&#xff0c;并返回“WebServlet(url"myFirst")”中…

Linux介绍与安装CentOS 7操作系统

什么是操作系统 操作系统&#xff0c;英⽂名称 Operating System&#xff0c;简称 OS&#xff0c;是计算机系统中必不 可少的基础系统软件&#xff0c;它是 应⽤程序运⾏以及⽤户操作必备的基础环境 ⽀撑&#xff0c;是计算机系统的核⼼。 操作系统的作⽤是管理和控制计算机系…

【Linux】深入理解进程信号机制:信号的产生、捕获与阻塞

&#x1f3ac; 个人主页&#xff1a;谁在夜里看海. &#x1f4d6; 个人专栏&#xff1a;《C系列》《Linux系列》《算法系列》 ⛰️ 时间不语&#xff0c;却回答了所有问题 目录 &#x1f4da;前言 &#x1f4da;一、信号的本质 &#x1f4d6;1.异步通信 &#x1f4d6;2.信…

【西门子PLC.博途】——面向对象编程及输入输出映射FC块

当我们做面向对象编程的时候&#xff0c;需要用到输入输出的映射。这样建立的变量就能够被复用&#xff0c;从而最大化利用了我们建立的udt对象。 下面就来讲讲映射是什么。 从本质上来说&#xff0c;映射就是拿实际物理对象对应程序虚拟对象&#xff0c;假设程序对象是I0.0&…

MySQL索引的理解

MySQL与磁盘的交互 根据冯诺依曼结构体系&#xff0c;我们知道我们任何上层的应用想要去访问磁盘就必须要通过内存来访问&#xff0c;MySQL作为一款储存数据的服务&#xff0c;肯定是很多时间要用来访问磁盘。而大量访问磁盘一定会影响运行效率的在innoDB的存储引擎下为了减少…

分布式全文检索引擎ElasticSearch-数据的写入存储底层原理

一、数据写入的核心流程 当向 ES 索引写入数据时&#xff0c;整体流程如下&#xff1a; 1、客户端发送写入请求 客户端向 ES 集群的任意节点&#xff08;称为协调节点&#xff0c;Coordinating Node&#xff09;发送一个写入请求&#xff0c;比如 index&#xff08;插入或更…

Maven 生命周期

文章目录 Maven 生命周期- Clean 生命周期- Build 生命周期- Site 生命周期 Maven 生命周期 Maven 有以下三个标准的生命周期&#xff1a; Clean 生命周期&#xff1a; clean&#xff1a;删除目标目录中的编译输出文件。这通常是在构建之前执行的&#xff0c;以确保项目从一个…

Android Studio AI助手---Gemini

从金丝雀频道下载最新版 Android Studio&#xff0c;以利用所有这些新功能&#xff0c;并继续阅读以了解新增内容。 Gemini 现在可以编写、重构和记录 Android 代码 Gemini 不仅仅是提供指导。它可以编辑您的代码&#xff0c;帮助您快速从原型转向实现&#xff0c;实现常见的…