金融数据_PySpark-3.0.3决策树(DecisionTreeClassifier)实例

金融数据_PySpark-3.0.3决策树(DecisionTreeClassifier)实例

逻辑回归: 逻辑回归常被用于二分类问题, 比如涨跌预测。你可以将涨跌标记为类别, 然后使用逻辑回归进行训练。

决策树和随机森林: 决策树和随机森林是用于分类问题的强大模型。它们能够处理非线性关系, 并且对于特征的重要性有较好的解释。

在 PySpark-3.x.x 中构建决策树主要使用 pyspark.ml 模块中的 DecisionTreeClassifier。

下面是一个简单的示例, 演示如何使用 PySpark-3.x.x 构建和训练决策树模型。

实例数据

本实例截取了 “湖北宜化(000422)” 2015年08月06日 - 2015年12月31日的数据。

HBYH_000422_20150806_20151231.csv

Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80

探索思路

这里只是简单示例, 目的在于熟悉 Spark 中的决策树分类器使用方法, 无任何投资引导。

目标:

通过当日数值情况, 预测当日收盘涨跌, 如果 “涨跌幅(Change) >= 0”, 则用 1 表示, 如果 “涨跌幅(Change) < 0”, 则用 0 表示 (二分类标签)。

变量:

  1. 当日最高价

  2. 当日最低价

  3. 当日换手率

  4. 当日成交量

  5. 当日星期几 (星期对价格的影响)

  6. 当日 “短期均线(MA5)” 与 “长期均线(MA10)” 的关系, 如果 “MA5 > MA10”, 则用 1 表示, 如果 “MA5 = MA10”, 则用 0 表示, 如果 “MA5 < MA10”, 则用 -1 表示。

导入 pyspark.sql 相关模块

Spark SQL 是用于结构化数据处理的 Spark 模块。它提供了一种成为 DataFrame 编程抽象, 是由 SchemaRDD 发展而来。

不同于 SchemaRDD 直接继承 RDD, DataFrame 自己实现了 RDD 的绝大多数功能。

from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DateType, IntegerType, DoubleType

导入 pyspark.ml 相关模块

Spark 在核心数据抽象 RDD 的基础上, 支持 4 大组件, 其中机器学习占其一。

进一步的, Spark 中实际上支持两个机器学习模块, MLlib 和 ML, 区别在于前者主要是基于 RDD 数据结构, 当前处于维护状态; 而后者则是 DataFrame 数据结构, 支持更多的算法, 后续将以此为主进行迭代。

所以, 在实际应用中优先使用 ML 子模块。

Spark 的 ML 库与 Python 中的另一大机器学习库 Sklearn 的关系是: Spark 的 ML 库支持大部分机器学习算法和接口功能, 虽远不如 Sklearn 功能全面, 但主要面向分布式训练, 针对大数据。

而 Sklearn 是单点机器学习算法库, 支持几乎所有主流的机器学习算法, 从样例数据, 特征选择, 模型选择和验证, 基础学习算法和集成学习算法, 提供了机器学习一站式解决方案, 但仅支持并行而不支持分布式。

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline

创建 SparkSession 对象

Spark 2.0 以上版本的 spark-shell 在启动时会自动创建一个名为 spark 的 SparkSession 对象。

当需要手工创建时, SparkSession 可以由其伴生对象的 builder 方法创建出来。

spark = SparkSession.builder.master("local[*]").appName("spark").getOrCreate()

使用 Spark 构建 DataFrame 数据 (Optional)

当数据量较小时, 可以使用该方法手工构建 DataFrame 数据。

构建数据行 Row (以前 3 行为例):

Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85")
ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85")
Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")

将构建好的数据行 Row 加入列表 (以前 3 行为例):

Data_Rows = [Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85"),ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85"),Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")
]

生成 DataFrame 数据框 (以前 3 行为例):

SDF = spark.createDataFrame(Data_Rows)

输出 DataFrame 数据框 (以前 3 行为例):

print("[Message] Builded Spark DataFrame: D:\\HBYH_000422_20150806_20151231.csv")
SDF.show()

输出:

+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
|      Date|   Code|Open|High| Low|Close|Pre_Close|   Change|Turnover_Rate|    Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77|     7.93|-0.020177|     0.015498| 1.39152E7|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93|     7.84|  0.01148|     0.018662| 1.67559E7|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84|     7.71| 0.016861|     0.015886| 1.42638E7|7.90|7.81|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+

使用 Spark 读取 CSV 数据

调用 SparkSession 的 .read 方法读取 CSV 数据:

其中 .option 是读取文件时的选项, 左边是 “键(Key)”, 右边是 “值(Value)”, 例如 .option(“header”, “true”) 与 {header = “true”} 类同。

SDF = spark.read.option("header", "true").option("encoding", "utf-8").csv("file:///D:\\HBYH_000422_20150806_20151231.csv")

输出 DataFrame 数据框:

print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
SDF.show()

输出:

