TensorFlow项目练手(三)——基于GRU股票走势预测任务

项目介绍

项目基于GRU算法通过20天的股票序列来预测第21天的数据,有些项目也可以用LSTM算法,两者主要差别如下:

  • LSTM算法:目前使用最多的时间序列算法,是一种特殊的RNN(循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
  • GRU算法:是一种特殊的RNN。和LSTM一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

一、准备数据

1、获取数据

  1. 通过命令行安装yfinance
  2. 通过api获取股票数据
  3. 保存到csv中方便使用
import pandas_datareader.data as web
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.rcParams['font.sans-serif']='SimHei' #图表显示中文import yfinance as yf
yf.pdr_override() #需要调用这个函数# 1、获取股票数据
#上海的股票代码+.SS;深圳的股票代码+.SZ :
stock = web.get_data_yahoo("601318.SS", start="2022-01-01", end="2023-07-17")
# 保存到csv中
pd.DataFrame(data=stock).to_csv('./stock.csv')# 2、获取csv中的数据
features = pd.read_csv('stock.csv')
features = features.drop('Adj Close',axis=1)
features.head()

在这里插入图片描述

2、数据可视化

通过绘图的方式查看当前的数据情况

# 3、绘图看看收盘价数据情况
close=features["Close"]
# 计算20天和100天移动平均线:
short_rolling_close = close.rolling(window=20).mean()
long_rolling_close = close.rolling(window=100).mean()
# 绘制
fig, ax = plt.subplots(figsize=(16,9))   #画面大小,可以修改
ax.plot(close.index, close, label='中国平安')   #以收盘价为索引值绘图
ax.plot(short_rolling_close.index, short_rolling_close, label='20天均线')
ax.plot(long_rolling_close.index, long_rolling_close, label='100天均线')
#x轴、y轴及图例:
ax.set_xlabel('日期')
ax.set_ylabel('收盘价 (人民币)')
ax.legend()      #图例
plt.show()      #绘图

在这里插入图片描述

3、数据预处理

取出当前的收盘价,删除无用的日期元素

# 4、取出label值
labels = features['Close']
time = features['Date']
features = features.drop('Date',axis=1)
features.head()

在这里插入图片描述

进行数据的归一化

# 5、数据预处理
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features

在这里插入图片描述

4、构建数据序列

由于RNN的算法要求我们要有一定的序列,来预测出下一个值,所以我们按照20天的数据作为一个序列

# 6、定义序列,[下标1-20天预测第21天的收盘价]
from collections import dequex = []
y = []seq_len = 20
deq = deque(maxlen=seq_len)
for i in input_features:deq.append(list(i))if len(deq) == seq_len:x.append(list(deq))x = x[:-1] # 取少一个序列,因为最后个序列没有答案
y = features['Close'].values[seq_len: ] #从第二十一天开始(下标为20)
time = time.values[seq_len: ] #从第二十一天开始(下标为20)x, y, time = np.array(x), np.array(y), np.array(time)
print(x.shape)
print(y.shape)
print(time.shape)

在这里插入图片描述

二、构建模型

1、搭建GRU模型

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layersfrom keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import Adam# 7、搭建模型
model = tf.keras.Sequential()
model.add(layers.GRU(8,input_shape=(20,5), activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(16, activation='relu', return_sequences=True,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.GRU(32, activation='relu', return_sequences=False,kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.01)))
model.add(layers.Dense(1))
model.summary()

在这里插入图片描述

2、优化器和损失函数

# 优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.MeanAbsoluteError(), # 标签和预测之间绝对差异的平均metrics = tf.keras.losses.MeanSquaredLogarithmicError()) # 计算标签和预测

3、开始训练

25%的比例作为验证集,75%的比例作为训练集

# 开始训练
model.fit(x,y,validation_split=0.25,epochs=200,batch_size=128)

在这里插入图片描述

4、模型预测

# 预测
y_pred = model.predict(x)
fig = plt.figure(figsize=(10,5))
axes = fig.add_subplot(111)
axes.plot(time,y,'b-',label='actual')
# 预测值,红色散点
axes.plot(time,y_pred,'r--',label='predict')
axes.set_xticks(time[::50])
axes.set_xticklabels(time[::50],rotation=45)plt.legend()
plt.show()

在这里插入图片描述

5、回归指标评估

from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
from math import sqrt#回归评价指标
# calculate MSE 均方误差
mse=mean_squared_error(y,y_pred)
# calculate RMSE 均方根误差
rmse = sqrt(mean_squared_error(y, y_pred))
#calculate MAE 平均绝对误差
mae=mean_absolute_error(y,y_pred)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

在这里插入图片描述

源代码

  • 源码查看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/21734.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python小白学习:超级详细的字典介绍(字典的定义、存储、修改、遍历元素和嵌套)

目录 一、字典简介1.1 创建字典1.2 访问字典中的值1.3 添加键值对1.4 修改字典中的值实例 1.5 删除键值对1.6 由多个类似对象组成的字典1.7 使用get()访问值1.8 练习题 二、遍历字典2.1 遍历所有键值对实例 2.2 遍历字典中的所有键2.3 按照特定顺序遍历字典中的所有键2.4 遍历字…

04 Ubuntu中的中文输入法的安装

在Ubuntu22.04这种版本相对较高的系统中安装中文输入法,一般推荐使用fctix5,相比于其他的输入法,这款输入法的推荐词要好得多,而且不会像ibus一样莫名其妙地失灵。 首先感谢文章《滑动验证页面》,我是根据这篇文章的教…

通用指令(汇编)

一、数据处理指令1)数学运算数据运算指令的格式数据搬移指令立即数伪指令加法指令带进位的加法指令减法指令带借位的减法指令逆向减法指令乘法指令数据运算指令的扩展 2)逻辑运算按位与指令按位或指令按位异或指令左移指令右移指令位清零指令 3&#xff…

