金融数据_Scikit-Learn决策树(DecisionTreeClassifier)实例

金融数据_Scikit-Learn决策树(DecisionTreeClassifier)实例

逻辑回归: 逻辑回归常被用于二分类问题, 比如涨跌预测。你可以将涨跌标记为类别, 然后使用逻辑回归进行训练。

决策树和随机森林: 决策树和随机森林是用于分类问题的强大模型。它们能够处理非线性关系, 并且对于特征的重要性有较好的解释。

实例数据

本实例截取了 “湖北宜化(000422)” 2015年08月06日 - 2015年12月31日的数据。

HBYH_000422_20150806_20151231.csv

Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80

探索思路

这里只是简单示例, 目的在于熟悉 Scikit-Learn 中的决策树分类器使用方法, 无任何投资引导。

目标:

通过当日数值情况, 预测当日收盘涨跌, 如果 “涨跌幅(Change) >= 0”, 则用 1 表示, 如果 “涨跌幅(Change) < 0”, 则用 0 表示 (二分类标签)。

变量:

  1. 当日最高价

  2. 当日最低价

  3. 当日换手率

  4. 当日成交量

  5. 当日星期几 (星期对价格的影响)

  6. 当日 “短期均线(MA5)” 与 “长期均线(MA10)” 的关系, 如果 “MA5 > MA10”, 则用 1 表示, 如果 “MA5 = MA10”, 则用 0 表示, 如果 “MA5 < MA10”, 则用 -1 表示。

  7. 节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 “圣诞节(Christmas)” 和 “平安夜(Christmas Eve)” 做示例)。

导入 Pandas 相关模块

Pandas 是基于 NumPy 的一种工具, 该工具是为解决数据分析任务而创建的。Pandas 纳入了大量库和一些标准的数据模型, 提供了高效地操作大型数据集所需的工具。

Pandas 提供了大量能使我们快速便捷地处理数据的函数和方法。你很快就会发现, 它是使 Python 成为强大而高效的数据分析环境的重要因素之一。

import pandas as pd

导入 Scikit-Learn 相关模块

Scikit-Learn (以前称为 scikits.learn, 也称为 sklearn) 是针对 Python 编程语言的免费软件机器学习库。

它具有各种分类, 回归和聚类算法, 包括支持向量机, 随机森林, 梯度提升, K均值 和 DBSCAN, 并且旨在与 Python 数值科学库 NumPy 和 SciPy 联合使用。

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler

使用 Pandas 读取 CSV 数据

调用 Pandas 的 .read_csv 方法读取 CSV 数据:

其中 header 参数指定 CSV 文件的表头行, 这里的 header=0 表示表头行在 1 行, 如果 header=None 则表示数据没有列索引, Pandas 则会自动加上索引。

其中 sep 参数指定 CSV 文件的分隔符, 默认情况下都是以 “,” 作为分隔符, 这里的 sep=“,” 表示指定 CSV 文件的分隔符为 “,”。

还有 dtype 参数指定 CSV 某些特定列以特定的数据类型进行读取, 例如 dtype={“Close”:float, “Volume”:int} 表示 “Close” 列以 浮点(float) 类型读取, “Volume” 列以 整数(integer) 类型读取。

PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")

输出 DataFrame 数据框:

print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
print(PDF)

输出:

