循环神经网络(RNN)实现股票预测

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
    • 2. 导入数据
  • 四、数据预处理
    • 1.归一化
    • 2.设置测试集训练集
  • 五、构建模型
  • 六、激活模型
  • 七、训练模型
  • 八、结果可视化
    • 1.绘制loss图
    • 2.预测
    • 3.评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2. 导入数据

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('./datasets/SH600519.csv')  # 读取股票文件data
Unnamed: 0dateopenclosehighlowvolumecode
0742010-04-2688.70287.38189.07287.362107036.13600519
1752010-04-2787.35584.84187.35584.68158234.48600519
2762010-04-2884.23584.31885.12883.59726287.43600519
3772010-04-2984.59285.67186.31584.59234501.20600519
4782010-04-3083.87182.34083.87181.52385566.70600519
242124952020-04-201221.0001227.3001231.5001216.80024239.00600519
242224962020-04-211221.0201200.0001223.9901193.00029224.00600519
242324972020-04-221206.0001244.5001249.5001202.22044035.00600519
242424982020-04-231250.0001252.2601265.6801247.77026899.00600519
242524992020-04-241248.0001250.5601259.8901235.18019122.00600519

2426 rows × 8 columns

training_set = data.iloc[0:2426 - 300, 2:3].values  
test_set = data.iloc[2426 - 300:, 2:3].values  

四、数据预处理

1.归一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

2.设置测试集训练集

