金融数据_PySpark-3.0.3随机森林(RandomForestClassifier)实例

金融数据_PySpark-3.0.3随机森林(RandomForestClassifier)实例

随机森林 (Random Forest) 和随机森林回归 (Random Forest Regression) 都是基于集成学习的算法, 但它们在任务和输出方面存在一些关键的区别。

随机森林 (Random Forest)
任务类型: 随机森林主要用于分类任务。在分类任务中, 算法试图将输入数据划分到多个类别中的一个。
输出: 随机森林的输出是一个离散的类别标签, 表示输入数据所属的类别。
算法原理: 随机森林是基于决策树的集成学习算法。它通过训练多个决策树, 每个决策树都在随机抽样的数据子集上进行训练, 然后进行投票或平均来做出最终的分类决策。

随机森林回归 (Random Forest Regression)
任务类型: 随机森林回归主要用于回归任务。在回归任务中, 算法试图预测一个连续的数值输出, 而不是一个离散的类别。
输出: 随机森林回归的输出是一个连续的数值, 表示输入数据的预测结果。
算法原理: 随机森林回归同样基于决策树, 但在回归任务中, 每个决策树的输出是一个实数值。最终的预测结果是多个决策树输出的平均值或加权平均值。

关系和共同点
基于决策树: 随机森林和随机森林回归都是基于决策树的集成学习方法, 通过组合多个决策树来提高模型的性能。
随机性: 两者都引入了随机性, 通过在训练过程中对数据进行随机抽样和对特征进行随机选择来增加模型的多样性。
鲁棒性: 集成多个模型可以提高模型的鲁棒性, 减小过拟合的风险。
易于并行化: 由于每个决策树可以独立训练, 因此随机森林和随机森林回归都易于并行化, 适合在大规模数据集上训练。

总体而言, 关键区别在于任务类型和输出。随机森林用于分类, 输出是离散的类别标签;而随机森林回归用于回归, 输出是连续的数值。

* 当使用 PySpark 进行随机森林的构建时, 你可以使用 RandomForestClassifier 类。* 下面是一个简单的示例, 演示如何在 PySpark 中构建和训练随机森林分类器。

实例数据

本实例截取了 “湖北宜化(000422)” 2015年08月06日 - 2015年12月31日的数据。

HBYH_000422_20150806_20151231.csv

Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80

探索思路

这里只是简单示例, 目的在于熟悉 Spark 中的随机森林分类器使用方法, 无任何投资引导。

目标:

通过当日数值情况, 预测当日收盘涨跌, 如果 “涨跌幅(Change) >= 0”, 则用 1 表示, 如果 “涨跌幅(Change) < 0”, 则用 0 表示 (二分类标签)。

变量:

  1. 当日最高价

  2. 当日最低价

  3. 当日换手率

  4. 当日成交量

  5. 当日星期几 (星期对价格的影响)

  6. 当日 “短期均线(MA5)” 与 “长期均线(MA10)” 的关系, 如果 “MA5 > MA10”, 则用 1 表示, 如果 “MA5 = MA10”, 则用 0 表示, 如果 “MA5 < MA10”, 则用 -1 表示。

导入 pyspark.sql 相关模块

Spark SQL 是用于结构化数据处理的 Spark 模块。它提供了一种成为 DataFrame 编程抽象, 是由 SchemaRDD 发展而来。

不同于 SchemaRDD 直接继承 RDD, DataFrame 自己实现了 RDD 的绝大多数功能。

from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DateType, IntegerType, DoubleType

导入 pyspark.ml 相关模块

Spark 在核心数据抽象 RDD 的基础上, 支持 4 大组件, 其中机器学习占其一。

进一步的, Spark 中实际上支持两个机器学习模块, MLlib 和 ML, 区别在于前者主要是基于 RDD 数据结构, 当前处于维护状态; 而后者则是 DataFrame 数据结构, 支持更多的算法, 后续将以此为主进行迭代。

所以, 在实际应用中优先使用 ML 子模块。

Spark 的 ML 库与 Python 中的另一大机器学习库 Sklearn 的关系是: Spark 的 ML 库支持大部分机器学习算法和接口功能, 虽远不如 Sklearn 功能全面, 但主要面向分布式训练, 针对大数据。

而 Sklearn 是单点机器学习算法库, 支持几乎所有主流的机器学习算法, 从样例数据, 特征选择, 模型选择和验证, 基础学习算法和集成学习算法, 提供了机器学习一站式解决方案, 但仅支持并行而不支持分布式。

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline

创建 SparkSession 对象

Spark 2.0 以上版本的 spark-shell 在启动时会自动创建一个名为 spark 的 SparkSession 对象。

当需要手工创建时, SparkSession 可以由其伴生对象的 builder 方法创建出来。

spark = SparkSession.builder.master("local[*]").appName("spark").getOrCreate()

使用 Spark 构建 DataFrame 数据 (Optional)

当数据量较小时, 可以使用该方法手工构建 DataFrame 数据。

构建数据行 Row (以前 3 行为例):

Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85")
ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85")
Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")

将构建好的数据行 Row 加入列表 (以前 3 行为例):

Data_Rows = [Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85"),ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85"),Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")
]

生成 DataFrame 数据框 (以前 3 行为例):

SDF = spark.createDataFrame(Data_Rows)

输出 DataFrame 数据框 (以前 3 行为例):

print("[Message] Builded Spark DataFrame: D:\\HBYH_000422_20150806_20151231.csv")
SDF.show()

输出:

