RNN股票预测(Pytorch版)

任务:基于zgpa_train.csv数据,建立RNN模型,预测股价
1.完成数据预处理,将序列数据转化为可用于RNN输入的数据
2.对新数据zgpa_test.csv进行预测,可视化结果
3.存储预测结果,并观察局部预测结果
备注:模型结构:单层RNN,输出有5个神经元,每次使用前8个数据预测第9个数据
参考视频:吹爆!3小时搞懂!【RNN循环神经网络+时间序列LSTM深度学习模型】学不会UP主下跪!
up主用的Keras,自己用Pytorch尝试了一下,代码如下:

import pandas as pd
import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
data = pd.read_csv('zgpa_train.csv')
# loc 通过行索引 “Index” 中的具体值来取行数据
# 取出开盘价
price = data.loc[:,'close']# 归一化
price_norm = price/max(price)
# 开盘价折线图
# fig1 = plt.figure(figsize=(10, 6))
# plt.plot(price)
# plt.title('close price')
# plt.xlabel('time')
# plt.ylabel('price')
# plt.show()# 提取数据 每次使用前8个数据来预测第九个数据
def extract_data(data, time_step):x = []y = []for i in range(len(data)- time_step):x.append([a for a in data[i:i+time_step]])y.append(data[i + time_step])x = np.array(x)x = x.reshape(x.shape[0], x.shape[1], 1)x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)return x, y
time_step = 8
x, y = extract_data(price_norm,time_step)
# print(x)
# print(y)
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(RNN,self).__init__()self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first = True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.rnn(x)# print(out)out = self.fc(out[:, -1, :])out = out.squeeze(1)return out
# 定义模型参数
input_size = 1 # 输入特征的维度
hidden_size = 64 # 隐藏层的维度
output_size = 1 # 输出特征的维度
num_layers = 1 # RNN的层数# 创建模型
model = RNN(input_size, hidden_size, output_size, num_layers)# 定义损失函数和优化器
criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
epochs = 200
for epoch in range(epochs):optimizer.zero_grad()# outputs = model(x.unsqueeze(2))outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
# 进行预测 数据很少这里就不先保存模型再预测了
model.eval()
with torch.no_grad():y_train_predict = model(x) * max(price)
y_train = [i * max(price) for i in y]
# print(y_train_predict)
y_train_predict = y_train_predict.cpu().numpy()
y_train = np.array(y_train)
fig2 = plt.figure(figsize=(10, 6))
plt.plot(y_train_predict, label='Predicted', color='blue')
plt.plot(y_train, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()# 测试集
data_test = pd.read_csv('zgpa_test.csv')
price_test = data_test.loc[:,'close']
price_test_norm = price_test/max(price)
x_test,y_test = extract_data(price_test_norm,time_step)
with torch.no_grad():y_test_predict = model(x_test) * max(price)
y_test = [i * max(price) for i in y_test]
# print(y_train_predict)
y_test_predict = y_test_predict.cpu().numpy()
y_test = np.array(y_test)
fig3 = plt.figure(figsize=(10, 6))
plt.plot(y_test_predict, label='Predicted', color='blue')
plt.plot(y_test, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values (Test Set)')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()# 存储数据
result_y_test = np.array(y_test).reshape(-1, 1) # 若干行,1列
result_y_test_predict = y_test_predict.reshape(-1, 1)
print(result_y_test.shape, result_y_test_predict.shape)
result = np.concatenate((result_y_test, result_y_test_predict), axis=1)
print(result.shape)
result = pd.DataFrame(result, columns=['real_price_test', 'predict_price_test'])
result.to_csv('zgpa_predict_test.csv')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/54093.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB 可视化基础:绘图命令与应用

目录 1. 绘制子图1.1基本绘图命令1.2. 使用 subplot 函数1.3. 绘图类型 2.MATLAB 可视化进阶(以下代码均居于以上代码的数据定义上实现)2.1. 极坐标图2.3. 隐函数的绘制 3.总结 在数据分析和科学计算中,数据可视化是理解和解释结果的关键工具。今天,我将…

.bixi勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复

导言 随着网络技术的飞速发展,网络安全威胁日益加剧,各种勒索病毒层出不穷。其中,.bixi勒索病毒(也称为Bixi Ransomware)作为一种新兴的网络安全威胁,对个人用户和企业数据安全构成了严重威胁。本文91数据…

PHP7 json_encode() 浮点小数溢出错误

原因已找到, 该现象只出现在PHP 7.1版本上 建议使用默认值 serialize_precision -1 即可 事情是这样的,项目里发现一个奇怪的现象,json_encode一个带浮点价格的数据, 出现溢出, 比如: echo json_encode(277.2); // 输出结果为: 277.199999999999989这明显是不能接受的, 数据…

【C++】基础知识 笔记

目录 1.1 基本结构: 1.2 注释 单行注释 多行注释 1.3 变量 1.4 常量 C定义常量两种方式 1.5 关键字(标识符) 标识符起名规则 1.1 基本结构: #include "iostream" using namespace std; //以上两行是预处理指令…

Text2vec -文本转向量

文章目录 一、关于 Text2vec1、Text2vec 是什么2、Features3、Demo4、News5、Evaluation英文匹配数据集的评测结果:中文匹配数据集的评测结果: 6、Release Models 二、Install三、使用1、文本向量表征1.2 Usage (HuggingFace Transformers)1.3 Usage (se…

标准库标头 <barrier>(C++20)学习

此头文件是线程支持库的一部分。 类模板 std::barrier 提供一种线程协调机制,阻塞已知大小的线程组直至该组中的所有线程到达该屏障。不同于 std::latch,屏障是可重用的:一旦到达的线程组被解除阻塞,即可重用同一屏障。与 std::l…

NISP 一级 | 5.5 账户口令安全

关注这个证书的其他相关笔记:NISP 一级 —— 考证笔记合集-CSDN博客 0x01:账户口令安全威胁 当用户在使用各种应用时,需通过账户和口令来验证身份从而访问某些资源,因此,账号口令的安全性非常重要。当前攻击者窃取用户…

深度学习之微积分预备知识点

极限(Limit) 定义:表示某一点处函数趋近于某一特定值的过程,一般记为 极限是一种变化状态的描述,核心思想是无限靠近而永远不能到达 公式: 表示 x 趋向 a 时 f(x) 的极限。 知识点口诀解释极限的存在左…

C语言 | Leetcode C语言题解之第412题Fizz Buzz

题目&#xff1a; 题解&#xff1a; /*** Note: The returned array must be malloced, assume caller calls free().*/ char ** fizzBuzz(int n, int* returnSize) {/*定义字符串数组*/char **answer (char**)malloc(sizeof(char*)*n);for(int i 1;i<n;i){/*分配单个字符串…

React学习day06-异步操作、ReactRouter的概念及简单使用

13、续 &#xff08;8&#xff09;异步状态操作 1&#xff09;在子仓库中 ①创建仓库 ②解构需要的方法 ③安装axios ④封装并导出请求 ⑤在reducer中为newsList赋值 ⑥获取并导出reducer函数 2&#xff09;在入口文件index.js中&#xff0c;注入 3&#xff09;在App.js中&a…

Ansible自动化部署kubernetes集群

机器环境介绍 1.1. 机器信息介绍 IP hostname application CPU Memory 192.168.204.129 k8s-master01 etcd&#xff0c;kube-apiserver&#xff0c;kube-controller-manager&#xff0c;kube-scheduler,kubelet,kube-proxy,containerd 2C 4G 192.168.204.130 k8s-w…

根据NVeloDocx Word模板引擎生成Word(六-结束)

前面几篇已经把E6开发平台配套的Word模版隐藏NVeloDocx的基础用法介绍了一遍&#xff0c;这些基础用法基本上可以完全覆盖实际业务的绝大部分需求。所以我们这一篇就介绍一些边边角角的内容&#xff0c;给本系列来一个首尾。 本篇的主要内容有&#xff1a; 1、汇总计算&#…

Java实现建造者模式和源码中的应用

&#x1f3af; 设计模式专栏&#xff0c;持续更新中&#xff0c; 欢迎订阅&#xff1a;JAVA实现设计模式 &#x1f6e0;️ 希望小伙伴们一键三连&#xff0c;有问题私信都会回复&#xff0c;或者在评论区直接发言 Java实现建造者模式&#xff08;Builder Pattern&#xff09; 文…

ubuntu安装mysql 8.0忘记root初始密码,如何重新修改密码

1、停止mysql服务 $ service mysql stop 2、修改my.cnf文件 # 修改my.cnf文件&#xff0c;在文件新增 skip-grant-tables&#xff0c;在启动mysql时不启动grant-tables&#xff0c;授权表 $ sudo vim /etc/mysql/my.cnf [mysqld] skip-grant-tables 3、启动mysql服务 servic…

【四】k8s部署 TDengine集群

k8s部署 TDengine集群 目录 k8s部署 TDengine集群 一、在 Kubernetes 上部署 TDengine 集群 第一步&#xff1a;创建命名空间 第二步&#xff1a;从yaml创建有状态服务 StatefulSet 第三步&#xff1a;配置 Service 服务 二、集群测试 一、在 Kubernetes 上部署 TDengine…

实习期间git的分枝管理以及最常用的命令

各位找工作实习的友友在工作之前一定要把git的相关知识掌握呀&#xff0c;我实现期间被leader说过关于git规范的相关问题了 目前已更新系列&#xff1a; 当前&#xff1a;:实习期间git的分枝管理以及最常用的命令 Redis高级-----持久化AOF、RDB原理 Redis高级---面试总结5种…

Android SPN/PLMN 显示逻辑简介

功能描述 当设备驻网后(运营商网络),会在状态栏、锁屏界面、下拉控制中心显示运营商的名称。 此名称来源有两种: 1、SPN(Service Provider Name) 2、PLMN (Public Land Mobile Name) 功能AOSP默认逻辑SPN提供SIM卡的运营商名称预置在SIM EF中,SIM卡发行运营商名称…

微软九月补丁星期二发现了 79 个漏洞

微软将在2024 年 9 月补丁星期二修复 79 个漏洞。 微软有证据表明&#xff0c;发布的四个漏洞被野外利用和/或公开披露&#xff1b;所有四个漏洞均已在CISA KEV上列出。微软还在修补四个关键的远程代码执行 (RCE) 漏洞。 不同寻常的是&#xff0c;微软本月尚未修补任何浏览器…

redis详细解析和配置选择

Redis是一个开源的、使用ANSI C语言编写的、基于内存亦可持久化的日志型Key-Value非关系型数据库。它以其高性能、丰富的数据结构和灵活的数据模型而广受欢迎&#xff0c;被广泛应用于缓存、消息队列、实时数据处理等多种场景。以下是对Redis的详细解析和配置选择的详细阐述。 …

AI替代插画师跟设计师?不用焦虑!

一个固定的工作流&#xff0c; 一个训练好的lora模型 输入一段提示词 二三十秒的时间&#xff0c;就能生成一张精致美观有韵味的中秋国风插画 这张不喜欢&#xff0c;改下提示词重新生成一张不一样的。还是二十几秒 同样的插画&#xff0c;你用手绘&#xff0c;从起稿到上…