用循环神经网络预测股价

循环神经网络可以用来对时间序列进行预测,之前我们在介绍循环神经网络RNN,LSTM和GRU的时候都用到了正弦函数预测的例子,其实这个例子就是一个时间序列。而在众多的时间序列例子中,最普遍的就是股价的预测了,股价序列是一种很明显的时间序列,价格随时间变化,每天都有一个收盘价。本文就打算使用简单循环神经网络RNN和长短期记忆网络LSTM来对股价进行一下预测。

我们打算利用前N天的股票收盘价来预测下一日的股票收盘价,所以首先需要获取股票数据,这里我使用akshare接口来获取数据,个人觉得比tushare好用。

虽然变量名取了df_hs300,但我没有用沪深300指数,我选择了浙大网新这个股票,毕竟是自己学校下面的企业,支持一下:)

df_hs300 = ak.stock_zh_a_hist(symbol="600797", period="daily", start_date="20210101", end_date=datetime.datetime.today().strftime("%Y%m%d"), adjust="")

获取了从2021年1月1日到当前的股票数据,我们可以输出这个数据看一下:

而我这里只需要收盘价以及日期两个字段,并把收盘价进行归一化处理,更便于训练:

close_list = df_hs300['收盘'].values
date_list = df_hs300['日期'].values
close_list_norm=[price/max(close_list) for price in close_list]

可以打印出来看一下

%matplotlib inline
import matplotlib.pyplot as pltplt.plot(close_list_norm)
plt.title('hs_300')
plt.xlabel('date')
plt.ylabel('colse price')
plt.show()

下面,根据这个数据集定义一个Dataset和DataLoader,我选择用前10天的收盘价来预测下一个交易日的收盘价,所以时间步选择了10,并用前700个数据作为训练数据集。

from torch.utils.data import Dataset, DataLoader  class StockDataset(Dataset):def __init__(self, data_list, time_step = 10, transform=None):self.data = data_listself.features = []self.targets = []for i in range(len(self.data)-time_step):feature = [x for x in self.data[i:i+time_step]]y = self.data[i+time_step]#feature = torch.Tensor(feature)#feature = feature.unsqueeze(1)y = torch.tensor(y)  y = y.reshape(-1)self.features.append(feature)self.targets.append(y)self.features = torch.tensor(self.features)self.features = self.features.reshape(-1, time_step, 1)def __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.targets[idx]transform = transforms.Compose([transforms.ToTensor()])
dataset = StockDataset(close_list_norm[:700], time_step=10, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=True)

我先用简单循环神经网络来训练一下该数据集,下面定义了这个RNN模型以及一些初始化参数:

time_step = 10
batch_size = 1
#设计网络(单隐藏层Rnn)
input_size,hidden_size,output_size=1,20,1
#Rnn初始隐藏单元hidden_prev初始化
hidden_prev=torch.zeros(1,batch_size,hidden_size).cuda()
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.rnn=nn.RNN(input_size=input_size,    #输入特征维度,当前特征为股价,维度为1hidden_size=hidden_size,  #隐藏层神经元个数,或者也叫输出的维度num_layers=1,batch_first=True)self.linear=nn.Linear(hidden_size,output_size)def forward(self,X,hidden_prev):out,ht=self.rnn(X,hidden_prev)batch_size, seq,  hidden_size = out.shapeout = self.linear(out[:, -1, :])  # 其实就是取出输出的序列长度中的最后一个去进行线性运算,得到输出return out

定义一个训练方法:

model=Net()
model=model.cuda()
criterion=nn.MSELoss()
learning_rate,epochs=0.01,500
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
for epoch in range(epochs):losses = []for X,y in dataloader:X = X.cuda()y = y.cuda()y=y.to(torch.float32)X=X.to(torch.float32)#print("X.shape: ",X.shape)#print("y.shape: ",y.shape)optimizer.zero_grad()yy=model(X,hidden_prev)yy=yy.cuda()#print("yy.shape: ",yy.shape)#print(yy)#print(y)loss = criterion(y, yy)model.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())epoch_loss=sum(losses)/len(losses)if epoch%50==0:   #保留验证集损失最小的模型参数print("epoch:{},loss:{:.8f}".format(epoch+1,epoch_loss))
torch.save(model, "model2.pt")
# 输出:
epoch:1,loss:0.00685027
epoch:51,loss:0.00065118
epoch:101,loss:0.00120512
epoch:151,loss:0.00215360
epoch:201,loss:0.00149827
epoch:251,loss:0.00173493
epoch:301,loss:0.00188238
epoch:351,loss:0.00167589
epoch:401,loss:0.00165730
epoch:451,loss:0.00160637

我们把训练后的模型在验证数据集上测试一下,首先定义验证数据集,训练数据集选择所有数据的前700个数据,验证数据集就选择700个以后的数据作为验证数据集。

transform = transforms.Compose([transforms.ToTensor()])
val_dataset = StockDataset(close_list_norm[700:], time_step=10, transform=transform)
val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)