+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
|      Date|   Code|Open|High| Low|Close|Pre_Close|   Change|Turnover_Rate|    Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77|     7.93|-0.020177|     0.015498| 1.39152E7|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93|     7.84|  0.01148|     0.018662| 1.67559E7|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84|     7.71| 0.016861|     0.015886| 1.42638E7|7.90|7.81|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+

使用 Spark 读取 CSV 数据

调用 SparkSession 的 .read 方法读取 CSV 数据:

其中 .option 是读取文件时的选项, 左边是 “键(Key)”, 右边是 “值(Value)”, 例如 .option(“header”, “true”) 与 {header = “true”} 类同。

SDF = spark.read.option("header", "true").option("encoding", "utf-8").csv("file:///D:\\HBYH_000422_20150806_20151231.csv")

输出 DataFrame 数据框:

print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
SDF.show()

输出:

Readed CSV File: D:\HBYH_000422_20150806_20151231.csv
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
|      Date|   Code|Open|High| Low|Close|Pre_Close|   Change|Turnover_Rate|  Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77|     7.93|-0.020177|     0.015498|13915200|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93|     7.84| 0.011480|     0.018662|16755900|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84|     7.71| 0.016861|     0.015886|14263800|7.90|7.81|
|2015-12-28|'000422|8.03|8.08|7.70| 7.71|     8.03|-0.039851|     0.030821|27672800|7.91|7.78|
|2015-12-25|'000422|8.03|8.05|7.93| 8.03|     7.99| 0.005006|     0.021132|18974000|7.93|7.78|
|2015-12-24|'000422|7.93|8.16|7.87| 7.99|     7.92| 0.008838|     0.026487|23781900|7.85|7.72|
|2015-12-23|'000422|7.97|8.11|7.88| 7.92|     7.89| 0.003802|     0.042360|38033600|7.80|7.69|
|2015-12-22|'000422|7.86|7.93|7.76| 7.89|     7.83| 0.007663|     0.026929|24178700|7.73|7.68|
|2015-12-21|'000422|7.59|7.89|7.56| 7.83|     7.63| 0.026212|     0.030777|27633600|7.66|7.67|
|2015-12-18|'000422|7.71|7.74|7.57| 7.63|     7.74|-0.014212|     0.024764|22234900|7.62|7.71|
|2015-12-17|'000422|7.58|7.75|7.57| 7.74|     7.55| 0.025166|     0.028054|25188400|7.59|7.77|
|2015-12-16|'000422|7.57|7.62|7.53| 7.55|     7.55| 0.000000|     0.020718|18601600|7.58|7.79|
|2015-12-15|'000422|7.63|7.66|7.52| 7.55|     7.62|-0.009186|     0.025902|23256600|7.64|7.78|
|2015-12-14|'000422|7.40|7.64|7.36| 7.62|     7.51| 0.014647|     0.021005|18860100|7.68|7.76|
|2015-12-11|'000422|7.65|7.70|7.41| 7.51|     7.67|-0.020860|     0.020477|18385900|7.80|7.73|
|2015-12-10|'000422|7.78|7.87|7.65| 7.67|     7.83|-0.020434|     0.019972|17931900|7.95|7.69|
|2015-12-09|'000422|7.76|8.00|7.75| 7.83|     7.77| 0.007722|     0.025137|22569700|8.00|7.68|
|2015-12-08|'000422|8.08|8.18|7.76| 7.77|     8.24|-0.057039|     0.036696|32948200|7.92|7.66|
|2015-12-07|'000422|8.12|8.39|7.94| 8.24|     8.23| 0.001215|     0.064590|57993100|7.84|7.64|
|2015-12-04|'000422|7.85|8.48|7.80| 8.23|     7.92| 0.039141|     0.100106|89881900|7.65|7.58|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
only showing top 20 rows

转换 Spark 中 DateFrame 各列数据类型

通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。

# 转换 Spark 中 DateFrame 数据类型。
SDF = SDF.withColumn("Date",          col("Date").cast(DateType()))
SDF = SDF.withColumn("Open",          col("Open").cast(DoubleType()))
SDF = SDF.withColumn("High",          col("High").cast(DoubleType()))
SDF = SDF.withColumn("Low",           col("Low").cast(DoubleType()))
SDF = SDF.withColumn("Close",         col("Close").cast(DoubleType()))
SDF = SDF.withColumn("Pre_Close",     col("Pre_Close").cast(DoubleType()))
SDF = SDF.withColumn("Change",        col("Change").cast(DoubleType()))
SDF = SDF.withColumn("Turnover_Rate", col("Turnover_Rate").cast(DoubleType()))
SDF = SDF.withColumn("Volume",        col("Volume").cast(IntegerType()))
SDF = SDF.withColumn("MA5",           col("MA5").cast(DoubleType()))
SDF = SDF.withColumn("MA10",          col("MA10").cast(DoubleType()))# 输出 Spark 中 DataFrame 字段和数据类型。
print("[Message] Changed Spark DataFrame Data Type:")
SDF.printSchema()

输出:

[Message] Changed Spark DataFrame Data Type:
root|-- Date: date (nullable = true)|-- Code: string (nullable = true)|-- Open: double (nullable = true)|-- High: double (nullable = true)|-- Low: double (nullable = true)|-- Close: double (nullable = true)|-- Pre_Close: double (nullable = true)|-- Change: double (nullable = true)|-- Turnover_Rate: double (nullable = true)|-- Volume: integer (nullable = true)|-- MA5: double (nullable = true)|-- MA10: double (nullable = true)