前端实现打印1 - 使用 iframe 实现 并 分页打印

目录 打印代码对话框预览打印预览 打印代码 <!-- 打印 --> <template><el-dialogtitle"打印":visible.sync"dialogVisible"width"50%"top"7vh"append-to-bodyclose"handleClose"><div ref"print…

STM32 DHT11

DHT11 DHT11数字温湿度传感器是一款含有已校准数字信号输出的温湿度复合传感器。 使用单总线通信 该传感器包括一个电容式感湿元件和一个NTC测温元件&#xff0c;并于一个高性能8位单片机相连&#xff08;模数转换&#xff09;。 DHT11引脚说明 开漏模式下没有输出高电平的能…

Embedding入门介绍以及为什么Embedding在大语言模型中很重要

Embeddings技术简介及其历史概要 在机器学习和自然语言处理中&#xff0c;embedding是指将高维度的数据&#xff08;例如文字、图片、音频&#xff09;映射到低维度空间的过程。embedding向量通常是一个由实数构成的向量&#xff0c;它将输入的数据表示成一个连续的数值空间中…

Java版本spring cloud + spring boot企业电子招投标系统源代码 tbms

​ 功能模块&#xff1a; 待办消息&#xff0c;招标公告&#xff0c;中标公告&#xff0c;信息发布 描述&#xff1a; 全过程数字化采购管理&#xff0c;打造从供应商管理到采购招投标、采购合同、采购执行的全过程数字化管理。通供应商门户具备内外协同的能力&#xff0c;为…

Kubernetes概述

Kubernetes概述 使用kubeadm快速部署一个k8s集群 Kubernetes高可用集群二进制部署&#xff08;一&#xff09;主机准备和负载均衡器安装 Kubernetes高可用集群二进制部署&#xff08;二&#xff09;ETCD集群部署 Kubernetes高可用集群二进制部署&#xff08;三&#xff09;部署…

企业新片场排名如何优化

企业新片场排名如何优化 要如何去做关键SEO&#xff1f;第一个我们要做的就是做好 SEO 关键词的选词&#xff0c;一般就是会有第一个常用的选词方法&#xff0c;第一是以常用的提问词去做&#xff0c;不实像是情人节买什么礼物&#xff0c;母亲节买什么礼物&#xff0c; 618 有…

推荐前端开发者提升效率的工具