[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csvDate     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10
0   2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85
1   2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85
2   2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81
3   2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78
4   2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78
..         ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...
94  2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08
95  2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03
96  2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92
97  2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81
98  2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80[99 rows x 12 columns]

转换 Pandas 中 DateFrame 各列数据类型

通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。

# 转换 Pandas 中 DateFrame 数据类型。
PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")
PDF["Open"] =          PDF["Open"].astype("float64")
PDF["High"] =          PDF["High"].astype("float64")
PDF["Low"] =           PDF["Low"].astype("float64")
PDF["Close"] =         PDF["Close"].astype("float64")
PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")
PDF["Change"] =        PDF["Change"].astype("float64")
PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
PDF["Volume"] =        PDF["Volume"].astype("int64")
PDF["MA5"] =           PDF["MA5"].astype("float64")
PDF["MA10"] =          PDF["MA10"].astype("float64")# 输出 Pandas 中 DataFrame 字段和数据类型。
print("[Message] Changed Pandas DataFrame Data Type:")
print(PDF.dtypes)

输出:

[Message] Changed Pandas DataFrame Data Type:
Date             datetime64[ns]
Code                     object
Open                    float64
High                    float64
Low                     float64
Close                   float64
Pre_Close               float64
Change                  float64
Turnover_Rate           float64
Volume                    int64
MA5                     float64
MA10                    float64
dtype: object

在 Pandas 的 DataFrame 中计算数据

编写 “判断股票短期均线和长期均线关系” 函数:

def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.

在 Pandas 的 DataFrame 中直接计算或调用自定义函数:

# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())
# ..................................................
# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
PDF["Festival"] = None
for Idx in PDF.index:if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。
# ..................................................
# 计算数据: 判断股票涨跌。
PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
# ..................................................
# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)# 输出计算好的 DataFrame 数据框。
print("[Message] Calculated DataFrame:")
print(PDF)

输出:

[Message] Calculated DataFrame:Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)   Festival  Rise_Fall  MA_Relationship
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3       None          0                1
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2       None          1                1
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1       None          1                1
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0       None          0                1
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4  Christmas          1                1
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...        ...              ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2       None          0                1
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1       None          0                1
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0       None          1                1
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4       None          1                1
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3       None          1                1[99 rows x 16 columns]

在 Pandas 的 DataFrame 中将字符串类型的特征列转换为数值 (One-Hot Encoding)

pd.get_dummies() 是 Pandas 库中用于独热编码 (One-Hot Encoding) 的函数。它的作用是将分类 (离散) 变量的每个不同取值都拓展为一个新的二进制特征 (0 或 1), 从而方便机器学习模型处理。

# 函数签名:
pd.get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, columns=None, sparse=False, drop_first=False, dtype=None)# 参数说明:
# - data: 要进行独热编码的 DataFrame 或 Series。
# - prefix: 生成的独热编码列的前缀。
# - prefix_sep: 生成的独热编码列的前缀和原始列名之间的分隔符。
# - dummy_na: 是否为原始数据中的缺失值生成独热编码列。
# - columns: 要进行独热编码的列的名称, 如果指定, 则只对这些列进行操作。
# - drop_first: 是否删除第一个独热编码列, 以避免共线性问题。

转换 Festival 特征列为数值:

# 将字符串类型的特征列转换为数值 (独热编码)。
PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)# 输出转换后的 DataFrame 数据框。
print("[Message] DataFrame After One-Hot Encoding:")
print(PDF)

输出:

[Message] DataFrame After One-Hot Encoding:Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)  Rise_Fall  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3          0                1                   0                       0
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2          1                1                   0                       0
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1          1                1                   0                       0
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0          0                1                   0                       0
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4          1                1                   1                       0
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...              ...                 ...                     ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2          0                1                   0                       0
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1          0                1                   0                       0
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0          1                1                   0                       0
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4          1                1                   0                       0
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3          1                1                   0                       0[99 rows x 17 columns]

提取 标签(Label)列 和 特征(Feature)列

提取 标签(Label) 列:

# 提取 标签(Label) 列。
Y = PDF["Rise_Fall"]

提取 特征(Feature) 列:

# 提取 特征(Feature) 列。
X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)

划分训练集和测试集(train_test_split) 以及 特征标准化(StandardScaler)

划分训练集和测试集(train_test_split):

# 数据集划分训练集和测试集(train_test_split)。
X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)

特征标准化(StandardScaler):

在机器学习中, fit_transform 和 transform 是用于数据预处理的常见方法, 它们的作用略有不同:

fit_transform: 该方法将同时拟合和转换数据。

  • 它会根据输入的数据计算所需的转换参数 (例如均值、标准差等), 然后将数据应用这些参数进行转换。

  • 在训练阶段, 通常使用 fit_transform 来对训练集进行拟合和转换。

  • 拟合过程会根据训练集数据计算并保存所需的转换参数, 然后将训练集数据应用这些参数进行转换。

  • 这样做的目的是确保在后续对测试集或新数据进行转换时使用相同的转换参数。

transform: 该方法仅对数据进行转换, 不进行拟合过程。

  • 它根据之前使用 fit_transform 得到的转换参数, 将这些参数应用于新的数据, 使其按照相同的转换方式进行处理。

  • 在测试阶段或对新数据应用模型时, 通常使用 transform 方法对测试集或新数据进行转换。

简而言之, fit_transform 方法用于拟合转换器并将数据进行转换, 而 transform 方法仅用于将数据按照已经拟合的转换器进行转换。

在代码中的具体应用上, 通常将 fit_transform 用于训练集的拟合和转换, 将 transform 用于测试集或新数据的转换, 以保证数据的一致性和正确的预处理操作。

# 特征标准化(StandardScaler)。
Obj_Scaler = StandardScaler()
X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
X_Test_Scaled = Obj_Scaler.transform(X_Test)

训练 决策树分类器(DecisionTreeClassifier) 模型

创建 决策树分类器(DecisionTreeClassifier):

# 创建 决策树分类器(DecisionTreeClassifier)。
DTC = DecisionTreeClassifier(random_state=42)

训练 决策树分类器(DecisionTreeClassifier) 模型:

# 训练 决策树分类器(DecisionTreeClassifier) 模型。
DTC.fit(X_Train_Scaled, Y_Train)# Value of Return:
# +----------------------------------------+
# |▼        DecisionTreeClassifier         |
# +----------------------------------------+
# | DecisionTreeClassifier(random_state=42)|
# +----------------------------------------+

使用 决策树分类器(DecisionTreeClassifier) 模型预测数据