将 Spark 的 DateFrame 和 Spark RDD 互相转换并计算数据

编写 “向 spark.sql 的 Row 对象添加字段和字段值” 函数:

def MapFunc_SparkSQL_Row_Add_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row: """[Require] import pyspark[Example] >>> SrcRow = Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10)>>> NewRow = MapFunc_SparkSQL_Row_Add_Field(SrcRow=SrcRow, FldName='Weekday', FldVal=SrcRow['Date'].weekday())>>> print(NewRow)Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10, Weekday=4)"""# Convert Obj "pyspark.sql.types.Row" to Dict. # ----------------------------------------------Row_Dict = SrcRow.asDict()# Add a New Key in the Dictionary With the New Column Name and Value.# ----------------------------------------------Row_Dict[FldName] = FldVal# Convert Dict to Obj "pyspark.sql.types.Row". # ----------------------------------------------NewRow = pyspark.sql.types.Row(**Row_Dict)# ==============================================return NewRow

编写 “判断股票涨跌” 函数:

def MapFunc_Stock_Judgement_Rise_or_Fall(ChgRate:float) -> int: if (ChgRate >= 0.0): return 1if (ChgRate <  0.0): return 0# ==============================================# End of Function.

编写 “判断股票短期均线和长期均线关系” 函数:

def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int: if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.

编写 “返回星期几(中文)” 函数:

def DtmFunc_Weekday_Return_String_CN(SrcDtm:datetime.datetime) -> str:"""[Require] import datetime[Explain] Python3 中 datetime.datetime 对象的 .weekday() 方法返回的是从 0 到 6 的数字 (0 代表周一, 6 代表周日)。"""Weekday_Str_Chinese:list = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]# ==============================================return Weekday_Str_Chinese[SrcDtm.weekday()]

在 Spark 中将 DataFrame 转换为 Spark RDD 并调用自定义函数:

# 在 Spark 中将 DataFrame 转换为 RDD。
CalcRDD = SDF.rdd# --------------------------------------------------
# 调用自定义函数: 提取星期索引。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(Idx)", X["Date"].weekday()))
# ..................................................
# 调用自定义函数: 返回星期几(中文)。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(CN)", DtmFunc_Weekday_Return_String_CN(X["Date"])))
# ..................................................
# 调用自定义函数: 判断股票涨跌。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Rise_Fall", MapFunc_Stock_Judgement_Rise_or_Fall(X["Change"])))
# ..................................................
# 判断股票短期均线和长期均线关系。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "MA_Relationship", MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"])))# 显示计算好的 RDD 前 5 行。
print("[Message] Calculated RDD Top 5 Rows:")
pprint.pprint(CalcRDD.take(5))

输出:

[Message] Calculated RDD Top 5 Rows:
[Row(Date=datetime.date(2015, 12, 31), Code="'000422", Open=7.93, High=7.95, Low=7.76, Close=7.77, Pre_Close=7.93, Change=-0.020177, Turnover_Rate=0.015498, Volume=13915200, MA5=7.86, MA10=7.85, Weekday(Idx)=3, Weekday(CN)='周四', Rise_Fall=0, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 30), Code="'000422", Open=7.86, High=7.93, Low=7.75, Close=7.93, Pre_Close=7.84, Change=0.01148, Turnover_Rate=0.018662, Volume=16755900, MA5=7.9, MA10=7.85, Weekday(Idx)=2, Weekday(CN)='周三', Rise_Fall=1, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 29), Code="'000422", Open=7.72, High=7.85, Low=7.69, Close=7.84, Pre_Close=7.71, Change=0.016861, Turnover_Rate=0.015886, Volume=14263800, MA5=7.9, MA10=7.81, Weekday(Idx)=1, Weekday(CN)='周二', Rise_Fall=1, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 28), Code="'000422", Open=8.03, High=8.08, Low=7.7, Close=7.71, Pre_Close=8.03, Change=-0.039851, Turnover_Rate=0.030821, Volume=27672800, MA5=7.91, MA10=7.78, Weekday(Idx)=0, Weekday(CN)='周一', Rise_Fall=0, MA_Relationship=1),Row(Date=datetime.date(2015, 12, 25), Code="'000422", Open=8.03, High=8.05, Low=7.93, Close=8.03, Pre_Close=7.99, Change=0.005006, Turnover_Rate=0.021132, Volume=18974000, MA5=7.93, MA10=7.78, Weekday(Idx)=4, Weekday(CN)='周五', Rise_Fall=1, MA_Relationship=1)]

计算完成后将 Spark RDD 转换回 Spark 的 DataFrame:

# 在 Spark 中将 RDD 转换为 DataFrame。
NewSDF = CalcRDD.toDF()print("[Message] Convert RDD to DataFrame and Filter Out Key Columns for Display:")
NewSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship"]).show()

输出:

[Message] Convert RDD to DataFrame and Filter Out Key Columns:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
|      Date|   Code|High| Low|Close|   Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85|       周四|        0|              1|
|2015-12-30|'000422|7.93|7.75| 7.93|  0.01148| 7.9|7.85|       周三|        1|              1|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81|       周二|        1|              1|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78|       周一|        0|              1|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78|       周五|        1|              1|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72|       周四|        1|              1|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69|       周三|        1|              1|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68|       周二|        1|              1|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67|       周一|        1|             -1|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71|       周五|        0|             -1|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77|       周四|        1|             -1|
|2015-12-16|'000422|7.62|7.53| 7.55|      0.0|7.58|7.79|       周三|        1|             -1|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78|       周二|        0|             -1|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76|       周一|        1|             -1|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73|       周五|        0|              1|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69|       周四|        0|              1|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68|       周三|        1|              1|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66|       周二|        0|              1|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64|       周一|        1|              1|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58|       周五|        1|              1|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+