Readed CSV File: D:\HBYH_000422_20150806_20151231.csv
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
|      Date|   Code|Open|High| Low|Close|Pre_Close|   Change|Turnover_Rate|  Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77|     7.93|-0.020177|     0.015498|13915200|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93|     7.84| 0.011480|     0.018662|16755900|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84|     7.71| 0.016861|     0.015886|14263800|7.90|7.81|
|2015-12-28|'000422|8.03|8.08|7.70| 7.71|     8.03|-0.039851|     0.030821|27672800|7.91|7.78|
|2015-12-25|'000422|8.03|8.05|7.93| 8.03|     7.99| 0.005006|     0.021132|18974000|7.93|7.78|
|2015-12-24|'000422|7.93|8.16|7.87| 7.99|     7.92| 0.008838|     0.026487|23781900|7.85|7.72|
|2015-12-23|'000422|7.97|8.11|7.88| 7.92|     7.89| 0.003802|     0.042360|38033600|7.80|7.69|
|2015-12-22|'000422|7.86|7.93|7.76| 7.89|     7.83| 0.007663|     0.026929|24178700|7.73|7.68|
|2015-12-21|'000422|7.59|7.89|7.56| 7.83|     7.63| 0.026212|     0.030777|27633600|7.66|7.67|
|2015-12-18|'000422|7.71|7.74|7.57| 7.63|     7.74|-0.014212|     0.024764|22234900|7.62|7.71|
|2015-12-17|'000422|7.58|7.75|7.57| 7.74|     7.55| 0.025166|     0.028054|25188400|7.59|7.77|
|2015-12-16|'000422|7.57|7.62|7.53| 7.55|     7.55| 0.000000|     0.020718|18601600|7.58|7.79|
|2015-12-15|'000422|7.63|7.66|7.52| 7.55|     7.62|-0.009186|     0.025902|23256600|7.64|7.78|
|2015-12-14|'000422|7.40|7.64|7.36| 7.62|     7.51| 0.014647|     0.021005|18860100|7.68|7.76|
|2015-12-11|'000422|7.65|7.70|7.41| 7.51|     7.67|-0.020860|     0.020477|18385900|7.80|7.73|
|2015-12-10|'000422|7.78|7.87|7.65| 7.67|     7.83|-0.020434|     0.019972|17931900|7.95|7.69|
|2015-12-09|'000422|7.76|8.00|7.75| 7.83|     7.77| 0.007722|     0.025137|22569700|8.00|7.68|
|2015-12-08|'000422|8.08|8.18|7.76| 7.77|     8.24|-0.057039|     0.036696|32948200|7.92|7.66|
|2015-12-07|'000422|8.12|8.39|7.94| 8.24|     8.23| 0.001215|     0.064590|57993100|7.84|7.64|
|2015-12-04|'000422|7.85|8.48|7.80| 8.23|     7.92| 0.039141|     0.100106|89881900|7.65|7.58|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
only showing top 20 rows

转换 Spark 中 DateFrame 各列数据类型

通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。

# 转换 Spark 中 DateFrame 数据类型。
SDF = SDF.withColumn("Date",          col("Date").cast(DateType()))
SDF = SDF.withColumn("Open",          col("Open").cast(DoubleType()))
SDF = SDF.withColumn("High",          col("High").cast(DoubleType()))
SDF = SDF.withColumn("Low",           col("Low").cast(DoubleType()))
SDF = SDF.withColumn("Close",         col("Close").cast(DoubleType()))
SDF = SDF.withColumn("Pre_Close",     col("Pre_Close").cast(DoubleType()))
SDF = SDF.withColumn("Change",        col("Change").cast(DoubleType()))
SDF = SDF.withColumn("Turnover_Rate", col("Turnover_Rate").cast(DoubleType()))
SDF = SDF.withColumn("Volume",        col("Volume").cast(IntegerType()))
SDF = SDF.withColumn("MA5",           col("MA5").cast(DoubleType()))
SDF = SDF.withColumn("MA10",          col("MA10").cast(DoubleType()))# 输出 Spark 中 DataFrame 字段和数据类型。
print("[Message] Changed Spark DataFrame Data Type:")
SDF.printSchema()

输出:

[Message] Changed Spark DataFrame Data Type:
root|-- Date: date (nullable = true)|-- Code: string (nullable = true)|-- Open: double (nullable = true)|-- High: double (nullable = true)|-- Low: double (nullable = true)|-- Close: double (nullable = true)|-- Pre_Close: double (nullable = true)|-- Change: double (nullable = true)|-- Turnover_Rate: double (nullable = true)|-- Volume: integer (nullable = true)|-- MA5: double (nullable = true)|-- MA10: double (nullable = true)

将 Spark 的 DateFrame 和 Spark RDD 互相转换并计算数据

编写 “向 spark.sql 的 Row 对象添加字段和字段值” 函数:

def MapFunc_SparkSQL_Row_Add_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row: """[Require] import pyspark[Example] >>> SrcRow = Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10)>>> NewRow = MapFunc_SparkSQL_Row_Add_Field(SrcRow=SrcRow, FldName='Weekday', FldVal=SrcRow['Date'].weekday())>>> print(NewRow)Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10, Weekday=4)"""# Convert Obj "pyspark.sql.types.Row" to Dict. # ----------------------------------------------Row_Dict = SrcRow.asDict()# Add a New Key in the Dictionary With the New Column Name and Value.# ----------------------------------------------Row_Dict[FldName] = FldVal# Convert Dict to Obj "pyspark.sql.types.Row". # ----------------------------------------------NewRow = pyspark.sql.types.Row(**Row_Dict)# ==============================================return NewRow