# 在测试集上进行预测。
Y_Pred = DTC.predict(X_Test_Scaled)# 合并预测结果。
Result = X_Test.copy()
Result["Actually"] = Y_Test
Result["Prediction"] = Y_Predprint("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
print(Result)

输出:

[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:High   Low  Turnover_Rate    Volume  Weekday(Idx)  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve  Actually  Prediction
62  6.32  6.18       0.008991   8072900             1               -1                   0                       0         0           1
40  7.54  7.32       0.040463  36330200             3                1                   0                       0         1           1
95  8.68  8.32       0.048444  42343900             1                1                   0                       0         0           1
18  8.39  7.94       0.064590  57993100             0                1                   0                       0         1           1
97  8.28  8.08       0.025855  22599800             4                1                   0                       0         1           1
84  6.77  6.09       0.029963  26190200             2               -1                   0                       0         1           0
64  6.56  6.25       0.012584  11298800             4                1                   0                       0         0           1
42  7.13  7.02       0.014938  13412400             1                1                   0                       0         0           0
10  7.75  7.57       0.028054  25188400             3               -1                   0                       0         1           0
0   7.95  7.76       0.015498  13915200             3                1                   0                       0         0           1
31  7.54  7.38       0.014173  12725000             2               -1                   0                       0         0           0
76  7.09  6.86       0.028974  25325600             2                1                   0                       0         1           1
47  7.48  7.08       0.057658  51769300             1                1                   0                       0         1           1
26  7.64  7.51       0.020919  18782900             2                1                   0                       0         1           1
44  7.38  7.10       0.022821  20490200             4                1                   0                       0         1           0
4   8.05  7.93       0.021132  18974000             4                1                   1                       0         1           1
22  7.39  7.23       0.012308  11050700             1               -1                   0                       0         1           1
12  7.66  7.52       0.025902  23256600             1               -1                   0                       0         0           1
88  8.56  8.14       0.031764  27764200             3                1                   0                       0         0           1
73  6.95  6.18       0.021233  18559600             0                1                   0                       0         0           1

使用 accuracy_score 评估模型性能

# 评估模型性能。
Accuracy = accuracy_score(Y_Test, Y_Pred)
print("Accuracy:", Accuracy)
print("\n")# 输出分类报告。
print("Classification Report:")
print(classification_report(Y_Test, Y_Pred))

输出:

Accuracy: 0.5Classification Report:precision    recall  f1-score   support0       0.40      0.22      0.29         91       0.53      0.73      0.62        11accuracy                           0.50        20macro avg       0.47      0.47      0.45        20
weighted avg       0.47      0.50      0.47        20

完整代码

#!/usr/bin/python3
# Create By GF 2024-01-04# 在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。
# 为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。
# 然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。
# 最后, 我们训练模型、进行预测, 并评估模型性能。
# 请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。import datetime
# --------------------------------------------------
import pandas as pd
# --------------------------------------------------
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.if __name__ == "__main__":PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")print(PDF)# 转换 Pandas 中 DateFrame 数据类型。PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")PDF["Open"] =          PDF["Open"].astype("float64")PDF["High"] =          PDF["High"].astype("float64")PDF["Low"] =           PDF["Low"].astype("float64")PDF["Close"] =         PDF["Close"].astype("float64")PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")PDF["Change"] =        PDF["Change"].astype("float64")PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")PDF["Volume"] =        PDF["Volume"].astype("int64")PDF["MA5"] =           PDF["MA5"].astype("float64")PDF["MA10"] =          PDF["MA10"].astype("float64")# 输出 Pandas 中 DataFrame 字段和数据类型。print("[Message] Changed Pandas DataFrame Data Type:")print(PDF.dtypes)# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())# ..................................................# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。PDF["Festival"] = Nonefor Idx in PDF.index:if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。# ..................................................# 计算数据: 判断股票涨跌。PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))# ..................................................# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)# 输出计算好的 DataFrame 数据框。print("[Message] Calculated DataFrame:")print(PDF)# 将字符串类型的特征列转换为数值 (独热编码)。PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)# 输出转换后的 DataFrame 数据框。print("[Message] DataFrame After One-Hot Encoding:")print(PDF)# 提取 标签(Label) 列。Y = PDF["Rise_Fall"]# 提取 特征(Feature) 列。X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)# 数据集划分训练集和测试集(train_test_split)。X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)# 特征标准化(StandardScaler)。Obj_Scaler = StandardScaler()X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)X_Test_Scaled = Obj_Scaler.transform(X_Test)# 创建 决策树分类器(DecisionTreeClassifier)。DTC = DecisionTreeClassifier(random_state=42)# 训练 决策树分类器(DecisionTreeClassifier) 模型。DTC.fit(X_Train_Scaled, Y_Train)# Value of Return:# +----------------------------------------+# |▼        DecisionTreeClassifier         |# +----------------------------------------+# | DecisionTreeClassifier(random_state=42)|# +----------------------------------------+# 在测试集上进行预测。Y_Pred = DTC.predict(X_Test_Scaled)# 合并预测结果。Result = X_Test.copy()Result["Actually"] = Y_TestResult["Prediction"] = Y_Predprint("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")print(Result)# 评估模型性能。Accuracy = accuracy_score(Y_Test, Y_Pred)print("Accuracy:", Accuracy)print("\n")# 输出分类报告。print("Classification Report:")print(classification_report(Y_Test, Y_Pred))

其它

在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。

为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。

然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。

最后, 我们训练模型、进行预测, 并评估模型性能。

请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。

总结

以上就是关于 金融数据 Scikit-Learn决策树(DecisionTreeClassifier)实例 的全部内容。

更多内容可以访问我的代码仓库:

https://gitee.com/goufeng928/public

https://github.com/goufeng928/public

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/796854.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jetpack Bluetooth——更优雅地使用蓝牙

Jetpack Bluetooth——更优雅地使用蓝牙 蓝牙是安卓开发中非常常用的操作&#xff0c;但安卓经过这么多年的迭代&#xff0c;蓝牙的相关接口都经过了很多修改需要适配&#xff0c;还有的接口需要实现一堆函数。。。整套操作虽说不算复杂&#xff0c;但难免感觉不太舒服。 之前…