字符串索引化 (StringIndexer) 演示 (Only Demo)

StringIndexer (字符串-索引变换) 是一个估计器, 是将字符串列编码为标签索引列。索引位于 [0, numLabels), 按标签频率排序, 频率最高的排 0, 依次类推, 因此最常见的标签获取索引是 0。

# 使用 StringIndexer 转换 Weekday(CN) 列。
MyStringIndexer = StringIndexer(inputCol="Weekday(CN)", outputCol="StrIdx")
# 拟合并转换数据。
IndexedSDF = MyStringIndexer.fit(NewSDF).transform(NewSDF)# 筛选 Date, Weekday(Idx), Weekday(CN), StrIdx 四列, 输出 StringIndexer 效果。
print("[Message] The Effect of StringIndexer:")
IndexedSDF.select(["Date", "Weekday(Idx)", "Weekday(CN)", "StrIdx"]).show()

输出:

[Message] The Effect of StringIndexer:
+----------+------------+-----------+------+
|      Date|Weekday(Idx)|Weekday(CN)|StrIdx|
+----------+------------+-----------+------+
|2015-12-31|           3|       周四|   3.0|
|2015-12-30|           2|       周三|   1.0|
|2015-12-29|           1|       周二|   2.0|
|2015-12-28|           0|       周一|   0.0|
|2015-12-25|           4|       周五|   4.0|
|2015-12-24|           3|       周四|   3.0|
|2015-12-23|           2|       周三|   1.0|
|2015-12-22|           1|       周二|   2.0|
|2015-12-21|           0|       周一|   0.0|
|2015-12-18|           4|       周五|   4.0|
|2015-12-17|           3|       周四|   3.0|
|2015-12-16|           2|       周三|   1.0|
|2015-12-15|           1|       周二|   2.0|
|2015-12-14|           0|       周一|   0.0|
|2015-12-11|           4|       周五|   4.0|
|2015-12-10|           3|       周四|   3.0|
|2015-12-09|           2|       周三|   1.0|
|2015-12-08|           1|       周二|   2.0|
|2015-12-07|           0|       周一|   0.0|
|2015-12-04|           4|       周五|   4.0|
+----------+------------+-----------+------+
only showing top 20 rows

提取 标签(Label)列 和 特征向量(Features)列

在创建特征向量(Features)列时, 将会用到 VectorAssembler 模块, VectorAssembler 将多个特征合并为一个特征向量。

提取 标签(Label) 列:

# 将 Rise_Fall 列复制为 Label 列。
NewSDF = NewSDF.withColumn("Label", col("Rise_Fall"))

创建 特征向量(Features) 列:

# VectorAssembler 将多个特征合并为一个特征向量。
FeaColsName:list = ["High", "Low", "Turnover_Rate", "Volume", "Weekday(Idx)", "MA_Relationship"]
MyAssembler = VectorAssembler(inputCols=FeaColsName, outputCol="Features")# 拟合数据 (可选, 如果在模型训练时使用 Pipeline, 则无需在此步骤拟合数据, 当然也就无法在此步骤预览数据)。
AssembledSDF = MyAssembler.transform(NewSDF)

输出预览:

print("[Message] Assembled Label and Features for RandomForestClassifier:")
AssembledSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship", "Label", "Features"]).show()

预览:

[Message] Assembled for RandomForestClassifier:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
|      Date|   Code|High| Low|Close|   Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|Label|            Features|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85|       周四|        0|              1|    0|[7.95,7.76,0.0154...|
|2015-12-30|'000422|7.93|7.75| 7.93|  0.01148| 7.9|7.85|       周三|        1|              1|    1|[7.93,7.75,0.0186...|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81|       周二|        1|              1|    1|[7.85,7.69,0.0158...|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78|       周一|        0|              1|    0|[8.08,7.7,0.03082...|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78|       周五|        1|              1|    1|[8.05,7.93,0.0211...|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72|       周四|        1|              1|    1|[8.16,7.87,0.0264...|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69|       周三|        1|              1|    1|[8.11,7.88,0.0423...|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68|       周二|        1|              1|    1|[7.93,7.76,0.0269...|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67|       周一|        1|             -1|    1|[7.89,7.56,0.0307...|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71|       周五|        0|             -1|    0|[7.74,7.57,0.0247...|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77|       周四|        1|             -1|    1|[7.75,7.57,0.0280...|
|2015-12-16|'000422|7.62|7.53| 7.55|      0.0|7.58|7.79|       周三|        1|             -1|    1|[7.62,7.53,0.0207...|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78|       周二|        0|             -1|    0|[7.66,7.52,0.0259...|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76|       周一|        1|             -1|    1|[7.64,7.36,0.0210...|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73|       周五|        0|              1|    0|[7.7,7.41,0.02047...|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69|       周四|        0|              1|    0|[7.87,7.65,0.0199...|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68|       周三|        1|              1|    1|[8.0,7.75,0.02513...|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66|       周二|        0|              1|    0|[8.18,7.76,0.0366...|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64|       周一|        1|              1|    1|[8.39,7.94,0.0645...|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58|       周五|        1|              1|    1|[8.48,7.8,0.10010...|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
only showing top 20 rows

训练 随机森林分类器(RandomForestClassifier) 模型

将数据集划分为 “训练集” 和 “测试集”:

(TrainingData, TestData) = AssembledSDF.randomSplit([0.8, 0.2], seed=42)

创建 随机森林分类器(RandomForestClassifier):

RFC = RandomForestClassifier(labelCol="Label", featuresCol="Features", numTrees=10)

创建 Pipeline (可选):

# 创建 Pipeline, 将特征向量转换和随机森林模型组合在一起
# 注意: 如果要使用 Pipeline, 则在创建 特征向量(Features)列 的时候不需要拟合数据, 否则会报 "Output column Features already exists." 的错误。
MyPipeline = Pipeline(stages=[MyAssembler, RFC])

训练 随机森林分类器(RandomForestClassifier) 模型:

如果在创建 特征向量(Features)列 的时候已经拟合数据:

# 训练模型 (普通模式)。
Model = RFC.fit(TrainingData)

如果在创建 特征向量(Features)列 的时候没有拟合数据:

# 训练模型 (Pipeline 模式)。
Model = MyPipeline.fit(TrainingData)

使用 随机森林分类器(RandomForestClassifier) 模型预测数据

# 在测试集上进行预测。
Predictions = Model.transform(TestData)# 删除不需要的列 (以免列数太多, 结果显示拥挤, 不好观察)。
Predictions = Predictions.drop("Open")
Predictions = Predictions.drop("High")
Predictions = Predictions.drop("Low")
Predictions = Predictions.drop("Close")
Predictions = Predictions.drop("Pre_Close")
Predictions = Predictions.drop("Turnover_Rate")
Predictions = Predictions.drop("Volume")
Predictions = Predictions.drop("Weekday(Idx)")
Predictions = Predictions.drop("Weekday(CN)")print("[Message] Prediction Results on The Test Data Set for RandomForestClassifier:")
Predictions.show()

输出:

[Message] Prediction Results on The Test Data Set for RandomForestClassifier:
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+--------------------+--------------------+----------+
|      Date|   Code|   Change| MA5|MA10|Rise_Fall|MA_Relationship|Label|            Features|       rawPrediction|         probability|prediction|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+--------------------+--------------------+----------+
|2015-08-10|'000422| 0.034105| 8.2|7.92|        1|              1|    1|[8.58,8.18,0.0412...|[3.83333333333333...|[0.38333333333333...|       1.0|
|2015-08-14|'000422| 0.009479|8.43|8.24|        1|              1|    1|[8.65,8.43,0.0411...|[6.33333333333333...|[0.63333333333333...|       0.0|
|2015-08-18|'000422|-0.095455|8.39|8.32|        0|              1|    0|[8.86,7.92,0.0561...|[4.83333333333333...|[0.48333333333333...|       1.0|
|2015-08-25|'000422|-0.099424|7.52|7.96|        0|             -1|    0|[6.77,6.25,0.0294...|[1.24468211527035...|[0.12446821152703...|       1.0|
|2015-09-02|'000422|-0.053412|6.73|6.91|        0|             -1|    0|[6.88,6.3,0.02228...|[2.39316696375519...|[0.23931669637551...|       1.0|
|2015-09-10|'000422|-0.031161|6.76|6.74|        0|              1|    0|[7.01,6.76,0.0174...|[2.40476190476190...|[0.24047619047619...|       1.0|
|2015-09-18|'000422|      0.0|6.39|6.62|        1|             -1|    1|[6.58,6.3,0.01662...|[4.22700534759358...|[0.42270053475935...|       1.0|
|2015-09-28|'000422| 0.009464|6.48|6.47|        1|              1|    1|[6.42,6.25,0.0088...|[3.83333333333333...|[0.38333333333333...|       1.0|
|2015-10-19|'000422|-0.007062|6.94|6.72|        0|              1|    0|[7.13,6.92,0.0312...|[1.44220779220779...|[0.14422077922077...|       1.0|
|2015-10-20|'000422| 0.008535|6.98|6.81|        1|              1|    1|[7.09,6.94,0.0244...|[2.59069264069264...|[0.25906926406926...|       1.0|
|2015-10-21|'000422|-0.062059|6.96|6.85|        0|              1|    0|[7.11,6.61,0.0393...|[3.42857142857142...|[0.34285714285714...|       1.0|
|2015-10-23|'000422| 0.054412|6.95|6.93|        1|              1|    1|[7.22,6.81,0.0471...|[2.47857142857142...|[0.24785714285714...|       1.0|
|2015-10-27|'000422| 0.033426|7.04|7.01|        1|              1|    1|[7.48,7.08,0.0576...|[2.81190476190476...|[0.28119047619047...|       1.0|
|2015-11-02|'000422|-0.027548|7.23| 7.1|        0|              1|    0|[7.26,7.05,0.0168...|[1.62402597402597...|[0.16240259740259...|       1.0|
|2015-11-11|'000422| 0.005284|7.54|7.37|        1|              1|    1|[7.64,7.52,0.0261...|[3.29902597402597...|[0.32990259740259...|       1.0|
|2015-11-20|'000422| 0.002635|7.52|7.53|        1|             -1|    1|[7.71,7.53,0.0282...|[5.74068627450980...|[0.57406862745098...|       0.0|
|2015-12-02|'000422| 0.009511|7.37|7.49|        1|             -1|    1|[7.48,7.2,0.01596...|[7.54901960784313...|[0.75490196078431...|       0.0|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+--------------------+--------------------+----------+

使用 BinaryClassificationEvaluator 评估模型性能

# 使用 BinaryClassificationEvaluator 评估模型性能。
MyEvaluator = BinaryClassificationEvaluator(labelCol="Label", metricName="areaUnderROC")
auc = MyEvaluator.evaluate(Predictions)print("Area Under ROC (AUC):", auc)

输出:

Area Under ROC (AUC): 0.15714285714285714

完整代码

#!/usr/bin/python3
# Create By GF 2024-01-07# 在这个例子中, 我们使用 VectorAssembler 将多个特征列合并为一个特征向量, 并使用 RandomForestClassifier 构建随机森林模型。
# 最后, 我们使用 BinaryClassificationEvaluator 评估模型性能, 通常使用 ROC 曲线下面积 (AUC) 作为评估指标。
# 请根据你的实际数据和问题调整特征列, 标签列以及其他参数。在实际应用中, 你可能需要进行更多的特征工程, 调参和模型评估。import datetime
import pprint
# --------------------------------------------------
import pyspark
# --------------------------------------------------
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DateType, IntegerType, DoubleType
# --------------------------------------------------
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline# 编写 "向 spark.sql 的 Row 对象添加字段和字段值" 函数。
def MapFunc_SparkSQL_Row_Add_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row: """[Require] import pyspark[Example] >>> SrcRow = Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10)>>> NewRow = MapFunc_SparkSQL_Row_Add_Field(SrcRow=SrcRow, FldName='Weekday', FldVal=SrcRow['Date'].weekday())>>> print(NewRow)Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10, Weekday=4)"""# Convert Obj "pyspark.sql.types.Row" to Dict. # ----------------------------------------------Row_Dict = SrcRow.asDict()# Add a New Key in the Dictionary With the New Column Name and Value.# ----------------------------------------------Row_Dict[FldName] = FldVal# Convert Dict to Obj "pyspark.sql.types.Row". # ----------------------------------------------NewRow = pyspark.sql.types.Row(**Row_Dict)# ==============================================return NewRow# 编写 "判断股票涨跌" 函数。
def MapFunc_Stock_Judgement_Rise_or_Fall(ChgRate:float) -> int: if (ChgRate >= 0.0): return 1if (ChgRate <  0.0): return 0# ==============================================# End of Function.# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int: if (Short_MA >= Long_MA): return  1if (Short_MA == Long_MA): return  0if (Short_MA <= Long_MA): return -1# ==============================================# End of Function.# 编写 "返回星期几(中文)" 函数。
def DtmFunc_Weekday_Return_String_CN(SrcDtm:datetime.datetime) -> str:"""[Require] import datetime[Explain] Python3 中 datetime.datetime 对象的 .weekday() 方法返回的是从 0 到 6 的数字 (0 代表周一, 6 代表周日)。"""Weekday_Str_Chinese:list = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]# ==============================================return Weekday_Str_Chinese[SrcDtm.weekday()]if __name__ == "__main__":# Spark 2.0 以上版本的 spark-shell 在启动时会自动创建一个名为 spark 的 SparkSession 对象。# 当需要手工创建时, SparkSession 可以由其伴生对象的 builder 方法创建出来。spark = SparkSession.builder.master("local[*]").appName("spark").getOrCreate()# 调用 SparkSession 的 .read 方法读取 CSV 数据:# 其中 .option 是读取文件时的选项, 左边是 "键(Key)", 右边是 "值(Value)", 例如 .option("header", "true") 与 {header = "true"} 类同。SDF = spark.read.option("header", "true").option("encoding", "utf-8").csv("file:///D:\\HBYH_000422_20150806_20151231.csv")print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")SDF.show()# 转换 Spark 中 DateFrame 数据类型。SDF = SDF.withColumn("Date",          col("Date").cast(DateType()))SDF = SDF.withColumn("Open",          col("Open").cast(DoubleType()))SDF = SDF.withColumn("High",          col("High").cast(DoubleType()))SDF = SDF.withColumn("Low",           col("Low").cast(DoubleType()))SDF = SDF.withColumn("Close",         col("Close").cast(DoubleType()))SDF = SDF.withColumn("Pre_Close",     col("Pre_Close").cast(DoubleType()))SDF = SDF.withColumn("Change",        col("Change").cast(DoubleType()))SDF = SDF.withColumn("Turnover_Rate", col("Turnover_Rate").cast(DoubleType()))SDF = SDF.withColumn("Volume",        col("Volume").cast(IntegerType()))SDF = SDF.withColumn("MA5",           col("MA5").cast(DoubleType()))SDF = SDF.withColumn("MA10",          col("MA10").cast(DoubleType()))# 输出 Spark 中 DataFrame 字段和数据类型。print("[Message] Changed Spark DataFrame Data Type:")SDF.printSchema()# 在 Spark 中将 DataFrame 转换为 RDD。CalcRDD = SDF.rdd# --------------------------------------------------# 调用自定义函数: 提取星期索引。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(Idx)", X["Date"].weekday()))# ..................................................# 调用自定义函数: 返回星期几(中文)。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(CN)", DtmFunc_Weekday_Return_String_CN(X["Date"])))# ..................................................# 调用自定义函数: 判断股票涨跌。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Rise_Fall", MapFunc_Stock_Judgement_Rise_or_Fall(X["Change"])))# ..................................................# 判断股票短期均线和长期均线关系。CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "MA_Relationship", MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"])))# 显示计算好的 RDD 前 5 行。print("[Message] Calculated RDD Top 5 Rows:")pprint.pprint(CalcRDD.take(5))# 在 Spark 中将 RDD 转换为 DataFrame。NewSDF = CalcRDD.toDF()print("[Message] Convert RDD to DataFrame and Filter Out Key Columns for Display:")NewSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship"]).show()# 提取 标签(Label) 列: 将 Rise_Fall 列复制为 Label 列。NewSDF = NewSDF.withColumn("Label", col("Rise_Fall"))# 创建 特征向量(Features) 列: VectorAssembler 将多个特征合并为一个特征向量。FeaColsName:list = ["High", "Low", "Turnover_Rate", "Volume", "Weekday(Idx)", "MA_Relationship"]MyAssembler = VectorAssembler(inputCols=FeaColsName, outputCol="Features")# 创建 特征向量(Features) 列: 拟合数据 (可选, 如果在模型训练时使用 Pipeline, 则无需在此步骤拟合数据, 当然也就无法在此步骤预览数据)。AssembledSDF = MyAssembler.transform(NewSDF)print("[Message] Assembled Label and Features for RandomForestClassifier:")AssembledSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise_Fall", "MA_Relationship", "Label", "Features"]).show()# 将数据集划分为 "训练集" 和 "测试集"。(TrainingData, TestData) = AssembledSDF.randomSplit([0.8, 0.2], seed=42)# 创建 随机森林分类器(RandomForestClassifier)。RFC = RandomForestClassifier(labelCol="Label", featuresCol="Features", numTrees=10)# 创建 Pipeline (可选): 将特征向量转换和随机森林模型组合在一起# 注意: 如果要使用 Pipeline, 则在创建 特征向量(Features)列 的时候不需要拟合数据, 否则会报 "Output column Features already exists." 的错误。#MyPipeline = Pipeline(stages=[MyAssembler, RFC])# 训练模型 (普通模式)。Model = RFC.fit(TrainingData)# 训练模型 (Pipeline 模式)。#Model = MyPipeline.fit(TrainingData)# 在测试集上进行预测。Predictions = Model.transform(TestData)# 删除不需要的列 (以免列数太多, 结果显示拥挤, 不好观察)。Predictions = Predictions.drop("Open")Predictions = Predictions.drop("High")Predictions = Predictions.drop("Low")Predictions = Predictions.drop("Close")Predictions = Predictions.drop("Pre_Close")Predictions = Predictions.drop("Turnover_Rate")Predictions = Predictions.drop("Volume")Predictions = Predictions.drop("Weekday(Idx)")Predictions = Predictions.drop("Weekday(CN)")print("[Message] Prediction Results on The Test Data Set for RandomForestClassifier:")Predictions.show()# 使用 BinaryClassificationEvaluator 评估模型性能。MyEvaluator = BinaryClassificationEvaluator(labelCol="Label", metricName="areaUnderROC")auc = MyEvaluator.evaluate(Predictions)print("Area Under ROC (AUC):", auc)

