使用 PyTorch 构建 LSTM 股票价格预测模型

目录

      • 引言
      • 准备工作
      • 1. 训练模型(`train.py`)
      • 2. 模型定义(`model.py`)
      • 3. 测试模型和可视化(`test.py`)
      • 使用说明
      • 模型调整
      • 结论

引言

在金融领域,股票价格预测是一个重要且具有挑战性的任务。随着深度学习的发展,长短期记忆网络(LSTM)因其在处理时间序列数据方面的出色表现而受到关注。本篇博客将指导你如何使用PyTorch构建一个LSTM模型来预测股票价格,我们将逐步介绍数据预处理、模型训练和结果可视化的完整流程。

准备工作

  1. 安装依赖
    确保你已经安装了以下 Python 库:

    pip install pandas numpy torch matplotlib scikit-learn
    
  2. 下载数据
    使用 yfinance 库下载你感兴趣的股票的历史数据,并保存为 CSV 文件。我们这里使用 Apple(AAPL)过去五年的数据,文件命名为 AAPL_5y_data.csv。以下是一个下载数据的代码示例:

    import yfinance as yf# 下载Apple股票过去5年的数据
    data = yf.download('AAPL', start='2019-01-01', end='2024-01-01')
    data.to_csv('AAPL_5y_data.csv')
    

1. 训练模型(train.py

在这个脚本中,我们将读取 CSV 文件,归一化数据,并使用 LSTM 模型进行训练。

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类# 设置随机种子
torch.manual_seed(42)# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]def create_dataset(data, time_step=1):X, y = [], []for i in range(len(data) - time_step - 1):a = data[i:(i + time_step)]X.append(a)y.append(data[i + time_step, 0])  # 预测收盘价return np.array(X), np.array(y)# 创建数据集
time_step = 50  # 时间步长
X_train, y_train = create_dataset(train_data, time_step)# 转换为PyTorch张量
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float().view(-1, 1)# 初始化模型、损失函数和优化器
model = LSTM()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 训练模型
num_epochs = 300
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'lstm_model.pth')
print("模型已保存为 'lstm_model.pth'")

2. 模型定义(model.py

在这个文件中定义 LSTM 模型结构。

import torch
import torch.nn as nnclass LSTM(nn.Module):def __init__(self):super(LSTM, self).__init__()self.lstm = nn.LSTM(input_size=5, hidden_size=100, num_layers=2, batch_first=True)self.fc = nn.Linear(100, 1)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])  # 取最后时间步的输出return out