专题【双指针】【学习题】刷题日记

题目列表 11. 盛最多水的容器 42. 接雨水 15. 三数之和 16. 最接近的三数之和 18. 四数之和 26. 删除有序数组中的重复项 27. 移除元素 75. 颜色分类 167. 两数之和 II - 输入有序数组 2024.04.06 11. 盛最多水的容器 题目 给定一个长度为 n 的整数数组 height 。有 n 条垂…

阿里云服务器 篇二:搭建静态网站

文章目录 系列文章获取静态网站模板应用静态网站模板解压zip文件SCP命令上传文件其他上传文件的方法 系列文章 阿里云服务器 篇一&#xff1a;申请和初始化 阿里云服务器 篇二&#xff1a;搭建静态网站 获取静态网站模板 站长素材&#xff1a;网站中包括大量的免费模板&…

上传视频的核心代码

/*** 上传学习视频信息*/Log(title "上传学习视频信息", businessType BusinessType.INSERT)PostMapping("/uploadVideo")public AjaxResult add(HttpServletRequest request) {return toAjax(videoInfoService.insertVideoInfo(request));}/*** 上传学习…

PHP实现网站微信扫码关注公众号后自动注册登陆实现方法及代码【关注收藏】

在网站注册登陆这环节&#xff0c;增加微信扫码注册登陆&#xff0c;普通的方法需要开通微信开发者平台&#xff0c;生成二维码扫码后才能获取用户的uinonid或openid&#xff0c;实现注册登陆&#xff0c;但这样比较麻烦还要企业认证交费开发者平台&#xff0c;而且没有和公众号…

如何控制Docker容器退出后的自动重启行为?

在Docker中&#xff0c;可以通过以下两种方式来控制容器退出后的自动重启行为&#xff1a; 使用docker run命令时&#xff0c;通过设置--restart参数来指定容器退出后的重启策略。可以使用以下值之一&#xff1a; no: 默认值&#xff0c;容器退出后不会自动重启。always: 容器…

什么是EL表达式?怎么使用?

文章目录 一、什么是EL表达式1、命令格式&#xff1a;${作用域对象别名.共享数据} 二、EL表达式与作用域对象别名1、JSP文件可以使用的作用域对象2、EL表达式提供作用域对象别名3、EL表达式将引用对象属性写入到响应体4、EL表达式简化版 三、EL表达式与运算表达式四、EL表达式提…

【SQL】1890. 2020年最后一次登录(简单写法;窗口函数写法)

前述 sql 中 between 的边界问题 ---- between 边界&#xff1a;闭区间&#xff0c;not between 边界&#xff1a;开区间 在 sql 中&#xff0c; between 边界&#xff1a;闭区间not between 边界&#xff1a;开区间 题目描述 leetcode题目&#xff1a;1890. 2020年最后一…

【leetcode面试经典150题】16.接雨水(C++)

【leetcode面试经典150题】专栏系列将为准备暑期实习生以及秋招的同学们提高在面试时的经典面试算法题的思路和想法。本专栏将以一题多解和精简算法思路为主&#xff0c;题解使用C语言。&#xff08;若有使用其他语言的同学也可了解题解思路&#xff0c;本质上语法内容一致&…

aardio教程五) 写Python风格的aardio代码(字符串篇)

前言 熟悉一个新的语言最麻烦的就是需要了解一些库的使用&#xff0c;特别是基础库的使用。 所以我想给aardio封装一个Python风格的库&#xff0c;Python里的基础库是什么方法名&#xff0c;aardio里也封装同样的方法名。 这样就不需要单独去了解aardio里一些方法的使用细节…

Lanelets_ 高效的自动驾驶地图表达方式

Lanelets: 高效的自动驾驶地图表达方式 附赠自动驾驶学习资料和量产经验&#xff1a;链接 LaneLets是自动驾驶领域高精度地图的一种高效表达方式&#xff0c;它以彼此相互连接的LaneLets来描述自动驾驶可行驶区域&#xff0c;不仅可以表达车道几何&#xff0c;也可以完整表述车…

.NET9 PreView2+.AOT ILC 的重大变化

RyuJIT 增强功能 1. 环路优化 (循环优化) 这种优化实际上是一种 for 循环叠加态的优化&#xff0c;for 循环叠加计算的过程中&#xff0c;会对其中部分变量进行感应。比如循环中放置 0 扩展 (第一个索引为 0)&#xff0c;这种优化灵感来源于 LLVM 标量演化。下面看例子&#…

每天一个数据分析题(二百五十四)

在大数据时代背景下&#xff0c;我们使用的数据主要包含两种类别&#xff0c;一种称为结构化数据&#xff0c;另一种称为非结构化数据。请问以下哪个选项属于非结构化数据&#xff1f; A. 利润表 B. 短视频 C. 产品库存表 D. 产品进货表 题目来源于CDA模拟题库 点击此处获…

LeetCode 每日一题 2024/4/1-2024/4/7

记录了初步解题思路 以及本地实现代码&#xff1b;并不一定为最优 也希望大家能一起探讨 一起进步 目录 4/1 2810. 故障键盘4/2 894. 所有可能的真二叉树4/3 1379. 找出克隆二叉树中的相同节点4/4 2192. 有向无环图中一个节点的所有祖先4/5 1026. 节点与其祖先之间的最大差值4/…

8种专坑运维的 SQL 写法,性能降低100倍,您不来看看?

1、LIMIT 语句 分页查询是最常用的场景之一&#xff0c;但也通常也是最容易出问题的地方。比如对于下面简单的语句&#xff0c;一般 DBA 想到的办法是在 type&#xff0c;name&#xff0c; create_time 字段上加组合索引。这样条件排序都能有效的利用到索引&#xff0c;性能迅…

AIGC实战——ProGAN(Progressive Growing Generative Adversarial Network)

AIGC实战——ProGAN 0. 前言1. ProGAN2. 渐进式训练3. 其他技术3.1 小批标准差3.2 均等学习率3.3 逐像素归一化 4. 图像生成小结系列链接 0. 前言 我们已经学习了使用生成对抗网络 (Generative Adversarial Network, GAN) 解决各种图像生成任务。GAN 的模型架构和训练过程具有…

真实的招生办对话邮件及美国高校官网更新的反 AI 政策

这两年 ChatGPT 的热度水涨船高&#xff0c;其编写功能强大&#xff0c;且具备强大的信息整合效果&#xff0c;所以呈现的内容在一定程度上具备可读性。 那么&#xff0c;美国留学文书可以用 ChatGPT 写吗&#xff1f;使用是否有风险&#xff1f;外网博主 Kushi Uppu 在这个申…

C++20 semaphore(信号量) 详解

头文件在C20中是并发库技术规范&#xff08;Technical Specification, TS&#xff09;的一部分。信号量是同步原语&#xff0c;帮助控制多线程程序中对共享资源的访问。头文件提供了标准C方式来使用信号量。 使用环境 Windows&#xff1a;VS中打开项目属性&#xff0c;修改C语…

基于卷积神经网络的天气识别系统(pytorch框架)【python源码+UI界面+前端界面+功能源码详解】

功能演示&#xff1a; 天气识别系统&#xff0c;vgg16&#xff0c;mobilenet卷积神经网络&#xff08;pytorch框架&#xff09;_哔哩哔哩_bilibili &#xff08;一&#xff09;简介 基于卷积神经网络的天气识别系统是在pytorch框架下实现的&#xff0c;系统中有两个模型可选…

vue+elementUI实现表格组件的封装

效果图&#xff1a; 在父组件使用表格组件 <table-listref"table":stripe"true":loading"loading":set-table-h"slotProps.setMainCardBodyH":table-data"tableData":columns"columns.tableList || []":ra…