其它

在这个例子中, 我们使用 VectorAssembler 将多个特征列合并为一个特征向量, 并使用 RandomForestClassifier 构建随机森林模型。

最后, 我们使用 BinaryClassificationEvaluator 评估模型性能, 通常使用 ROC 曲线下面积 (AUC) 作为评估指标。

请根据你的实际数据和问题调整特征列, 标签列以及其他参数。在实际应用中, 你可能需要进行更多的特征工程, 调参和模型评估。

总结

以上就是关于 金融数据 PySpark-3.0.3随机森林(RandomForestClassifier)实例 的全部内容。

更多内容可以访问我的代码仓库:

https://gitee.com/goufeng928/public

https://github.com/goufeng928/public

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/792887.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Cisco路由器配置IPv6 Manual隧道

Cisco路由器配置IPv6 Manual隧道 IPv6与IPv4共存的方式 IPv6与IPv4共存方式大致有三种&#xff1a; 双栈&#xff1a;要求网络中所有设备均同时支持IPv4和IPv6转换&#xff1a;转换这种方式将IPv6协议的报头转换成IPv4协议报头。隧道&#xff1a;假定两个IPv6节点要使用IPv6…

flink源码编译-job提交

1、启动standalone集群的taskmanager standalone集群中的taskmanager启动类为 TaskManagerRunner 2 打开master启动类 通过 ctrln快捷键&#xff0c;找到、并打开类&#xff1a; org.apache.flink.runtime.taskexecutor.TaskManagerRunner 3 修改运⾏配置 基本完全按照mas…