编写 “判断股票涨跌” 函数:

def MapFunc_Stock_Judgement_Rise_or_Fall(ChgRate:float) -> int: if (ChgRate >= 0.0): return 1if (ChgRate <  0.0): return 0# ==============================================# End of Function.

编写 “判断股票短期均线和长期均线关系” 函数:

def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int: if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.

编写 “返回星期几(中文)” 函数:

def DtmFunc_Weekday_Return_String_CN(SrcDtm:datetime.datetime) -> str:"""[Require] import datetime[Explain] Python3 中 datetime.datetime 对象的 .weekday() 方法返回的是从 0 到 6 的数字 (0 代表周一, 6 代表周日)。"""Weekday_Str_Chinese:list = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]# ==============================================return Weekday_Str_Chinese[SrcDtm.weekday()]

在 Spark 中将 DataFrame 转换为 Spark RDD 并调用自定义函数:

# 在 Spark 中将 DataFrame 转换为 RDD。
CalcRDD = SDF.rdd# --------------------------------------------------
# 调用自定义函数: 提取星期索引。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(Idx)", X["Date"].weekday()))
# ..................................................
# 调用自定义函数: 返回星期几(中文)。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(CN)", DtmFunc_Weekday_Return_String_CN(X["Date"])))
# ..................................................
# 调用自定义函数: 判断股票涨跌。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Rise_Fall", MapFunc_Stock_Judgement_Rise_or_Fall(X["Change"])))
# ..................................................
# 判断股票短期均线和长期均线关系。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "MA_Relationship", MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"])))# 显示计算好的 RDD 前 5 行。
print("[Message] Calculated RDD Top 5 Rows:")
pprint.pprint(CalcRDD.take(5))

输出:

[Message] Calculated RDD Top 5 Rows:
[Row(Date=datetime.date(2015, 12, 31), Code="'000422", Open=7.93, High=7.95, Low=7.76, Close=7.77, Pre_Close=7.93, Change=-0.020177, Turnover_Rate=0.015498, Volume=13915200, MA5=7.86, MA10=7.85, Weekday(Idx)=3, Weekday(CN)='周四', Rise_Fall=0, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 30), Code="'000422", Open=7.86, High=7.93, Low=7.75, Close=7.93, Pre_Close=7.84, Change=0.01148, Turnover_Rate=0.018662, Volume=16755900, MA5=7.9, MA10=7.85, Weekday(Idx)=2, Weekday(CN)='周三', Rise_Fall=1, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 29), Code="'000422", Open=7.72, High=7.85, Low=7.69, Close=7.84, Pre_Close=7.71, Change=0.016861, Turnover_Rate=0.015886, Volume=14263800, MA5=7.9, MA10=7.81, Weekday(Idx)=1, Weekday(CN)='周二', Rise_Fall=1, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 28), Code="'000422", Open=8.03, High=8.08, Low=7.7, Close=7.71, Pre_Close=8.03, Change=-0.039851, Turnover_Rate=0.030821, Volume=27672800, MA5=7.91, MA10=7.78, Weekday(Idx)=0, Weekday(CN)='周一', Rise_Fall=0, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 25), Code="'000422", Open=8.03, High=8.05, Low=7.93, Close=8.03, Pre_Close=7.99, Change=0.005006, Turnover_Rate=0.021132, Volume=18974000, MA5=7.93, MA10=7.78, Weekday(Idx)=4, Weekday(CN)='周五', Rise_Fall=1, MA_Relationship=1)]

计算完成后将 Spark RDD 转换回 Spark 的 DataFrame:

# 在 Spark 中将 RDD 转换为 DataFrame。
NewSDF = CalcRDD.toDF()print("[Message] Convert RDD to DataFrame and Filter Out Key Columns for Display:")
NewSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship"]).show()

输出:

[Message] Convert RDD to DataFrame and Filter Out Key Columns:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
|      Date|   Code|High| Low|Close|   Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85|       周四|        0|              1|
|2015-12-30|'000422|7.93|7.75| 7.93|  0.01148| 7.9|7.85|       周三|        1|              1|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81|       周二|        1|              1|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78|       周一|        0|              1|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78|       周五|        1|              1|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72|       周四|        1|              1|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69|       周三|        1|              1|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68|       周二|        1|              1|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67|       周一|        1|             -1|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71|       周五|        0|             -1|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77|       周四|        1|             -1|
|2015-12-16|'000422|7.62|7.53| 7.55|      0.0|7.58|7.79|       周三|        1|             -1|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78|       周二|        0|             -1|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76|       周一|        1|             -1|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73|       周五|        0|              1|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69|       周四|        0|              1|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68|       周三|        1|              1|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66|       周二|        0|              1|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64|       周一|        1|              1|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58|       周五|        1|              1|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+

字符串索引化 (StringIndexer) 演示 (Only Demo)

StringIndexer (字符串-索引变换) 是一个估计器, 是将字符串列编码为标签索引列。索引位于 [0, numLabels), 按标签频率排序, 频率最高的排 0, 依次类推, 因此最常见的标签获取索引是 0。