x_train = []
y_train = []x_test = []
y_test = []"""
使用前60天的开盘价作为输入特征x_train第61天的开盘价作为输入标签y_trainfor循环共构建2426-300-60=2066组训练数据。共构建300-60=260组测试数据
"""
for i in range(60, len(training_set)):x_train.append(training_set[i - 60:i, 0])y_train.append(training_set[i, 0])for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0])# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
将训练数据调整为数组(array)调整后的形状:
x_train:(2066, 60, 1)
y_train:(2066,)
x_test :(240, 60, 1)
y_test :(240,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)"""
输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

五、构建模型

model = tf.keras.Sequential([SimpleRNN(80, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。Dropout(0.2),                         #防止过拟合SimpleRNN(80),Dropout(0.2),Dense(1)
])

六、激活模型

# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')  # 损失函数用均方误差

七、训练模型

history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1)                  #测试的epoch间隔数model.summary()
Epoch 1/20
33/33 [==============================] - 6s 123ms/step - loss: 0.1809 - val_loss: 0.0310
Epoch 2/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0257 - val_loss: 0.0721
Epoch 3/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0165 - val_loss: 0.0059
Epoch 4/20
33/33 [==============================] - 3s 85ms/step - loss: 0.0097 - val_loss: 0.0111
Epoch 5/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0099 - val_loss: 0.0139
Epoch 6/20
33/33 [==============================] - 3s 105ms/step - loss: 0.0067 - val_loss: 0.0167
Epoch 7/20
33/33 [==============================] - 3s 86ms/step - loss: 0.0067 - val_loss: 0.0095
Epoch 8/20
33/33 [==============================] - 3s 91ms/step - loss: 0.0063 - val_loss: 0.0218
Epoch 9/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0052 - val_loss: 0.0109
Epoch 10/20
33/33 [==============================] - 3s 99ms/step - loss: 0.0043 - val_loss: 0.0120
Epoch 11/20
33/33 [==============================] - 3s 92ms/step - loss: 0.0044 - val_loss: 0.0167
Epoch 12/20
33/33 [==============================] - 3s 89ms/step - loss: 0.0039 - val_loss: 0.0032
Epoch 13/20
33/33 [==============================] - 3s 88ms/step - loss: 0.0041 - val_loss: 0.0052
Epoch 14/20
33/33 [==============================] - 3s 93ms/step - loss: 0.0035 - val_loss: 0.0179
Epoch 15/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0033 - val_loss: 0.0124
Epoch 16/20
33/33 [==============================] - 3s 95ms/step - loss: 0.0035 - val_loss: 0.0149
Epoch 17/20
33/33 [==============================] - 4s 111ms/step - loss: 0.0028 - val_loss: 0.0111
Epoch 18/20
33/33 [==============================] - 4s 110ms/step - loss: 0.0029 - val_loss: 0.0061
Epoch 19/20
33/33 [==============================] - 3s 104ms/step - loss: 0.0027 - val_loss: 0.0110
Epoch 20/20
33/33 [==============================] - 3s 90ms/step - loss: 0.0028 - val_loss: 0.0037
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn (SimpleRNN)       (None, 60, 80)            6560      
_________________________________________________________________
dropout (Dropout)            (None, 60, 80)            0         
_________________________________________________________________
simple_rnn_1 (SimpleRNN)     (None, 80)                12880     
_________________________________________________________________
dropout_1 (Dropout)          (None, 80)                0         
_________________________________________________________________
dense (Dense)                (None, 1)                 81        
=================================================================
Total params: 19,521
Trainable params: 19,521
Non-trainable params: 0
_________________________________________________________________

八、结果可视化

1.绘制loss图

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

2.预测

predicted_stock_price = model.predict(x_test)                       # 测试集输入模型进行预测
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(0,1)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])              # 对真实数据还原---从(0,1)反归一化到原始范围# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction by K同学啊')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

在这里插入图片描述

3.评估

MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)print('均方误差: %.5f' % MSE)
print('均方根误差: %.5f' % RMSE)
print('平均绝对误差: %.5f' % MAE)
print('R2: %.5f' % R2)
均方误差: 1833.92534
均方根误差: 42.82435
平均绝对误差: 36.23424
R2: 0.72347

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/162204.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Rust】快速教程——一直在单行显示打印、输入、文件读写

前言 恨不过是七情六欲的一种,再强大的恨也没法独占整颗心,总有其它情感隐藏在心底深处,说不定在什么时候就会掀起滔天巨浪。——《死人经》 图中是Starship扔掉下面的燃料罐,再扔掉头顶的翅膀后,再翻转过来着陆火星的…

[C++ 从入门到精通] 13.派生类、调用顺序、继承方式、函数遮蔽

📢博客主页:https://loewen.blog.csdn.net📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正!📢本文由 丶布布原创,首发于 CSDN,转载注明出处🙉📢现…

【Unity细节】Default clip could not be found in attached animations list.(动画机报错)

👨‍💻个人主页:元宇宙-秩沅 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 秩沅 原创 😶‍🌫️收录于专栏:unity细节和bug 😶‍🌫️优质专栏 ⭐【…

生产制造业如何谋求数字化转型?需要哪些信息化系统做支撑?

生产制造业的数字化转型是将数字系统和各种技术整合到传统制造流程中的过程,这将导致行业格局的重大变革。工业4.0的到来为制造业开创了一个新时代,制造商可以简化生产线,提高整体效率。同时,这一技术革命使他们能够收集到大量的数…

Altium Designer学习笔记9

忽视了一个最大的问题,就是元器件的封装,不应该是根据AD系统的封装走,而应该是根据立创商城上的规格书,确认每个封装的大小,画出封装图,然后才是布局和走线。 1、确认电容的封装采用0805,贴片电…

【css】Google第三方登录按钮样式修改

文章目录 场景前置准备修改样式官方属性修改样式CSS修改样式按钮的高度height和border-radiusLogo和文字布局 场景 需要用到谷歌的第三方登录,登录按钮有自己的样式。根据官方文档:概览 | Authentication | Google for Developers,提供两种第…

局域网协议:地址解析协议(ARP,Address Resolution Protocol)

地址解析协议(ARP,Address Resolution Protocol)是一种用于在IP网络中将IP地址映射到物理MAC地址的协议。在IP网络中,IP是用于寻址,真正将数据包从一个设备发送到另外一个设备,用于通信的是物理MAC地址。 …

40、Flink 的Apache Kafka connector(kafka sink的介绍及使用示例)-2

Flink 系列文章 1、Flink 部署、概念介绍、source、transformation、sink使用示例、四大基石介绍和示例等系列综合文章链接 13、Flink 的table api与sql的基本概念、通用api介绍及入门示例 14、Flink 的table api与sql之数据类型: 内置数据类型以及它们的属性 15、Flink 的ta…

geemap学习笔记012:如何搜索Earth Engine Python脚本

前言 本节主要是介绍如何查询Earth Engine中已经集成好的Python脚本案例。 1 导入库 !pip install geemap #安装geemap库 import ee import geemap2 搜索Earth Engine Python脚本 很简单,只需要一行代码。 geemap.ee_search()使用方法 后记 大家如果有问题需…

如何提高图片转excel的效果?(软件选择篇)

在日常的工作中,我们常常会遇到一些财务报表类的图片需要转换成可编辑的excel,但是,受各种条件的限制,常常只能通过手工录入这种原始的方式来实现,随着人工智能、深度学习以及网络技术的发展,这种原始的录入…

SpringBoot集成七牛云OSS详细介绍

📑前言 本文主要SpringBoot集成七牛云OSS详细介绍的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是青衿🥇 ☁️博客首页:CSDN主页放风讲故事 🌄每日一句&a…

【Java工具篇】Java反编译工具Bytecode Viewer

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【C++高阶(四)】红黑树深度剖析--手撕红黑树!

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:C从入门到精通⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习C   🔝🔝 红黑树 1. 前言2. 红黑树的概念以及性质3. 红黑…

计算机网络之数据链路层

一、概述 1.1概述 物理层发出去的信号需要通过数据链路层才知道是否到达目的地;才知道比特流的分界线 链路(Link):从一个结点到相邻结点的一段物理线路,中间没有任何其他交换结点数据链路(Data Link):把实现通信协议的硬件和软件…

电商API接口|电商数据接入|拼多多平台根据商品ID查商品详情SKU和商品价格参数

随着科技的不断进步,API开发领域也逐渐呈现出蓬勃发展的势头。今天我将向大家介绍API接口,电商API接口具备独特的特点,使得数据获取变得更加高效便捷。 快速获取API数据——优化数据访问速度 传统的数据获取方式可能需要经过多个中介环节&…

华大基因认知障碍基因检测服务,助力认知障碍疾病防控

认知障碍是一种严重的神经系统疾病,对人类的脑健康产生了重大影响。据报告显示,在我国65岁以上的人群中,存在轻度认知障碍的患者约为3,800万,而中重度痴呆患者则约为1,500万,患病人口数量庞大。这种疾病不仅会对患者的…

免费多域名SSL证书

顾名思义,免费多域名SSL证书就是一种能够为多个域名或子域提供HTTPS安全保护的证书。这意味着,如果您有三个域名——例如example.com、example.cn和company.com,您可以使用一个免费的多域名SSL证书为所有这些域名提供安全保障,而无…

TransFusionNet:JetsonTX2下肝肿瘤和血管分割的语义和空间特征融合框架

TransFusionNet: Semantic and Spatial Features Fusion Framework for Liver Tumor and Vessel Segmentation Under JetsonTX2 TransFusionNet:JetsonTX2下肝肿瘤和血管分割的语义和空间特征融合框架背景贡献实验方法Transformer-Based Semantic Feature Extractio…

AI写代码 可以代替人工吗?

近年AI技术非常火热,有人就说,用AI写代码程序员不就都得下岗吗?对此我的回答是否定的,因为AI虽然已经有了编写代码的能力,但它现在的水平大多还仅限于根据业务需求搭建框架,而具体的功能实现还尚且稚嫩&…

【愚公系列】保姆级教程带你实现HarmonyOS手语猜一猜元服务

🚀前言 最近HarmonyOS NEXT大火,这个纯血鸿蒙吸引力了大家的关注。虽然现在还没面向个人开发者开放,但我们可以基于最新的API9及开发工具来尝试开发鸿蒙新的应用形态——元服务。来体验下未来在HarmonyOS NEXT上实现的应用开发。 HarmonyOS…