RabbitMQ在云原生环境中部署和应用实践

一、RabbitMQ和云原生技术的关系 RabbitMQ是一种开源的、实现了先进的消息队列协议&#xff08;AMQP&#xff09;的消息队列软件。而云原生技术就是为在公共云、私有云以及其他各种云环境提供应用的一种方法。RabbitMQ和云原生技术在分布式系统和微服务架构中都起到了关键作用…

NIUSHOP完美运营版商城 虚拟商品全功能商城 全能商城小程序 智慧商城系统 全品类百货商城

完美运营版商城/拼团/团购/秒杀/积分/砍价/实物商品/虚拟商品等全功能商城 干干净净 没有一丝多余收据 还没过手其他站 还没乱七八走的广告和后门 后台可以自由拖曳修改前端UI页面 还支持虚拟商品自动发货等功能 挺不错的一套源码 前端UNIAPP 后端PHP 一键部署版本 源码免费…

腾讯云4核8G服务器性能怎么样?能用来干什么?

腾讯云4核8G服务器多少钱&#xff1f;腾讯云4核8G轻量应用服务器12M带宽租用价格646元15个月&#xff0c;活动页面 txybk.com/go/txy 活动链接打开如下图所示&#xff1a; 腾讯云4核8G服务器优惠价格 这台4核8G服务器是轻量应用服务器&#xff0c;详细配置为&#xff1a;轻量4核…