3. 测试模型和可视化(test.py

在这个脚本中,我们将加载训练好的模型,并使用测试数据进行预测和可视化。

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from model import LSTM  # 导入LSTM类# 设置字体为SimHei,用于显示中文
plt.rcParams['font.family'] = 'SimHei'# 读取CSV文件
file_path = 'AAPL_5y_data.csv'  # 替换为你的CSV文件路径
data = pd.read_csv(file_path)# 确保日期列是 datetime 类型
data['Date'] = pd.to_datetime(data['Date'])
data.set_index('Date', inplace=True)# 选择多特征:'Close', 'Open', 'High', 'Low', 'Volume'
features = data[['Close', 'Open', 'High', 'Low', 'Volume']].values# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(features)# 准备训练和测试数据
train_size = int(len(scaled_data) * 0.8)
train_data = scaled_data[:train_size]
test_data = scaled_data[train_size:]def create_dataset(data, time_step=1):X, y = [], []for i in range(len(data) - time_step - 1):a = data[i:(i + time_step)]X.append(a)y.append(data[i + time_step, 0])  # 预测收盘价return np.array(X), np.array(y)# 创建测试数据集
time_step = 50  # 时间步长
X_test, y_test = create_dataset(test_data, time_step)# 转换为PyTorch张量
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float().view(-1, 1)# 加载模型
model = LSTM()
model.load_state_dict(torch.load('lstm_model.pth'))
model.eval()# 测试模型
with torch.no_grad():test_outputs = model(X_test)# test_outputs 是预测的收盘价,将其重新归一化为原始价格test_outputs = scaler.inverse_transform(np.concatenate((test_outputs.numpy(), np.zeros((test_outputs.shape[0], 4))), axis=1))[:, 0]  # 反归一化收盘价y_test_inverse = scaler.inverse_transform(np.concatenate((y_test.numpy(), np.zeros((y_test.shape[0], 4))), axis=1))[:, 0]# 可视化结果
plt.figure(figsize=(14, 7))
plt.plot(data.index[-len(y_test):], y_test_inverse, label='真实价格', color='blue')
plt.plot(data.index[-len(test_outputs):], test_outputs, label='预测价格', color='red')
plt.title('股票价格预测')
plt.xlabel('日期')
plt.ylabel('价格')
plt.legend()
plt.show()

使用说明

  1. 保存脚本

    • 将训练脚本代码保存为 train.py
    • 将模型定义代码保存为 model.py
    • 将测试脚本代码保存为 test.py
  2. 运行训练

    • 在命令行中运行训练脚本:
      python train.py
      
    • 训练完成后,模型将保存为 lstm_model.pth
  3. 运行测试和可视化

    • 在命令行中运行测试脚本:

      python test.py
      
    • 这将加载已训练的模型,并可视化预测结果。
      在这里插入图片描述
      这只是一个演示,模型的预测效果还有待进一步优化。

模型调整

如果预测的价格和真实价格差距较大,可能是由于以下几个原因:

  1. 数据规模不足

    • 如果训练数据不足,模型可能无法学到市场的长期趋势。
    • 改进:使用更多的历史数据,尽量包括多年的数据。可以尝试增加数据的时间跨度。
  2. 数据预处理问题

    • 数据没有正确归一化,或归一化范围过窄。
    • 改进:检查 MinMaxScaler 的应用。你可以尝试不同的归一化范围,例如 (0, 1)(-1, 1),也可以使用其他标准化方法(例如 StandardScaler)。
  3. 模型复杂度不足

    • 模型的层数或隐藏单元数量可能不足以捕捉数据的复杂性。
    • 改进:增加 LSTM 的隐藏层数量或隐藏单元数量。你还可以考虑添加其他类型的层,例如卷积层(CNN)或全连接层,以提高模型的表达能力。
  4. 超参数调整

    • 学习率、批大小和时间步长等超参数可能需要调整以优化模型性能。
    • 改进:尝试不同的学习率(例如,0.001、0.0001 等)、不同的批大小(如 16、32、64)和时间步长(如 30、60)。
  5. 更改损失函数

    • 在某些情况下,使用不同的损失函数可能有助于模型的收敛。
    • 改进:可以尝试使用其他损失函数,例如 Huber 损失函数(nn.SmoothL1Loss)或自定义损失函数,以更好地适应数据。

结论

通过使用 PyTorch 构建 LSTM 模型,我们成功地实现了股票价格的预测。在这个过程中,我们学习了如何处理时间序列数据,构建和训练深度学习模型,以及如何评估和可视化预测结果。尽管模型的性能可能需要进一步的优化和调整,但这个示例为未来的工作奠定了基础。

希望这篇博客能够帮助你在股票价格预测方面取得更好的成果。欢迎分享你的成果和经验,或者提出你的问题!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/56233.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

即时通讯增加Redis渠道

情况说明 在本地和服务器分别启动im服务,当本地发送消息时,会发现服务器上并没有收到消息 初版im只支持单机版,不支持分布式的情况。此次针对该情况对项目进行优化,文档中贴出的代码非完整代码,可自行查看参考资料[2] 代码结构调…

Docker安装ocserv教程(效果极佳)

本章教程,介绍如何在Debain系统上安装ocserv。安装方式是使用Docker方式部署。 一、安装Docker curl -sSL https://file.ewbang.com/docker/debian/install_docker.sh -o install_docker.sh && bash install_docker.sh二、拉取镜像 docker pull tommylau/ocserv

Jsoup在Java中:解析京东网站数据

对于电商网站如京东来说,其页面上的数据包含了丰富的商业洞察。对于开发者而言,能够从这些网站中提取有价值的信息,进行分析和应用,无疑是一项重要的技能。本文将介绍如何使用Java中的Jsoup库来解析京东网站的数据。 Jsoup简介 …

Linux部署redis保姆级教程

一、版本说明 Redis版本号(本文的版本号是6.2.12)的第二位如果是偶数,代表稳定版本,如果是奇数,代表非稳定版本。 所有历史版本下载地址:Index of /releases/ 二、基于压缩包安装(推荐) 2.1安装依赖 2.1.1安装gcc: yum -y install gcc 2.1.2验证gcc是否安装成功:(…

Linux--多路转接之epoll

上一篇:Linux–多路转接之select epoll epoll 是 Linux 下多路复用 I/O 接口 select/poll 的增强版本,它能显著提高程序在大量并发连接中只有少量活跃的情况下的系统 CPU 利用率。它是 Linux 下多路复用 API 的一个选择,相比 select 和 poll&#xff0c…

DevExpress WPF v24.1新版亮点:PDF查看器、富文本编辑器功能升级

DevExpress WPF拥有120个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 DevExpress WPF控件日…

1971. 寻找图中是否存在路径

有一个具有 n 个顶点的 双向 图,其中每个顶点标记从 0 到 n - 1(包含 0 和 n - 1)。图中的边用一个二维整数数组 edges 表示,其中 edges[i] [ui, vi] 表示顶点 ui 和顶点 vi 之间的双向边。 每个顶点对由 最多一条 边连接&#x…

Vue3 学习笔记(一)Vue3 介绍及环境部署

一、Vue.js 简介 1、Vue.js 是什么? Vue.js(读音 /vjuː/, 类似于 view) 是一套构建用户界面的渐进式框架。Vue 只关注视图层, 采用自底向上增量开发的设计。Vue 的目标是通过尽可能简单的 API 实现响应的数据绑定和组合的视图组件…

性能工具之JMeter 通过Java API生成 BeanShell PreProcessor 脚本

文章目录 一、前言二、实现代码三、代码示例四、最后 一、前言 对于上一篇文章(性能工具之 HAR 格式化转换JMeter JMX 脚本文件)还是有点问题。大家在使用的情况需要注意。 如果多个接口相同 path 路径且不同参数进行查询如: 上面接口如果…

【前端】如何制作一个自己的网页(15)

有关后代选择器的具体解释&#xff1a; 后代选择器 后代选择器使用时&#xff0c;需要以空格将多个选择器间隔开。 比如&#xff0c;这里p span&#xff0c;表示只设置p元素内&#xff0c;span元素的样式。 <style> /* 使用后代选择器设置样式 */ p span { …

java--多态(详解)

目录 一、概念二、多态实现的条件三、向上转型和向下转型3.1 向上转型3.2 向下转型 四、重写和重载五、理解多态5.1练习&#xff1a;5.2避免在构造方法中调用重写的方法&#xff1a; 欢迎来到权权的博客~欢迎大家对我的博客提出指导这是我的博客主页&#xff1a;点击 一、概念…

Java毕业设计 基于SpringBoot发卡平台

Java毕业设计 基于SpringBoot发卡平台 这篇博文将介绍一个基于SpringBoot发卡平台&#xff0c;适合用于Java毕业设计。 功能介绍 首页 图片轮播 商品介绍 商品详情 提交订单 文章教程 文章详情 查询订单  查看订单卡密 客服   后台管理 登录 个人信息 修改密码 管…

Selenium爬虫技术:如何模拟鼠标悬停抓取动态内容

介绍 在当今数据驱动的世界中&#xff0c;抓取动态网页内容变得越来越重要&#xff0c;尤其是像抖音这样的社交平台&#xff0c;动态加载的评论等内容需要通过特定的方式来获取。传统的静态爬虫方法难以处理这些由JavaScript生成的动态内容&#xff0c;Selenium爬虫技术则是一…

字典如何与选择器一起使用

背景&#xff1a;开发过程中会遇到某些字段需要做成下拉框。如下图&#xff1a; 组件 | Element里有select选择器这个组件可以实现下拉框的效果 我们可能会想到创一个辅助表来存储这些下拉数据像这样 这样虽然能实现&#xff0c;但是在实际开发中是不合理的&#xff0c;如果有…

个税自然人扣缴客户端数据的备份与恢复(在那个文件夹)

一&#xff0c;软件能够正常打开&#xff0c;软件中的备份与恢复功能 1&#xff0c;备份 您按照下面的方法备份一下哦~ 进入要备份的自然人软件&#xff0c;点击左侧系统设置→→系统管理→→备份恢复&#xff1b; 在备份设置里&#xff0c;点击“备份到选择路径”&#xff0c;…

WebGL编程指南 - 颜色与纹理续

设置纹理坐标&#xff08;initVertexBuffers()&#xff09; 从缓冲区到 attribute 变量的流程&#xff1a; // 顶点坐标 function initVertexBuffers(gl) {// 数据准备let verticesTexCoords new Float32Array([// 顶点坐标&#xff0c;纹理坐标-0.5, 0.5, 0.0, 1.0, -0.5, …

图像异常检测评估指标-分类性能

图像异常检测评估指标-分类性能 1. 混淆矩阵 混淆矩阵包括4个用于衡量分类算法性能的基本数值 四个字母代表的含义是&#xff1a;P&#xff08;Positive&#xff09;代表算法将样本预测为正类&#xff0c;N&#xff08;Negative&#xff09;代表算法将样本预测为负类&#xf…

ST7789读取ID错误新思路(以STC32G为例)

1.前言 前两天刚把ST7789写入搞定&#xff0c;这两天想折腾一下读取。最开始是读ID&#xff0c;先是用厂家送的程序&#xff0c;程序里面用的是模拟I8080协议&#xff0c;一切正常。后来我用STC32G的内置LCM模块&#xff0c;发现读取不出来。更神奇的是ID读不出来&#xff0c;…

[项目详解][boost搜索引擎#2] 建立index | 安装分词工具cppjieba | 实现倒排索引

目录 编写建立索引的模块 Index 1. 设计节点 2.基本结构 3.(难点) 构建索引 1. 构建正排索引&#xff08;BuildForwardIndex&#xff09; 2.❗构建倒排索引 3.1 cppjieba分词工具的安装和使用 3.2 引入cppjieba到项目中 倒排索引代码 本篇文章&#xff0c;我们将继续项…

【C++指南】类和对象(四):类的默认成员函数——全面剖析 : 拷贝构造函数

引言 拷贝构造函数是C中一个重要的特性&#xff0c;它允许一个对象通过另一个已创建好的同类型对象来初始化。 了解拷贝构造函数的概念、作用、特点、规则、默认行为以及如何自定义实现&#xff0c;对于编写健壮和高效的C程序至关重要。 C类和对象系列文章&#xff0c;可点击下…