使用LSTM预测股票收盘价

在金融数据预测中,LSTM(长短期记忆网络)凭借其在时间序列数据建模中的优势,成为了分析股票价格趋势的热门选择。本篇博客将以完整的代码实现为例,展示如何利用LSTM网络对股票收盘价进行预测,并从数据处理到模型训练进行全面解析。


一、数据预处理与可视化

1. 导入并整理数据

首先,我们从CSV文件中加载了股票数据,并确保其按照日期递增排序,便于时间序列分析。以下是数据的基本信息:

filepath = './rlData.csv'
data = pd.read_csv(filepath)
data = data.sort_values('Date')
print(data.head())
print(data.shape)

2. 可视化股票价格

为了更直观地理解数据走势,我们对收盘价(Close)进行了可视化:

plt.figure(figsize=(15, 9))
plt.plot(data['Close'])
plt.xticks(range(0, data.shape[0], 20), data['Date'].loc[::20], rotation=45)
plt.title("Stock Price Trend", fontsize=18, fontweight='bold')
plt.xlabel('Date', fontsize=18)
plt.ylabel('Close Price (USD)', fontsize=18)
plt.savefig('StockPrice.jpg')
plt.show()

通过这一步骤,我们能够观察到股票价格的波动规律,为后续建模提供参考。


二、特征工程与数据集制作

1. 数据归一化

LSTM对输入数据的范围较为敏感,因此我们使用MinMaxScaler将数据归一化到[-1, 1]之间:

scaler = MinMaxScaler(feature_range=(-1, 1))
price['Close'] = scaler.fit_transform(price['Close'].values.reshape(-1, 1))

2. 时间序列数据集构建

为了利用前lookback天的数据预测未来一天的收盘价,我们编写了以下函数对数据进行切分:

def split_data(stock, lookback):data_raw = stock.to_numpy()data = []for index in range(len(data_raw) - lookback):data.append(data_raw[index: index + lookback])data = np.array(data)test_set_size = int(np.round(0.2 * data.shape[0]))train_set_size = data.shape[0] - (test_set_size)x_train = data[:train_set_size, :-1, :]y_train = data[:train_set_size, -1, :]x_test = data[train_set_size:, :-1, :]y_test = data[train_set_size:, -1, :]return [x_train, y_train, x_test, y_test]

三、LSTM模型构建与训练

1. 模型定义

LSTM模型由多层循环网络和全连接层组成。模型的输入维度、隐藏层大小和输出维度均可以根据需求调整:

class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dimself.num_layers = num_layersself.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_()out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))out = self.fc(out[:, -1, :])return out

2. 模型训练

在训练过程中,我们使用均方误差(MSE)作为损失函数,并通过Adam优化器进行优化。训练代码如下:

for t in range(num_epochs):y_train_pred = model(x_train)loss = criterion(y_train_pred, y_train_lstm)optimiser.zero_grad()loss.backward()optimiser.step()

四、模型结果与可视化

1. 训练结果分析

模型训练后,我们对预测值与真实值进行了可视化对比:

sns.lineplot(x=original.index, y=original[0], label="Actual Value")
sns.lineplot(x=predict.index, y=predict[0], label="Training Prediction")

2. 测试集性能

利用均方根误差(RMSE)评价模型的预测能力:

trainScore = math.sqrt(mean_squared_error(y_train[:, 0], y_train_pred[:, 0]))
testScore = math.sqrt(mean_squared_error(y_test[:, 0], y_test_pred[:, 0]))
print("Train RMSE: %.2f" % trainScore)
print("Test RMSE: %.2f" % testScore)

五、总结与思考

本项目通过LSTM网络对股票价格进行了预测,其主要特点包括:

  1. 时间序列处理能力:LSTM对长序列数据的依赖性表现良好。
  2. 数据可视化:通过Matplotlib与Plotly直观呈现预测效果。

未来优化方向

  1. 引入更多特征(如成交量、开盘价)。
  2. 优化超参数(如学习率、隐藏层维度)。
  3. 探索其他深度学习模型(如Transformer)。

完整代码提供了详细的实现步骤,适用于初学者学习LSTM模型以及时间序列分析。希望本文对您有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65849.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

模拟SpringIOCAOP