ros小问题之rosdep update time out问题

在另外一篇ROS 2边学边练系列的文章里有写碰到这种问题的解决方法&#xff08;主要参考了其他博主的文章&#xff0c;只是针对ROS 2做了些修改调整&#xff09;&#xff0c;此处单拎出来方便查找。 在ROS 2中执行rosdep update时&#xff0c;报出如下错误&#xff1a; 其实原因…

elsint报错Delete `␍`eslintprettier/prettier

一&#xff0c;原因 这篇博客写得很清楚&#xff1a;解决VSCode Delete ␍eslint(prettier/prettier)错误_vscode 删除cr-CSDN博客 还有这篇文章&#xff0c;解决办法很详细&#xff1a;滑动验证页面 二&#xff0c;解决办法 根目录下新建.prettierrc.js文件 module.exports…

谷歌AI搜索革新:探索高级搜索服务背后的未来趋势

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

vulhub中Apache Solr RemoteStreaming 文件读取与SSRF漏洞复现

Apache Solr 是一个开源的搜索服务器。在Apache Solr未开启认证的情况下&#xff0c;攻击者可直接构造特定请求开启特定配置&#xff0c;并最终造成SSRF或任意文件读取。 访问http://your-ip:8983即可查看Apache Solr后台 1.访问http://your-ip:8983/solr/admin/cores?indexI…

一致性hash问题(负载均衡原理)

一致性哈希问题 简介 一致性Hash是一种特殊的Hash算法&#xff0c;由于其均衡性、持久性的映射特点&#xff0c;被广泛的应用于负载均衡领域&#xff0c;如nginx和memcached都采用了一致性Hash来作为集群负载均衡的方案。 本文将介绍一致性Hash的基本思路&#xff0c;并讨论其…

gpt国内怎么用?最新版本来了

claude 3 opus面世后&#xff0c;这几天已经有许多应用&#xff0c;而其精确以及从不偷懒&#xff08;截止到2024年3月11日还没有偷懒&#xff09;的个性&#xff0c;也使得我们可以用它来首次完成各种需要多轮对话的尝试。 今天我们想要进行的一项尝试就是—— 如何从一个不知…

Python Django ORM使用简单的增,删,改,查

Django ORM&#xff08;对象关系映射&#xff09;是Django框架中用于与数据库交互的一个核心组件。它提供了一种方便、直观的方式来定义、查询和操作数据库中的数据。 1. 定义模型&#xff08;Model&#xff09; 首先&#xff0c;你需要通过定义模型来告诉Django你的数据应该…

正则表达式完全指南:语法、用法及JavaScript实例

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…

生成器、迭代器、可迭代对象

生成器、迭代器、可迭代对象 生成器 函数体中包含yield关键字的就是生成器 把生成 器传给 next(...) 函数时&#xff0c;生成器函数会向前&#xff0c;执行函数定义体中的 下一个 yield 语句&#xff0c;返回产出的值&#xff0c;并在函数定义体的当前位置暂停。等到再次遇到n…

深度学习与神经网络:从基础到前沿

深度学习与神经网络是人工智能领域中的重要分支&#xff0c;其应用范围涵盖图像识别、语音识别、自然语言处理等多个领域&#xff0c;对于推动人工智能技术的发展具有重要意义。本文将从深度学习的基础原理开始&#xff0c;逐步探讨神经网络的结构、训练方法&#xff0c;以及在…

认识 Redis 与 分布式

Redis 官网页面 Redis官网链接 Redis 的简介 Redis 是一个在内存中存储数据的中间件 一方面用于作为数据库&#xff0c;另一方面用于作为数据缓存&#xff0c;适用于分布式系统中 Redis 基于网络&#xff0c;进行进程间通信&#xff0c;把自己内存中的变量给别的进程&#xf…

Leetcode 300. 最长递增子序列

心路历程&#xff1a; 经典的子串/子序列的DP问题&#xff0c;这道题需要按照最后一个元素包含在子序列的角度去建模比较好做。 状态&#xff1a;以nums[i]为结尾的最长严格递增子序列的长度 动作候选集&#xff1a;每一个[0, i)之间满足比nums[i]小的元素 返回值&#xff1a…

Python超市商品管理系统

系统需要用户先登录&#xff0c;再进行操作&#xff0c;其中包含一下功能菜单 1、显示商品列表 2、增加商品信息 3、删除商品 4、设置商品折扣 5、修改商品价格信息 6、退出 a、使用列表嵌套字典的方式保存用户数据&#xff08;包含用户名、密码、姓名&#xff09;&#xff1…

C#/WPF Inno Setup打包程序

Inno Setup介绍 Inno Setup 是一个免费的 Windows 安装程序制作软件。第一次发表是在 1997 年&#xff0c;现在已经更新到Inno Setup 6了。Inno Setup是一个十分简单实用的打包小工具&#xff0c;可以按照我们自己的意愿设置功能&#xff0c;稳定性也很好。 官方网址&#xff1…

F - 创新型机器猫 高性能战斗机器人(遇到过的题,做个笔记)

我的代码&#xff1a; #include <iostream> #include <vector> using namespace std; int main() {string str;cin >> str;int dxy[][2] { {0,1},{1,0},{0,-1},{-1,0} }; //设置偏移量&#xff0c;按照右转顺序&#xff1a;北->东->南->西int now…