机器学习在股票价格预测中有重要的应用。在这个机器学习项目中,我们将讨论如何预测股票的收益。这是一个非常复杂的任务,充满了不确定性。我们将会把这个项目分成两部分进行开发:
首先,我们将学习如何使用 LSTM 神经网络预测股票价格。
然后,我们将使用 Plotly Dash 构建一个用于股票分析的仪表板。
股票价格预测项目仪表板
股票价格预测项目
数据集
为了构建股票价格预测模型,我们将使用“印度国家证券交易所(NSE)TATA GLOBAL”数据集。这是来自印度国家证券交易所的 Tata 全球饮料有限公司的 Tata 饮料数据集:
为了构建股票分析的仪表板,我们将使用另一个包含多个股票(如苹果、微软、脸书)的数据集:
源代码
下载地址:链接: 源代码 及 Tata 饮料数据集 多个股票(如苹果、微软、脸书)的数据集
使用 LSTM 预测股票价格
- 导入:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.pylab import rcParams
rcParams['figure.figsize']=20,10
from keras.models import Sequential
from keras.layers import LSTM,Dropout,Dense
from sklearn.preprocessing import MinMaxScaler
- 读取数据集:
df=pd.read_csv("NSE-TATA.csv")
df.head()
读取股票数据
- 从数据框中分析收盘价:
df["Date"]=pd.to_datetime(df.Date,format="%Y-%m-%d")
df.index=df['Date']
plt.figure(figsize=(16,8))
plt.plot(df["Close"],label='Close Price history')
分析股票价格
- 按日期时间排序并筛选“Date”和“Close”列:
data=df.sort_index(ascending=True,axis=0)
new_dataset=pd.DataFrame(index=range(0,len(df)),columns=['Date','Close'])
for i in range(0,len(data)):new_dataset["Date"][i]=data['Date'][i]new_dataset["Close"][i]=data["Close"][i]
- 对新的筛选数据集进行归一化:
scaler=MinMaxScaler(feature_range=(0,1))
final_dataset=new_dataset.values
train_data=final_dataset[0:987,:]
valid_data=final_dataset[987:,:]
new_dataset.index=new_dataset.Date
new_dataset.drop("Date",axis=1,inplace=True)
scaler=MinMaxScaler(feature_range=(0,1))
scaled_data=scaler.fit_transform(final_dataset)
x_train_data,y_train_data=[],[]
for i in range(60,len(train_data)):x_train_data.append(scaled_data[i-60:i,0])y_train_data.append(scaled_data[i,0])x_train_data,y_train_data=np.array(x_train_data),np.array(y_train_data)
x_train_data=np.reshape(x_train_data,(x_train_data.shape[0],x_train_data.shape[1],1))
- 构建和训练 LSTM 模型:
lstm_model=Sequential()
lstm_model.add(LSTM(units=50,return_sequences=True,input_shape=(x_train_data.shape[1],1)))
lstm_model.add(LSTM(units=50))
lstm_model.add(Dense(1))
inputs_data=new_dataset[len(new_dataset)-len(valid_data)-60:].values
inputs_data=inputs_data.reshape(-1,1)
inputs_data=scaler.transform(inputs_data)
lstm_model.compile(loss='mean_squared_error',optimizer='adam')
lstm_model.fit(x_train_data,y_train_data,epochs=1,batch_size=1,verbose=2)
- 从数据集中抽取样本,利用 LSTM 模型进行股票价格预测:
X_test=[]
for i in range(60,inputs_data.shape[0]):X_test.append(inputs_data[i-60:i,0])
X_test=np.array(X_test)
X_test=np.reshape(X_test,(X_test.shape[0],X_test.shape[1],1))
predicted_closing_price=lstm_model.predict(X_test)
predicted_closing_price=scaler.inverse_transform(predicted_closing_price)
- 保存 LSTM 模型:
lstm_model.save("saved_model.h5")
- 真实股票成本与预测股票成本对比可视化:
train_data=new_dataset[:987]
valid_data=new_dataset[987:]
valid_data['Predictions']=predicted_closing_price
plt.plot(train_data["Close"])
plt.plot(valid_data[['Close',"Predictions"]])
可以看到,LSTM 模型预测的股票价格与实际股票价格相当接近。
使用 Plotly Dash 构建仪表板
在本节中,我们将构建一个仪表板用于分析股票。Dash 是一个 Python 框架,它在 Flask 和 React.js 之上提供了一层抽象,用于构建分析型 Web 应用程序。
在继续之前,你需要安装 Dash。在终端运行以下命令。
pip3 install dash
pip3 install dash-html-components
pip3 install dash-core-components
现在创建一个新的 Python 文件 stock_app.py
并粘贴以下脚本:
import dash
import dash_core_components as dcc
import dash_html_components as html
import pandas as pd
import plotly.graph_objs as go
from dash.dependencies import Input, Output
from keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
import numpy as npapp = dash.Dash()
server = app.server
scaler = MinMaxScaler(feature_range=(0,1))
df_nse = pd.read_csv("./NSE-TATA.csv")
df_nse["Date"] = pd.to_datetime(df_nse.Date, format="%Y-%m-%d")
df_nse.index = df_nse['Date']
data = df_nse.sort_index(ascending=True, axis=0)
new_data = pd.DataFrame(index=range(0, len(df_nse)), columns=['Date', 'Close'])
for i in range(0, len(data)):new_data["Date"][i] = data['Date'][i]new_data["Close"][i] = data["Close"][i]
new_data.index = new_data.Date
new_data.drop("Date", axis=1, inplace=True)
dataset = new_data.values
train = dataset[0:987, :]
valid = dataset[987:, :]
scaler = MinMaxScaler(feature_range=(0,1))
scaled_data = scaler.fit_transform(dataset)
x_train, y_train = [], []
for i in range(60, len(train)):x_train.append(scaled_data[i-60:i, 0])y_train.append(scaled_data[i, 0])x_train, y_train = np.array(x_train), np.array(y_train)
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
model = load_model("saved_model.h5")
inputs = new_data[len(new_data)-len(valid)-60:].values
inputs = inputs.reshape(-1, 1)
inputs = scaler.transform(inputs)
X_test = []
for i in range(60, inputs.shape[0]):X_test.append(inputs[i-60:i, 0])
X_test = np.array(X_test)
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
closing_price = model.predict(X_test)
closing_price = scaler.inverse_transform(closing_price)
train = new_data[:987]
valid = new_data[987:]
valid['Predictions'] = closing_price
df = pd.read_csv("./stock_data.csv")app.layout = html.Div([html.H1("股票价格分析仪表板", style={"textAlign": "center"}),dcc.Tabs(id="tabs", children=[dcc.Tab(label='NSE-TATAGLOBAL 股票数据', children=[html.Div([html.H2("实际收盘价", style={"textAlign": "center"}),dcc.Graph(id="Actual Data",figure={"data": [go.Scatter(x=train.index,y=valid["Close"],mode='markers')],"layout": go.Layout(title='散点图',xaxis={'title': '日期'},yaxis={'title': '收盘价'})}),html.H2("LSTM 预测收盘价", style={"textAlign": "center"}),dcc.Graph(id="Predicted Data",figure={"data": [go.Scatter(x=valid.index,y=valid["Predictions"],mode='markers')],"layout": go.Layout(title='散点图',xaxis={'title': '日期'},yaxis={'title': '收盘价'})}) ]) ]),dcc.Tab(label='脸书股票数据', children=[html.Div([html.H1("脸书最高价与最低价对比", style={'textAlign': 'center'}),dcc.Dropdown(id='my-dropdown',options=[{'label': '特斯拉', 'value': 'TSLA'},{'label': '苹果', 'value': 'AAPL'}, {'label': '脸书', 'value': 'FB'}, {'label': '微软', 'value': 'MSFT'}], multi=True, value=['FB'],style={"display": "block", "margin-left": "auto", "margin-right": "auto", "width": "60%"}),dcc.Graph(id='highlow'),html.H1("脸书市场交易量", style={'textAlign': 'center'}),dcc.Dropdown(id='my-dropdown2',options=[{'label': '特斯拉', 'value': 'TSLA'},{'label': '苹果', 'value': 'AAPL'}, {'label': '脸书', 'value': 'FB'},{'label': '微软', 'value': 'MSFT'}], multi=True, value=['FB'],style={"display": "block", "margin-left": "auto", "margin-right": "auto", "width": "60%"}),dcc.Graph(id='volume')], className="container"),])])
])@app.callback(Output('highlow', 'figure'),[Input('my-dropdown', 'value')])
def update_graph(selected_dropdown):dropdown = {"TSLA": "特斯拉", "AAPL": "苹果", "FB": "脸书", "MSFT": "微软"}trace1 = []trace2 = []for stock in selected_dropdown:trace1.append(go.Scatter(x=df[df["Stock"] == stock]["Date"],y=df[df["Stock"] == stock]["High"],mode='lines', opacity=0.7, name=f'高 {dropdown[stock]}', textposition='bottom center'))trace2.append(go.Scatter(x=df[df["Stock"] == stock]["Date"],y=df[df["Stock"] == stock]["Low"],mode='lines', opacity=0.6,name=f'低 {dropdown[stock]}', textposition='bottom center'))traces = [trace1, trace2]data = [val for sublist in traces for val in sublist]figure = {'data': data,'layout': go.Layout(colorway=["#5E0DAC", '#FF4F00', '#375CB1', '#FF7400', '#FFF400', '#FF0056'],height=600,title=f"随时间变化的高低价:{', '.join(str(dropdown[i]) for i in selected_dropdown)}",xaxis={"title": "日期",'rangeselector': {'buttons': list([{'count': 1, 'label': '1M', 'step': 'month', 'stepmode': 'backward'},{'count': 6, 'label': '6M', 'step': 'month', 'stepmode': 'backward'},{'step': 'all'}])},'rangeslider': {'visible': True}, 'type': 'date'},yaxis={"title": "价格(美元)"}}}return figure@app.callback(Output('volume', 'figure'),[Input('my-dropdown2', 'value')])
def update_graph(selected_dropdown_value):dropdown = {"TSLA": "特斯拉", "AAPL": "苹果", "FB": "脸书", "MSFT": "微软"}trace1 = []for stock in selected_dropdown_value:trace1.append(go.Scatter(x=df[df["Stock"] == stock]["Date"],y=df[df["Stock"] == stock]["Volume"],mode='lines', opacity=0.7,name=f'交易量 {dropdown[stock]}', textposition='bottom center'))traces = [trace1]data = [val for sublist in traces for val in sublist]figure = {'data': data, 'layout': go.Layout(colorway=["#5E0DAC", '#FF4F00', '#375CB1', '#FF7400', '#FFF400', '#FF0056'],height=600,title=f"随时间变化的市场交易量:{', '.join(str(dropdown[i]) for i in selected_dropdown_value)}",xaxis={"title": "日期",'rangeselector': {'buttons': list([{'count': 1, 'label': '1M', 'step': 'month', 'stepmode': 'backward'},{'count': 6, 'label': '6M','step': 'month', 'stepmode': 'backward'},{'step': 'all'}])},'rangeslider': {'visible': True}, 'type': 'date'},yaxis={"title": "交易量"}}}return figureif __name__ == '__main__':app.run_server(debug=True)
现在运行此文件并打开浏览器中的应用:
python3 stock_app.py
股票价格预测项目仪表板
摘要
股票价格预测是一个适合机器学习初学者的项目;在本教程中,我们学习了如何开发股票价格预测模型以及如何构建用于股票分析的交互式仪表板。我们实现了基于 LSTM 模型的股市预测。另一方面,我们使用了 Python 的 Plotly Dash 框架来构建仪表板。
参考文献及资料链接
参考资料 | 链接 |
---|---|
股票价格预测基础 | https://example.com/ml-basics |
LSTM 神经网络教程 | https://example.com/lstm-tutorial |
TensorFlow 官方文档 | https://tensorflow.org/docs |
Keras 官方文档 | https://keras.io/zh/ |
Scikit-learn 文档 | https://scikit-learn.org/stable/ |
NSE TATA GLOBAL 数据集 | https://example.com/tata-global-dataset |
股票数据集 | https://example.com/stocks-dataset |
运行 Flask 扩展 | https://flask.palletsprojects.com/en/2.3.x/extensions/ |
Plotly 官方网站 | https://plotly.com/python/ |
Plotly 冲浪式图表 (Dash) 官方文档 | https://dash.plotly.com/ |
Pandas 官方文档 | https://pandas.pydata.org/pandas-docs/stable/ |
Numpy 官方文档 | https://numpy.org/doc/stable/ |
LSTM 股票预测实践 | https://medium.com/@example_lstm_pred |
Dash 股票分析仪表板案例 | https://blog.plotly.com/dash-stock-examples/ |
源代码与数据集介绍
股票价格预测项目
在这个机器学习项目中,我们将开发一个基于神经网络的股票预测模型,用于预测股票收益。
学习如何开发股票价格预测模型,并构建一个用于股票分析的交互式仪表板。我们使用 LSTM 模型实现股票市场预测,并使用 Plotly Dash Python 框架构建仪表板。
类别:机器学习、深度学习
编程语言:Python
工具与库:Plotly Dash、LSTM
IDE:Jupyter
前端:Plotly Dash(用于可视化)
后端:无
先决条件:Python、机器学习、深度学习、神经网络
目标受众:教育、开发人员、数据工程师、数据科学家
股票价格数据
该数据集包含关于塔塔全球饮料有限公司(Tata Global Beverages Limited)的股票价格记录。数据集中还包含按日期排列的股票价格,包括开盘价、收盘价、最高价和最低价,以及当天的交易量和成交额。
对于想要尝试数据可视化、数据分析以及多种形式的数据处理技术的人来说,这是一个极好的数据库。
示例数据:
NSE 塔塔全球饮料有限公司
数据格式:
- Date:日期
- Open:开盘价
- High:最高价
- Low:最低价
- Last:最新价
- Close:收盘价
- Total Trade Quantity:总交易量
- Turnover (Lacs):成交额(单位:十万卢比)
股票价格数据
该历史数据集包含关于苹果(Apple)、微软(Microsoft)、脸书(Facebook)等多家公司股票价格的记录。数据集中还包含按日期排列的股票价格,包括开盘价、收盘价、最高价和最低价,以及当天的交易量。
对于想要尝试数据可视化、数据分析以及多种形式的数据处理技术的人来说,这是一个极好的数据库。
示例数据:
股票数据集
数据格式:
- Date:日期
- Open:开盘价
- High:最高价
- Low:最低价
- Close:收盘价
- Volume:交易量
- OpenInt:未平仓合约(适用于期货和期权)
- Stock:股票名称或代码