定义验证方法

model = torch.load('model2.pt')Val_y,Val_predict=[],[]
#将归一化后的数据还原
Val_max_price=max(close_list) 
for X,y in val_dataloader:with torch.no_grad():X = X.cuda()y=y.to(torch.float32)X=X.to(torch.float32)print("X: ",X)predict=model(X,hidden_prev)y=y.cpu()predict=predict.cpu()print("y: ",y)print("predict: ",predict)# 把股价还原为归一化之前的股价Val_y.append(y[0][0]*Val_max_price) Val_predict.append(predict[0][0]*Val_max_price)fig=plt.figure(figsize=(8,5),dpi=80)
# 红色表示真实值,绿色表示预测值
plt.plot(Val_y,linestyle='--',color='r')
plt.plot(Val_predict,color='g')
plt.title('stock price')
plt.xlabel('time')
plt.ylabel('price')
plt.show()

我们可以看到,总体趋势是一致的,但是真实值和预测值之间的差距确实有点大,那么我们接下来看一下LSTM网络的模型表现如何:

class Net_LSTM(nn.Module):def __init__(self):super(Net_LSTM,self).__init__()self.lstm=nn.LSTM(input_size=input_size,    #输入特征维度,当前特征为股价,维度为1hidden_size=hidden_size,  #隐藏层神经元个数,或者也叫输出的维度num_layers=1,batch_first=True)self.linear=nn.Linear(hidden_size,output_size)def forward(self,X):out,ht=self.lstm(X)      batch_size, seq,  hidden_size = out.shapeout = self.linear(out[:, -1, :])  # 其实就是取出输出的序列长度中的最后一个去进行线性运算,得到输出return out

因为训练和验证的函数和用RNN训练和验证的函数基本是一致的,我就不赘述了,我们来看看利用LSTM进行训练后的模型,在验证集上的表现如何:

可以看到,这个效果比起用简单循环神经网络RNN好上了很多,可见LSTM的效果确实比简单RNN要提高了不少。这只是一个例子而已,不建议根据这个结果去进行投资,因为预测结果在细节上和原始数据还是有不少差别的,而且只能验证下一个交易日的情况,如果预测时间稍微拉长,效果就会急剧下降,并当前预测都是在前期的数据集的基础上进行预测,如果有突发事件的发生,模型是捕捉不到的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/16964.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

链表练习题

返回倒数第K个节点 快慢指针 让快指针先走k步,再使得快指针与慢指针同时走一步,这样没有开额外空间,空间复杂度较低。 代码实现如下: struct ListNode {int val;struct ListNode* next;}; int kthToLast(struct ListNode* head…

第 52 期:MySQL 半同步复制频繁报错

社区王牌专栏《一问一实验:AI 版》全新改版归来,得到了新老读者们的关注。其中不乏对 ChatDBA 感兴趣的读者前来咨询,表达了想试用体验 ChatDBA 的意愿,对此我们表示感谢 🤟。 目前,ChatDBA 还在最后的准备…

el-table实现合并特定列的所有行

el-table实现合并特定列的所有行 示例: 在这里插入图片描述 const objectSpanMethod ({ row, column, rowIndex, columnIndex }) > {if (columnIndex 5 || columnIndex 7) {// 就是只保留第一行,其他直接不要,然后行数是列表长度if …

2024年03月 Python(一级)真题解析#中国电子学会#全国青少年软件编程等级考试

Python等级考试(1~6级)全部真题・点这里 一、单选题(共25题,共50分) 第1题 下列哪个命令,可以将2024转换成’2024’ 呢?( ) A:str(2024) B:int(2024) C:float(2024) D:bool(2024) 答案:A 本题考察的是str() 语句,将数字转换成字符串用到的是str() 语句。 …

Java:IO

首 java.io中有百万计的类,如何找到自己需要的部分? 流 IO涉及到一个“流”stream的概念,可以简单理解成数据从一个源头到一个目的地。明白数据从哪来,要到哪里去,数据流中是字节还是字符之后,才能找到自…

由于找不到d3dx9_39.dll,无法继续执行代码的5种解决方法

在现代科技发展的时代,电脑已经成为我们生活中不可或缺的一部分。然而,由于各种原因,我们可能会遇到一些电脑问题,其中之一就是“d3dx9_39.dll丢失”。这个问题可能会导致我们在运行某些游戏或应用程序时遇到错误提示,…

新品 | Forge® 1GigE IP67工业相机助力智能农业、食品和饮料行业

近日,51camera的合作伙伴Teledyne FLIR IIS推出Forge 1GigE IP67,它是Forge系列的最新工业相机,旨在在恶劣的工业环境中运行,同时确保高效的生产能力。Forge 1GigE IP67致力于为工厂自动化提供先进成像系统的最新产品。 Forge 1GigE IP67相机…

MyBatis多数据源配置与使用,基于ThreadLocal+AOP

导读 MyBatis多数据源配置与使用其一其二1. 引依赖2. 配置文件3. 编写测试代码4. 自定义DynamicDataSource类5. DataSourceConfig配置类6. AOP与ThreadLocal结合7. 引入AOP依赖8. DataSourceContextHolder9. 自定义注解UseDB10. 创建切面类UseDBAspect11. 修改DynamicDataSourc…

PTA 计算矩阵两个对角线之和

计算一个nn矩阵两个对角线之和。 输入格式: 第一行输入一个整数n(0<n≤10)&#xff0c;第二行至第n1行&#xff0c;每行输入n个整数&#xff0c;每行第一个数前没有空格&#xff0c;每行的每个数之间各有一个空格。 输出格式: 两条对角线元素和&#xff0c;输出格式见样例…

Android存储系统成长记

用心坚持输出易读、有趣、有深度、高质量、体系化的技术文章 本文概要 您一定使用过Context的getFileStreamPath方法或者Environment的getExternalStoragePublicDirectory方法&#xff0c;甚至还有别的方法把数据存储到文件中&#xff0c;这些都是存储系统提供的服务&#x…

PTA 判断两个矩阵相等

Peter得到两个n行m列矩阵&#xff0c;她想知道两个矩阵是否相等&#xff0c;请你用“Yes”&#xff0c;“No”回答她&#xff08;两个矩阵相等指的是两个矩阵对应元素都相等&#xff09;。 输入格式: 第一行输入整数n和m&#xff0c;表示两个矩阵的行与列&#xff0c;用空格隔…

修改元组元素

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 场景模拟&#xff1a;伊米咖啡馆&#xff0c;由于麝香猫咖啡需求量较大&#xff0c;库存不足&#xff0c;店长想把它换成拿铁咖啡。 实例08 将麝香猫…

chrome浏览器驱动下载

跑自动化的时候&#xff0c;需要打开谷歌浏览器&#xff0c;这个时候提示浏览器驱动找不到咋办呢&#xff1f; 1、网上搜索找到了这篇文章&#xff1a;https://www.cnblogs.com/laoluoits/p/17710501.html&#xff1b;按照文章介绍&#xff0c; 首先找到&#xff1a;CNPM Bin…

D - Permutation Subsequence(AtCoder Beginner Contest 352)

题目链接: D - Permutation Subsequence (atcoder.jp) 题目大意&#xff1a; 分析&#xff1a; 相对于是记录一下每个数的位置 然后再长度为k的区间进行移动 然后看最大的pos和最小的pos的最小值是多少 有点类似于滑动窗口 用到了java里面的 TreeSet和Map TreeSet存的是数…

解决 Spring Boot 应用启动失败的问题:Unexpected end of file from server

解决 Spring Boot 应用启动失败的问题&#xff1a;Unexpected end of file from server 博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的…

Spring AOP失效的场景事务失效的场景

场景一&#xff1a;使用this调用被增强的方法 下面是一个类里面的一个增强方法 Service public class MyService implements CommandLineRunner {private MyService myService;public void performTask(int x) {System.out.println("Executing performTask method&quo…

爬虫学习--15.进程与线程(2)

线程锁 当多个线程几乎同时修改某一个共享数据的时候&#xff0c;需要进行同步控制 某个线程要更改共享数据时&#xff0c;先将其锁定&#xff0c;此时资源的状态为"锁定",其他线程不能改变&#xff0c;只到该线程释放资源&#xff0c;将资源的状态变成"非锁定…

Linux如何设置共享文件夹

打开虚拟机->菜单->虚拟机设置->选项->共享文件夹->总是启用。点击添加按钮->弹出添加向导->点击浏览按钮&#xff0c;从windows中选择一个文件夹&#xff0c;确定即可。

[Windows] GIF动画、动图制作神器 ScreenToGif(免费)

ScreenToGif 是开源免费的 Gif 动画录制工具&#xff0c;小巧原生单文件&#xff0c;功能很实用。它有录制屏幕、录制摄像头、录制画板、图像编辑器等功能&#xff0c;可以将屏幕任何区域及操作过程录制成 GIF 格式的动态图像。保存前还可对 GIF 图像编辑优化&#xff0c;支持自…

研二学妹面试字节,竟倒在了ThreadLocal上,这是不要应届生还是不要女生啊?

一、写在开头 今天和一个之前研二的学妹聊天&#xff0c;聊及她上周面试字节的情况&#xff0c;着实感受到了Java后端现在找工作的压力啊&#xff0c;记得在18&#xff0c;19年的时候&#xff0c;研究生计算机专业的学生&#xff0c;背背八股文找个Java开发工作毫无问题&#x…