一、IOC容器 Ioc负责创建,管理实例,向使用者提供实例,ioc就像一个工厂一样,称之为Bean工厂 1.1 Bean工厂的作用 先分析一下Bean工厂应具备的行为 1、需要一个获取实例的方法,根据一个参数获取对应的实例 getBean(…

预编译SQL

预编译SQL 预编译SQL是指在数据库应用程序中,SQL语句在执行之前已经通过某种机制(如预编译器)进行了解析、优化和准备,使得实际执行时可以直接使用优化后的执行计划,而不需要每次都重新解析和编译。这么说可能有一些抽…

Centos9 + Docker 安装 MySQL8.4.0 + 定时备份数据库到本地

Centos9 + Docker 安装 MySQL8.4.0 + 定时备份数据库到本地 创建目录,创建配置文件启动容器命令定时备份MySQL执行脚本Linux每日定时任务命令文件内参数其他时间参数AT一次性定时任务创建目录,创建配置文件 $ mkdir -p /opt/mysql/conf$ vim /opt/mysql/conf/my.cnf[mysql] #…

软件测试预备知识⑥—搭建Web服务器

在软件测试的广阔领域中,搭建Web服务器是一项极为关键的技能。它不仅有助于模拟真实的应用环境,方便我们对Web应用进行全面且深入的测试,还能让测试人员更好地掌控测试场景,提升测试效率与质量。接下来,让我们一同深入…

计算机视觉算法实战——打电话行为检测

✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连✨ ​​​​​​​ ​​​​​​​​​​​​​​​ ​​​​​​ ​ 1. 引言✨✨ 随着智能手机的普及,打电话行为检测成为了计算机视…

事务的隔离级别和MDL

文章目录 说明不同隔离级别可能发生的现象关键现象解释MDL(元数据锁,Metadata Lock)MDL 的作用MDL 的工作原理MDL 锁的常见场景如何避免 MDL 阻塞 说明 本文章由大模型对话整理而来,如果有错误之处,请在评论区留言指正…

Linux第二课:LinuxC高级 学习记录day01

0、大纲 0.1、Linux 软件安装,用户管理,进程管理,shell 命令,硬链接和软连接,解压和压缩,功能性语句,结构性语句,分文件,make工具,shell脚本 0.2、C高级 …

单片机存储与计算机存储:从微小到庞大的数据世界

单片机存储与计算机存储:从微小到庞大的数据世界 在现代电子设备中,存储是至关重要的组成部分。无论是小巧的单片机,还是功能强大的计算机,存储都扮演着不可或缺的角色。然而,单片机和计算机的存储架构却有着天壤之别…

ISP流程--去马赛克详解

前言 本期我们将深入讨论ISP流程中的去马赛克处理。我们熟知,彩色图像由一个个像元组成,每个像元又由红、绿、蓝(RGB)三通道构成。而相机传感器只能感知光的强度,无法直接感知光谱信息,即只有亮暗而没有颜色…

阿里云-通义灵码:在 PyCharm 中的强大助力(下)

目录 六.通义灵码在 PyCharm 中的优势与不足 1.优势 (1).提高开发效率 (2).提升代码质量 (3).易于使用 (4).不断学习和改进 2.不足 (1).依赖网络 (2).准确性有待提高 (3).局限性 七.未来发展展望 1.提高准确性和可靠性 2.与其他工具的集成 3.智能化程度的提升 八…

开源项目stable-diffusion-webui部署及生成照片

参考链接 https://www.freedidi.com/13133.html 基础环境部署 python 官网链接 Python Release Python 3.10.6 | Python.org 下载 Python 3.10.6 版本安装包 下载好后双击 点击安装,这里需要选择一下,把环境变量加上。(这里是默认安装到C盘…

【芯片封测学习专栏 -- 单 Die 与 多Die(Chiplet)介绍】

请阅读【嵌入式开发学习必备专栏 Cache | MMU | AMBA BUS | CoreSight | Trace32 | CoreLink | ARM GCC | CSH】 文章目录 Overview单个Die(Monolithic Die)多个Die(Chiplet Architecture or Heterogeneous SoC)如何判断一个SoC是…

Windows 安装 Docker 和 Docker Compose

🚀 作者主页: 有来技术 🔥 开源项目: youlai-mall ︱vue3-element-admin︱youlai-boot︱vue-uniapp-template 🌺 仓库主页: GitCode︱ Gitee ︱ Github 💖 欢迎点赞 👍 收藏 ⭐评论 …

java_将数据存入elasticsearch进行高效搜索

使用技术简介: (1) 使用Nginx实现反向代理,使前端可以调用多个微服务 (2) 使用nacos将多个服务管理关联起来 (3) 将数据存入elasticsearch进行高效搜索 (4) 使用消息队列rabbitmq进行消息的传递 (5) 使用 openfeign 进行多个服务之间的api调用 参…

Github Copilot学习笔记

(一)Prompt Engineering 利用AI工具生成prompt设计好的prompt结构使用MarkDown语法,按Role, Skills, Constrains, Background, Requirements和Demo这几个维度描述需求。然后收输入提示词:作为 [Role], 拥有 [Skills], 严格遵守 […

android分区和root

线刷包内容: 线刷包是一个完整的android镜像,不但包括android、linux和用户数据,还包括recovery等。当然此图中没有recovery,但是我们可以自己刷入一个。 主要分区 system.img 系统分区,包括linux下主要的二进制程序。 boot.img…

RabbitMQ基础(简单易懂)

RabbitMQ高级篇请看: RabbitMQ高级篇-CSDN博客 目录 什么是RabbitMQ? MQ 的核心概念 1. RabbitMQ 的核心组件 2. Exchange 的类型 3. 数据流向说明 如何安装RabbitQueue? WorkQueue(工作队列): Fa…

大数据环境搭建进度

1.使用虚拟机的系统:centos7.xLinux 2.资源不足,使用云服务器: 1. 3.使用远程登录进行操作 用xshell 4.任务 1.虚拟机装好 2.设置IP地址 3.可以联网 4.设置远程登录访问 5.创建module和software目录,修改两…

线程安全问题介绍

文章目录 **什么是线程安全?****为什么会出现线程安全问题?****线程安全问题的常见场景****如何解决线程安全问题?**1. **使用锁**2. **使用线程安全的数据结构**3. **原子操作**4. **使用volatile关键字**5. **线程本地存储**6. **避免死锁*…

pytorch小记(七):pytorch中的保存/加载模型操作

pytorch小记(七):pytorch中的保存/加载模型操作 1. 加载模型参数 (state_dict)1.1 保存模型参数1.2 加载模型参数1.3 常见变种1.3.1 指定加载设备1.3.2 非严格加载(跳过部分层)1.3.3 打印加载的参数 2. 加载整个模型2.…