是否掌握新的技术很大程度决定着你是否被淘汰。 虽然应用程序试图将网站替代&#xff0c;但前端 Web 开发业务仍在快速变化和增长&#xff0c;前端开发人员的功能并没有消失。以下介绍一款前端开发者提升效率的工具。 目录 一、低代码工具前景 二、如何理解低代码工具 三、前端…

【怎么提高性能和解决高并发】

怎么解决高并发 解决高并发的整体流程大概是&#xff1a; 先进行性能评估、再进行性能测试、然后找到程序可以承受的临界点、最后针对出问题的地方&#xff0c;进行优化。当然硬件设置对高并发的影响也很重要&#xff0c;如果达到硬件天花板&#xff0c;那么再怎么优化程序都…

java gc分析

使用工具转换&#xff1a;https://ctbots.com/#/ 通用GC分析 jstat -gc -t pid堆内存分析 jstat -gccapacity -t pid年轻代GC分析 jstat -gcnew -t pid年轻代内存分析 jstat -gcnewcapacity -t pid老年代GC分析 jstat -gcold -t pid老年代内存分析 jstat -gcoldcapacity…

Ubuntu18.04 安装opencv 4.8.0教程(亲测可用)

1. 安装准备 安装前需要下载一些必须的依赖项。 不同版本opencv依赖会有不同&#xff0c;具体见官网opencv安装 sudo apt-get install build-essential sudo apt-get install cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev sudo apt-…

AI大模型之花,绽放在鸿蒙沃土

随着生成式AI日益火爆&#xff0c;大语言模型能力引发了越来越多对于智慧语音助手的期待。 我们相信&#xff0c;AI大模型能力加持下的智慧语音助手一定会很快落地&#xff0c;这个预判不仅来自对AI大模型的观察&#xff0c;更来自对鸿蒙的了解。鸿蒙一定会很快升级大模型能力&…

拥抱创新:用Kotlin开发高效Android应用

拥抱创新&#xff1a;用Kotlin开发高效Android应用 引言 在当今数字时代&#xff0c;移动应用已经成为人们生活中不可或缺的一部分。无论是社交媒体、电子商务还是健康管理&#xff0c;移动应用已经深刻地影响了我们的生活方式。随着移动设备的普及和功能的增强&#xff0c;A…

Android getDrawable()和getColor()

Android getDrawable() 1.过时代码 虽然过时&#xff0c;但是不妨碍使用 context.getResources().getDrawable(R.drawable.xxx) 2.建议代码 context.getDrawable(R.drawable.xxx) 有API限制 3.最新代码 ContextCompat.getDrawable(getContext(), R.drawable.xxx); 有A…

安达发|模具制造业对APS软件需求大幅增长

近年来&#xff0c;中国模具工业以每年15%左右的增速速度快速发展。然而&#xff0c;对于大型、精密、复杂及长寿命模具的需求增长将远超过每年15%的增幅。为应对这一挑战&#xff0c;模具制造业对APS软件的需求大幅度增长&#xff0c;助力行业提速发展。 据统计&#xff0c;中…

linuxARM裸机学习笔记(3)----主频和时钟配置实验

引言&#xff1a;本文主要学习当前linux该如何去配置时钟频率&#xff0c;这也是重中之重。 系统时钟来源&#xff1a; 32.768KHz 晶振是 I.MX6U 的 RTC 时钟源&#xff0c; 24MHz 晶振是 I.MX6U 内核 和其它外设的时钟源 1. 7路PLL时钟源【都是从24MHZ的晶振PLL而来…

一个3年Android的找工作记录

作者&#xff1a;Petterp 这是我最近 1个月 的找工作记录&#xff0c;希望这些经历对你会有所帮助。 有时机会就像一阵风&#xff0c;如果没有握住&#xff0c;那下一阵风什么时候吹来&#xff0c;往往是个运气问题。 写在开始 先说背景: 自考本&#xff0c;3年经验&#xff0…

回归预测 | MATLAB实现SO-CNN-LSTM蛇群算法优化卷积长短期记忆神经网络多输入单输出回归预测

回归预测 | MATLAB实现SO-CNN-LSTM蛇群算法优化卷积长短期记忆神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现SO-CNN-LSTM蛇群算法优化卷积长短期记忆神经网络多输入单输出回归预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 MATLAB实现SO-CNN-LS…