# 使用 StringIndexer 转换 Weekday(CN) 列。
MyStringIndexer = StringIndexer(inputCol="Weekday(CN)", outputCol="StrIdx")
# 拟合并转换数据。
IndexedSDF = MyStringIndexer.fit(NewSDF).transform(NewSDF)# 筛选 Date, Weekday(Idx), Weekday(CN), StrIdx 四列, 输出 StringIndexer 效果。
print("[Message] The Effect of StringIndexer:")
IndexedSDF.select(["Date", "Weekday(Idx)", "Weekday(CN)", "StrIdx"]).show()

输出:

[Message] The Effect of StringIndexer:
+----------+------------+-----------+------+
|      Date|Weekday(Idx)|Weekday(CN)|StrIdx|
+----------+------------+-----------+------+
|2015-12-31|           3|       周四|   3.0|
|2015-12-30|           2|       周三|   1.0|
|2015-12-29|           1|       周二|   2.0|
|2015-12-28|           0|       周一|   0.0|
|2015-12-25|           4|       周五|   4.0|
|2015-12-24|           3|       周四|   3.0|
|2015-12-23|           2|       周三|   1.0|
|2015-12-22|           1|       周二|   2.0|
|2015-12-21|           0|       周一|   0.0|
|2015-12-18|           4|       周五|   4.0|
|2015-12-17|           3|       周四|   3.0|
|2015-12-16|           2|       周三|   1.0|
|2015-12-15|           1|       周二|   2.0|
|2015-12-14|           0|       周一|   0.0|
|2015-12-11|           4|       周五|   4.0|
|2015-12-10|           3|       周四|   3.0|
|2015-12-09|           2|       周三|   1.0|
|2015-12-08|           1|       周二|   2.0|
|2015-12-07|           0|       周一|   0.0|
|2015-12-04|           4|       周五|   4.0|
+----------+------------+-----------+------+
only showing top 20 rows

提取 标签(Label)列 和 特征向量(Features)列

在创建特征向量(Features)列时, 将会用到 VectorAssembler 模块, VectorAssembler 将多个特征合并为一个特征向量。

提取 标签(Label) 列:

# 将 Rise_Fall 列复制为 Label 列。
NewSDF = NewSDF.withColumn("Label", col("Rise_Fall"))

创建 特征向量(Features) 列:

# VectorAssembler 将多个特征合并为一个特征向量。
FeaColsName:list = ["High", "Low", "Turnover_Rate", "Volume", "Weekday(Idx)", "MA_Relationship"]
MyAssembler = VectorAssembler(inputCols=FeaColsName, outputCol="Features")# 拟合数据 (可选, 如果在模型训练时使用 Pipeline, 则无需在此步骤拟合数据, 当然也就无法在此步骤预览数据)。
AssembledSDF = MyAssembler.transform(NewSDF)

输出预览:

print("[Message] Assembled Label and Features for DecisionTreeClassifier:")
AssembledSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship", "Label", "Features"]).show()

预览:

[Message] Assembled for DecisionTreeClassifier:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
|      Date|   Code|High| Low|Close|   Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|Label|            Features|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85|       周四|        0|              1|    0|[7.95,7.76,0.0154...|
|2015-12-30|'000422|7.93|7.75| 7.93|  0.01148| 7.9|7.85|       周三|        1|              1|    1|[7.93,7.75,0.0186...|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81|       周二|        1|              1|    1|[7.85,7.69,0.0158...|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78|       周一|        0|              1|    0|[8.08,7.7,0.03082...|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78|       周五|        1|              1|    1|[8.05,7.93,0.0211...|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72|       周四|        1|              1|    1|[8.16,7.87,0.0264...|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69|       周三|        1|              1|    1|[8.11,7.88,0.0423...|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68|       周二|        1|              1|    1|[7.93,7.76,0.0269...|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67|       周一|        1|             -1|    1|[7.89,7.56,0.0307...|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71|       周五|        0|             -1|    0|[7.74,7.57,0.0247...|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77|       周四|        1|             -1|    1|[7.75,7.57,0.0280...|
|2015-12-16|'000422|7.62|7.53| 7.55|      0.0|7.58|7.79|       周三|        1|             -1|    1|[7.62,7.53,0.0207...|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78|       周二|        0|             -1|    0|[7.66,7.52,0.0259...|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76|       周一|        1|             -1|    1|[7.64,7.36,0.0210...|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73|       周五|        0|              1|    0|[7.7,7.41,0.02047...|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69|       周四|        0|              1|    0|[7.87,7.65,0.0199...|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68|       周三|        1|              1|    1|[8.0,7.75,0.02513...|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66|       周二|        0|              1|    0|[8.18,7.76,0.0366...|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64|       周一|        1|              1|    1|[8.39,7.94,0.0645...|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58|       周五|        1|              1|    1|[8.48,7.8,0.10010...|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
only showing top 20 rows

训练 决策树分类器(DecisionTreeClassifier) 模型

将数据集划分为 “训练集” 和 “测试集”:

(TrainingData, TestData) = AssembledSDF.randomSplit([0.8, 0.2], seed=42)

创建 决策树分类器(DecisionTreeClassifier):

DTC = DecisionTreeClassifier(labelCol="Label", featuresCol="Features")

创建 Pipeline (可选):

# 创建 Pipeline, 将特征向量转换和决策树模型组合在一起
# 注意: 如果要使用 Pipeline, 则在创建 特征向量(Features)列 的时候不需要拟合数据, 否则会报 "Output column Features already exists." 的错误。
MyPipeline = Pipeline(stages=[MyAssembler, DTC])

训练 决策树分类器(DecisionTreeClassifier) 模型:

如果在创建 特征向量(Features)列 的时候已经拟合数据:

# 训练模型 (普通模式)。
Model = DTC.fit(TrainingData)

如果在创建 特征向量(Features)列 的时候没有拟合数据:

# 训练模型 (Pipeline 模式)。
Model = MyPipeline.fit(TrainingData)

使用 决策树分类器(DecisionTreeClassifier) 模型预测数据

# 在测试集上进行预测。
Predictions = Model.transform(TestData)# 删除不需要的列 (以免列数太多, 结果显示拥挤, 不好观察)。
Predictions = Predictions.drop("Open")
Predictions = Predictions.drop("High")
Predictions = Predictions.drop("Low")
Predictions = Predictions.drop("Close")
Predictions = Predictions.drop("Pre_Close")
Predictions = Predictions.drop("Turnover_Rate")
Predictions = Predictions.drop("Volume")
Predictions = Predictions.drop("Weekday(Idx)")
Predictions = Predictions.drop("Weekday(CN)")print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
Predictions.show()

输出:

[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+-------------+--------------------+----------+
|      Date|   Code|   Change| MA5|MA10|Rise_Fall|MA_Relationship|Label|            Features|rawPrediction|         probability|prediction|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+-------------+--------------------+----------+
|2015-08-10|'000422| 0.034105| 8.2|7.92|        1|              1|    1|[8.58,8.18,0.0412...|    [0.0,1.0]|           [0.0,1.0]|       1.0|
|2015-08-14|'000422| 0.009479|8.43|8.24|        1|              1|    1|[8.65,8.43,0.0411...|    [1.0,0.0]|           [1.0,0.0]|       0.0|
|2015-08-18|'000422|-0.095455|8.39|8.32|        0|              1|    0|[8.86,7.92,0.0561...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-08-25|'000422|-0.099424|7.52|7.96|        0|             -1|    0|[6.77,6.25,0.0294...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-02|'000422|-0.053412|6.73|6.91|        0|             -1|    0|[6.88,6.3,0.02228...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-10|'000422|-0.031161|6.76|6.74|        0|              1|    0|[7.01,6.76,0.0174...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-18|'000422|      0.0|6.39|6.62|        1|             -1|    1|[6.58,6.3,0.01662...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-09-28|'000422| 0.009464|6.48|6.47|        1|              1|    1|[6.42,6.25,0.0088...|    [0.0,4.0]|           [0.0,1.0]|       1.0|
|2015-10-19|'000422|-0.007062|6.94|6.72|        0|              1|    0|[7.13,6.92,0.0312...|    [0.0,2.0]|           [0.0,1.0]|       1.0|
|2015-10-20|'000422| 0.008535|6.98|6.81|        1|              1|    1|[7.09,6.94,0.0244...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-10-21|'000422|-0.062059|6.96|6.85|        0|              1|    0|[7.11,6.61,0.0393...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-10-23|'000422| 0.054412|6.95|6.93|        1|              1|    1|[7.22,6.81,0.0471...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-10-27|'000422| 0.033426|7.04|7.01|        1|              1|    1|[7.48,7.08,0.0576...|   [0.0,11.0]|           [0.0,1.0]|       1.0|
|2015-11-02|'000422|-0.027548|7.23| 7.1|        0|              1|    0|[7.26,7.05,0.0168...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-11-11|'000422| 0.005284|7.54|7.37|        1|              1|    1|[7.64,7.52,0.0261...|   [7.0,22.0]|[0.24137931034482...|       1.0|
|2015-11-20|'000422| 0.002635|7.52|7.53|        1|             -1|    1|[7.71,7.53,0.0282...|    [6.0,3.0]|[0.66666666666666...|       0.0|
|2015-12-02|'000422| 0.009511|7.37|7.49|        1|             -1|    1|[7.48,7.2,0.01596...|    [7.0,1.0]|       [0.875,0.125]|       0.0|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+-------------+--------------------+----------+

使用 BinaryClassificationEvaluator 评估模型性能

# 使用 BinaryClassificationEvaluator 评估模型性能。
MyEvaluator = BinaryClassificationEvaluator(labelCol="Label", metricName="areaUnderROC")
auc = MyEvaluator.evaluate(Predictions)print("Area Under ROC (AUC):", auc)

输出:

Area Under ROC (AUC): 0.3

完整代码

#!/usr/bin/python3
# Create By GF 2024-01-06# 在这个例子中, 我们使用 VectorAssembler 将多个特征列合并为一个特征向量, 并使用 DecisionTreeClassifier 构建决策树模型。
# 最后, 我们使用 BinaryClassificationEvaluator 评估模型性能, 通常使用 ROC 曲线下面积 (AUC) 作为评估指标。
# 请根据你的实际数据和问题调整特征列, 标签列以及其他参数。在实际应用中, 你可能需要进行更多的特征工程, 调参和模型评估。import datetime
import pprint
# --------------------------------------------------
import pyspark
# --------------------------------------------------
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DateType, IntegerType, DoubleType
# --------------------------------------------------
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline# 编写 "向 spark.sql 的 Row 对象添加字段和字段值" 函数。
def MapFunc_SparkSQL_Row_Add_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row: """[Require] import pyspark[Example] >>> SrcRow = Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10)>>> NewRow = MapFunc_SparkSQL_Row_Add_Field(SrcRow=SrcRow, FldName='Weekday', FldVal=SrcRow['Date'].weekday())>>> print(NewRow)Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10, Weekday=4)"""# Convert Obj "pyspark.sql.types.Row" to Dict. # ----------------------------------------------Row_Dict = SrcRow.asDict()# Add a New Key in the Dictionary With the New Column Name and Value.# ----------------------------------------------Row_Dict[FldName] = FldVal# Convert Dict to Obj "pyspark.sql.types.Row". # ----------------------------------------------NewRow = pyspark.sql.types.Row(**Row_Dict)# ==============================================return NewRow# 编写 "判断股票涨跌" 函数。
def MapFunc_Stock_Judgement_Rise_or_Fall(ChgRate:float) -> int: if (ChgRate >= 0.0): return 1if (ChgRate <  0.0): return 0# ==============================================# End of Function.# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int: if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.# 编写 "返回星期几(中文)" 函数。
def DtmFunc_Weekday_Return_String_CN(SrcDtm:datetime.datetime) -> str:"""[Require] import datetime[Explain] Python3 中 datetime.datetime 对象的 .weekday() 方法返回的是从 0 到 6 的数字 (0 代表周一, 6 代表周日)。"""Weekday_Str_Chinese:list = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]# ==============================================return Weekday_Str_Chinese[SrcDtm.weekday()]if __name__ == "__main__":# Spark 2.0 以上版本的 spark-shell 在启动时会自动创建一个名为 spark 的 SparkSession 对象。# 当需要手工创建时, SparkSession 可以由其伴生对象的 builder 方法创建出来。spark = SparkSession.builder.master("local[*]").appName("spark").getOrCreate()# 调用 SparkSession 的 .read 方法读取 CSV 数据:# 其中 .option 是读取文件时的选项, 左边是 "键(Key)", 右边是 "值(Value)", 例如 .option("header", "true") 与 {header = "true"} 类同。SDF = spark.read.option("header", "true").option("encoding", "utf-8").csv("file:///D:\\HBYH_000422_20150806_20151231.csv")print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")SDF.show()# 转换 Spark 中 DateFrame 数据类型。SDF = SDF.withColumn("Date",          col("Date").cast(DateType()))SDF = SDF.withColumn("Open",          col("Open").cast(DoubleType()))SDF = SDF.withColumn("High",          col("High").cast(DoubleType()))SDF = SDF.withColumn("Low",           col("Low").cast(DoubleType()))SDF = SDF.withColumn("Close",         col("Close").cast(DoubleType()))SDF = SDF.withColumn("Pre_Close",     col("Pre_Close").cast(DoubleType()))SDF = SDF.withColumn("Change",        col("Change").cast(DoubleType()))SDF = SDF.withColumn("Turnover_Rate", col("Turnover_Rate").cast(DoubleType()))SDF = SDF.withColumn("Volume",        col("Volume").cast(IntegerType()))SDF = SDF.withColumn("MA5",           col("MA5").cast(DoubleType()))SDF = SDF.withColumn("MA10",          col("MA10").cast(DoubleType()))# 输出 Spark 中 DataFrame 字段和数据类型。print("[Message] Changed Spark DataFrame Data Type:")SDF.printSchema()# 在 Spark 中将 DataFrame 转换为 RDD。CalcRDD = SDF.rdd# --------------------------------------------------# 调用自定义函数: 提取星期索引。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(Idx)", X["Date"].weekday()))# ..................................................# 调用自定义函数: 返回星期几(中文)。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(CN)", DtmFunc_Weekday_Return_String_CN(X["Date"])))# ..................................................# 调用自定义函数: 判断股票涨跌。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Rise_Fall", MapFunc_Stock_Judgement_Rise_or_Fall(X["Change"])))# ..................................................# 判断股票短期均线和长期均线关系。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "MA_Relationship", MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"])))# 显示计算好的 RDD 前 5 行。print("[Message] Calculated RDD Top 5 Rows:")pprint.pprint(CalcRDD.take(5))# 在 Spark 中将 RDD 转换为 DataFrame。NewSDF = CalcRDD.toDF()print("[Message] Convert RDD to DataFrame and Filter Out Key Columns for Display:")NewSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship"]).show()# 提取 标签(Label) 列: 将 Rise_Fall 列复制为 Label 列。NewSDF = NewSDF.withColumn("Label", col("Rise_Fall"))# 创建 特征向量(Features) 列: VectorAssembler 将多个特征合并为一个特征向量。FeaColsName:list = ["High", "Low", "Turnover_Rate", "Volume", "Weekday(Idx)", "MA_Relationship"]MyAssembler = VectorAssembler(inputCols=FeaColsName, outputCol="Features")# 创建 特征向量(Features) 列: 拟合数据 (可选, 如果在模型训练时使用 Pipeline, 则无需在此步骤拟合数据, 当然也就无法在此步骤预览数据)。AssembledSDF = MyAssembler.transform(NewSDF)print("[Message] Assembled Label and Features for DecisionTreeClassifier:")AssembledSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship", "Label", "Features"]).show()# 将数据集划分为 "训练集" 和 "测试集"。(TrainingData, TestData) = AssembledSDF.randomSplit([0.8, 0.2], seed=42)# 创建 决策树分类器(DecisionTreeClassifier)。DTC = DecisionTreeClassifier(labelCol="Label", featuresCol="Features")# 创建 Pipeline (可选): 将特征向量转换和决策树模型组合在一起# 注意: 如果要使用 Pipeline, 则在创建 特征向量(Features)列 的时候不需要拟合数据, 否则会报 "Output column Features already exists." 的错误。#MyPipeline = Pipeline(stages=[MyAssembler, DTC])# 训练模型 (普通模式)。Model = DTC.fit(TrainingData)# 训练模型 (Pipeline 模式)。#Model = MyPipeline.fit(TrainingData)# 在测试集上进行预测。Predictions = Model.transform(TestData)# 删除不需要的列 (以免列数太多, 结果显示拥挤, 不好观察)。Predictions = Predictions.drop("Open")Predictions = Predictions.drop("High")Predictions = Predictions.drop("Low")Predictions = Predictions.drop("Close")Predictions = Predictions.drop("Pre_Close")Predictions = Predictions.drop("Turnover_Rate")Predictions = Predictions.drop("Volume")Predictions = Predictions.drop("Weekday(Idx)")Predictions = Predictions.drop("Weekday(CN)")print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")Predictions.show()# 使用 BinaryClassificationEvaluator 评估模型性能。MyEvaluator = BinaryClassificationEvaluator(labelCol="Label", metricName="areaUnderROC")auc = MyEvaluator.evaluate(Predictions)print("Area Under ROC (AUC):", auc)

其它

在这个例子中, 我们使用 VectorAssembler 将多个特征列合并为一个特征向量, 并使用 DecisionTreeClassifier 构建决策树模型。

最后, 我们使用 BinaryClassificationEvaluator 评估模型性能, 通常使用 ROC 曲线下面积 (AUC) 作为评估指标。

请根据你的实际数据和问题调整特征列, 标签列以及其他参数。在实际应用中, 你可能需要进行更多的特征工程, 调参和模型评估。

总结

以上就是关于 金融数据 PySpark-3.0.3决策树(DecisionTreeClassifier)实例 的全部内容。

更多内容可以访问我的代码仓库:

https://gitee.com/goufeng928/public

https://github.com/goufeng928/public

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/791684.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[笔记] BAD PASSWORD ,linux 修改密码历程

随着人们对安全意识的逐渐提升&#xff0c;Linux 中的密码策略也变得越来越复杂&#xff0c;导致使用 passwd 改密时需要花费大量时间来应付密码策略。这里回顾一下这艰难的改密之路。 背景 先描述一下我当前的环境&#xff0c;由于是在测试环境中&#xff0c;有大量用于测试…

Apache Doris 2.1.1 版本正式发布!

亲爱的社区小伙伴们&#xff0c;Apache Doris 2.1.1 版本已于 2024 年 4 月 3 日正式发布。该版本针对 2.1.0 版本出现的问题进行较为全面的优化&#xff0c;提交了若干改进项以及问题修复&#xff0c;进一步提升了系统的性能及稳定性&#xff0c;欢迎大家下载体验。 立即下载&…

ctf_show笔记篇(web入门---SSRF)

ssrf简介 ssrf产生原理&#xff1a; 服务端存在网络请求功能/函数&#xff0c;例如&#xff1a;file_get_contens()这一类类似于curl这种函数传入的参数用户是可控的没有对用户输入做过滤导致的ssrf漏洞 ssrf利用: 用于探测内网服务以及端口探针存活主机以及开放服务探针是否存…

C语言交换二进制位的奇数偶数位

基本思路 我们要先把想要交换的数的二进制位给写出来假如交换13的二进制位&#xff0c;13的二进制位是 0000 0000 0000 0000 0000 0000 0000 1101然后写出偶数位的二进制数&#xff08;偶数位是1的&#xff09; 1010 1010 1010 1010 1010 1010 1010 1010然后写出奇数位的二进…

uniapp切换中英文

一、安装 npm install uni-i18n --save 二、创建中英文切换的文件 1.英文en.js文件 2.中文zh_CN.js文件 三、 main.js中引用 // Vue i18n 国际化 import VueI18n from /common/vue-i18n.min.js; Vue.use(VueI18n);// i18n 部分的配置&#xff0c;引入语言包&#xff0c;注意路…

Linux :进程的程序替换

目录 一、什么是程序替换 1.1程序替换的原理 1.2更改为多进程版本 二、各种exe接口 2.2execlp ​编辑 2.2execv 2.3execle、execve、execvpe 一、什么是程序替换 1.1程序替换的原理 用fork创建子进程后执行的是和父进程相同的程序(但有可能执行不同的代码分支),子进程往…

0基础安装配置Linux-ubuntu环境

Vmtools的安装参见 0基础教你安装VM 17PRO-直接就是专业许可证版_vm17许可证-CSDN博客 在vmtools中安装ubuntu 等待安装 这时候发现没有继续按钮&#xff0c;我们关闭这个界面&#xff0c;进入系统中&#xff0c;先更改分辨率 点击这个三角&#xff0c;因为还么有安装成功&am…

【Canavs与艺术】绘制蓝白绶带大卫之星勋章

【图例】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>用Canvas绘制蓝白绶带大卫之星勋章</title><style type&quo…

蓝桥杯每日一题(筛质数、最大公约数)

3792 质数问题 用的埃氏筛法&#xff0c;st数组保存是否被筛掉&#xff0c;遍历到的st为0的节点就是质数&#xff0c;将其保存。 然后遍历所有相邻的节点得到判断是否存在条件中的质数。 #include<bits/stdc.h> using namespace std; //3792 质数问题 const int N1010…

从 MongoDB 到 PostgreSQL 的大迁移

Infisical&#xff0c;一家做密钥管理的开源商业公司&#xff0c;主要对标的是 HashiCorp Vault Infisical 在过去一年里迅速发展&#xff0c;平台现在每天处理超过 5000 万个密钥&#xff0c;将应用程序配置和私密数据发送给需要的团队、CI/CD 流水线以及服务器 / 应用程序。 …

基于JSP的农产品供销服务系统

背景 互联网的迅猛扩张彻底革新了全球各类组织的运营模式。自20世纪90年代起&#xff0c;中国的政府机关和各类企业便开始探索利用网络系统来处理管理事务。然而&#xff0c;早期的网络覆盖范围有限、用户接受度不高、互联网相关法律法规不完善以及技术开发不够成熟等因素&…

Docker基本入门操作

概述 Docker是一个开放源代码软件&#xff0c;用于自动化应用程序的部署&#xff0c;它允许开发者将应用程序打包到容器中&#xff0c;这些容器在任何地方都能被快速部署和运行。容器虚拟化了操作系统层&#xff0c;使不同的应用程序能够在相同的硬件上安全地运行&#xff0c;…

Python如何解决“滑动拼图”验证码(8)

前言 本文是该专栏的第67篇,后面会持续分享python爬虫干货知识,记得关注。 做过爬虫项目的同学,或多或少都会接触到一些需要解决验证码才能正常获取数据的平台。 在本专栏之前的文章中,笔者有详细介绍通过python来解决多种“验证码”(点选验证,图文验证,滑块验证,滑块…

汽车EDI:如何与奔驰建立EDI连接?

梅赛德斯-奔驰是世界闻名的豪华汽车品牌&#xff0c;无论是技术实力还是历史底蕴都在全球汽车主机厂中居于领先位置。奔驰拥有多种车型&#xff0c;多元化的产品布局不仅满足了不同用户画像的需求&#xff0c;也对其供应链体系有着极大的考验。 本文将为大家介绍梅赛德斯-奔驰乘…

在s390x架构机器上构建frps/frpc镜像 —— 筑梦之路

源码&#xff1a;GitHub - fatedier/frp: A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet. # 克隆代码git clone https://github.com/fatedier/frp.git# 切换目录cd frp# 构建frps服务端docker build -t frps:s390x -f …

c++ 实现线程池、实现异步接口

c 实现线程池&#xff0c;下面给出测试用例 mian.cpp #include <iostream> #include <thread> #include <chrono>#include "threadpool.h" #include "callback_proxy.h"using namespace std; using namespace Demo;bool GetTimeImpl(int…

目标检测——车牌数据集

一、重要性及意义 交通安全与管理&#xff1a;车牌检测和识别技术有助于交通管理部门快速、准确地获取车辆信息&#xff0c;从而更有效地进行交通监控和执法。例如&#xff0c;在违规停车、超速行驶等交通违法行为中&#xff0c;该技术可以帮助交警迅速锁定违规车辆&#xff0…

零日攻击

简介 零日攻击&#xff0c;原名Zero-Day Attack&#xff0c;是指利用软件或系统中未被发现的安全漏洞进行的攻击。攻击者利用这些漏洞来执行恶意代码、窃取数据或控制系统&#xff0c;而系统开发者还未发布相应的补丁或安全更新来修复这些漏洞。 类型 利用已知漏洞的零日攻击…

docker安装jenkins 2024版

docker 指令安装安装 docker run -d --restartalways \ --name jenkins -uroot -p 10340:8080 \ -p 10341:50000 \ -v /home/docker/jenkins:/var/jenkins_home \ -v /var/run/docker.sock:/var/run/docker.sock \ -v /usr/bin/docker:/usr/bin/docker jenkins/jenkins:lts访问…

简述vue3新特性

Vue 3 带来了许多新特性和改进&#xff0c;其中最重要的是 Composition API。但除了这个&#xff0c;Vue 3 还在模板语法、指令、组件等方面有所更新。以下是对 Vue 3 语法的详细说明&#xff1a; Composition API Vue 3 引入了 Composition API&#xff0c;这是一种新的、可…