Spark MLlib快速入门(1)逻辑回归、Kmeans、决策树、Pipeline、交叉验证

Spark MLlib快速入门(1)逻辑回归、Kmeans、决策树案例

除了scikit-learn外,在spark中也提供了机器学习库,即Spark MLlib。

在Spark MLlib机器学习库提供两套算法实现的API:基于RDD API和基于DataFrame API。今天,主要介绍下DataFrame API的使用,不涉及算法的原理。

主要提供的算法如下:

  • 分类

    • 逻辑回归、贝叶斯支持向量机
  • 聚类

    • K-均值
  • 推荐

    • 交替最小二乘法
  • 回归

    • 线性回归
    • 决策树、随机森林

1 Spark MLlib中逻辑回归在鸢尾花数据集上的应用

鸢尾花数据集,总共150条数据,分为三种类别的鸢尾花。

鸢尾花数据集属于分类算法,构建分类模型,此处使用逻辑回归分类算法构建分类模型,进行预测。

全部基于DataFrame API算法库和特征工程函数使用。

使用的spark版本为2.3。

1.1 读取数据

package com.yyds.tags.ml.classificationimport org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.{Normalizer, StringIndexer, StringIndexerModel, VectorAssembler}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructType}
import org.apache.spark.storage.StorageLevelobject IrisClassification {def main(args: Array[String]): Unit = {// 构建SparkSession实例对象val spark: SparkSession = SparkSession.builder().appName(this.getClass.getSimpleName.stripSuffix("$")).master("local[4]").config("spark.sql.shuffle.partitions",4).getOrCreate()import spark.implicits._// TODO step1 -> 读取数据val isrsSchema: StructType = new StructType().add("sepal_length",DoubleType,nullable = true).add("sepal_width",DoubleType,nullable = true).add("petal_length",DoubleType,nullable = true).add("petal_width",DoubleType,nullable = true).add("category",StringType, nullable = true)val rawIrisDF: DataFrame =  spark.read.option("sep",",")// 当首行不是列名称时候,需要自动设置schema.option("header","false").option("inferSchema","false").schema(isrsSchema).csv("datas/iris/iris.data")rawIrisDF.printSchema()rawIrisDF.show(10,truncate = false)}}
root|-- sepal_length: double (nullable = true)|-- sepal_width: double (nullable = true)|-- petal_length: double (nullable = true)|-- petal_width: double (nullable = true)|-- category: string (nullable = true)+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|category   |
+------------+-----------+------------+-----------+-----------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa|
|4.9         |3.0        |1.4         |0.2        |Iris-setosa|
|4.7         |3.2        |1.3         |0.2        |Iris-setosa|
|4.6         |3.1        |1.5         |0.2        |Iris-setosa|
|5.0         |3.6        |1.4         |0.2        |Iris-setosa|
|5.4         |3.9        |1.7         |0.4        |Iris-setosa|
|4.6         |3.4        |1.4         |0.3        |Iris-setosa|
|5.0         |3.4        |1.5         |0.2        |Iris-setosa|
|4.4         |2.9        |1.4         |0.2        |Iris-setosa|
|4.9         |3.1        |1.5         |0.1        |Iris-setosa|
+------------+-----------+------------+-----------+-----------+

1.2 特征工程

    // TODO step2 -> 特征工程/*1、类别转换数值类型类别特征索引化 -> label2、组合特征值features: Vector*/// 1、类别特征转换 StringIndexerval indexerModel: StringIndexerModel = new StringIndexer().setInputCol("category").setOutputCol("label").fit(rawIrisDF)val df1: DataFrame = indexerModel.transform(rawIrisDF)// 2、组合特征值 VectorAssemblerval assembler: VectorAssembler = new VectorAssembler()// 设置特征列名称.setInputCols(rawIrisDF.columns.dropRight(1)).setOutputCol("raw_features")val rawFeaturesDF: DataFrame = assembler.transform(df1)// 3、特征值正则化,使用L2正则val normalizer: Normalizer = new Normalizer().setInputCol("raw_features").setOutputCol("features").setP(2.0)val featuresDF: DataFrame = normalizer.transform(rawFeaturesDF)// 将数据集缓存,LR算法属于迭代算法,使用多次featuresDF.persist(StorageLevel.MEMORY_AND_DISK).count()featuresDF.printSchema()featuresDF.show(10, truncate = false)
root|-- sepal_length: double (nullable = true)|-- sepal_width: double (nullable = true)|-- petal_length: double (nullable = true)|-- petal_width: double (nullable = true)|-- category: string (nullable = true)|-- label: double (nullable = true)|-- raw_features: vector (nullable = true)|-- features: vector (nullable = true)

在这里插入图片描述

1.3 训练模型

    // TODO step3 -> 模型训练val lr: LogisticRegression = new LogisticRegression()// 设置列名称.setLabelCol("label").setFeaturesCol("features").setPredictionCol("prediction")// 设置迭代次数.setMaxIter(10).setRegParam(0.3) // 正则化参数.setElasticNetParam(0.8) // 弹性网络参数:L1正则和L2正则联合使用val lrModel: LogisticRegressionModel = lr.fit(featuresDF)

1.4 模型预测

    // TODO step4 -> 使用模型预测val predictionDF: DataFrame = lrModel.transform(featuresDF)predictionDF// 获取真实标签类别和预测标签类别.select("label", "prediction").show(10)

在这里插入图片描述

1.5 模型评估

 // TODO step5 -> 模型评估:准确度 = 预测正确的样本数 / 所有的样本数import  org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorval evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("accuracy")# accuracy = 0.9466666666666667println(s"accuracy = ${evaluator.evaluate(predictionDF)}")

1.6 模型的保存与加载

   // TODO step6 ->  模型调优,此处省略// TODO step7 ->  模型保存与加载val modelPath = s"datas/models/lrModel-${System.currentTimeMillis()}"// 保存模型lrModel.save(modelPath)// 加载模型val loadLrModel = LogisticRegressionModel.load(modelPath)// 模型预测loadLrModel.transform(Seq(Vectors.dense(Array(5.1,3.5,1.4,0.2))).map(x => Tuple1.apply(x)).toDF("features")).show(10, truncate = false)// 应用结束,关闭资源spark.stop()

在这里插入图片描述

2 Spark MLlib中KMeans在鸢尾花数据集上的应用

2.1 读取数据集

iris_kmeans.txt数据如下

1 1:5.1 2:3.5 3:1.4 4:0.2
1 1:4.9 2:3.0 3:1.4 4:0.2
1 1:4.7 2:3.2 3:1.3 4:0.2
1 1:4.6 2:3.1 3:1.5 4:0.2
1 1:5.0 2:3.6 3:1.4 4:0.2
1 1:5.4 2:3.9 3:1.7 4:0.4
1 1:4.6 2:3.4 3:1.4 4:0.3
1 1:5.0 2:3.4 3:1.5 4:0.2
1 1:4.4 2:2.9 3:1.4 4:0.2
1 1:4.9 2:3.1 3:1.5 4:0.1
1 1:5.4 2:3.7 3:1.5 4:0.2
1 1:4.8 2:3.4 3:1.6 4:0.2
1 1:4.8 2:3.0 3:1.4 4:0.1
1 1:4.3 2:3.0 3:1.1 4:0.1
1 1:5.8 2:4.0 3:1.2 4:0.2
1 1:5.7 2:4.4 3:1.5 4:0.4
1 1:5.4 2:3.9 3:1.3 4:0.4
1 1:5.1 2:3.5 3:1.4 4:0.3
1 1:5.7 2:3.8 3:1.7 4:0.3
1 1:5.1 2:3.8 3:1.5 4:0.3
1 1:5.4 2:3.4 3:1.7 4:0.2
1 1:5.1 2:3.7 3:1.5 4:0.4
1 1:4.6 2:3.6 3:1.0 4:0.2
1 1:5.1 2:3.3 3:1.7 4:0.5
1 1:4.8 2:3.4 3:1.9 4:0.2
1 1:5.0 2:3.0 3:1.6 4:0.2
1 1:5.0 2:3.4 3:1.6 4:0.4
1 1:5.2 2:3.5 3:1.5 4:0.2
1 1:5.2 2:3.4 3:1.4 4:0.2
1 1:4.7 2:3.2 3:1.6 4:0.2
1 1:4.8 2:3.1 3:1.6 4:0.2
1 1:5.4 2:3.4 3:1.5 4:0.4
1 1:5.2 2:4.1 3:1.5 4:0.1
1 1:5.5 2:4.2 3:1.4 4:0.2
1 1:4.9 2:3.1 3:1.5 4:0.1
1 1:5.0 2:3.2 3:1.2 4:0.2
1 1:5.5 2:3.5 3:1.3 4:0.2
1 1:4.9 2:3.1 3:1.5 4:0.1
1 1:4.4 2:3.0 3:1.3 4:0.2
1 1:5.1 2:3.4 3:1.5 4:0.2
1 1:5.0 2:3.5 3:1.3 4:0.3
1 1:4.5 2:2.3 3:1.3 4:0.3
1 1:4.4 2:3.2 3:1.3 4:0.2
1 1:5.0 2:3.5 3:1.6 4:0.6
1 1:5.1 2:3.8 3:1.9 4:0.4
1 1:4.8 2:3.0 3:1.4 4:0.3
1 1:5.1 2:3.8 3:1.6 4:0.2
1 1:4.6 2:3.2 3:1.4 4:0.2
1 1:5.3 2:3.7 3:1.5 4:0.2
1 1:5.0 2:3.3 3:1.4 4:0.2
2 1:7.0 2:3.2 3:4.7 4:1.4
2 1:6.4 2:3.2 3:4.5 4:1.5
2 1:6.9 2:3.1 3:4.9 4:1.5
2 1:5.5 2:2.3 3:4.0 4:1.3
2 1:6.5 2:2.8 3:4.6 4:1.5
2 1:5.7 2:2.8 3:4.5 4:1.3
2 1:6.3 2:3.3 3:4.7 4:1.6
2 1:4.9 2:2.4 3:3.3 4:1.0
2 1:6.6 2:2.9 3:4.6 4:1.3
2 1:5.2 2:2.7 3:3.9 4:1.4
2 1:5.0 2:2.0 3:3.5 4:1.0
2 1:5.9 2:3.0 3:4.2 4:1.5
2 1:6.0 2:2.2 3:4.0 4:1.0
2 1:6.1 2:2.9 3:4.7 4:1.4
2 1:5.6 2:2.9 3:3.6 4:1.3
2 1:6.7 2:3.1 3:4.4 4:1.4
2 1:5.6 2:3.0 3:4.5 4:1.5
2 1:5.8 2:2.7 3:4.1 4:1.0
2 1:6.2 2:2.2 3:4.5 4:1.5
2 1:5.6 2:2.5 3:3.9 4:1.1
2 1:5.9 2:3.2 3:4.8 4:1.8
2 1:6.1 2:2.8 3:4.0 4:1.3
2 1:6.3 2:2.5 3:4.9 4:1.5
2 1:6.1 2:2.8 3:4.7 4:1.2
2 1:6.4 2:2.9 3:4.3 4:1.3
2 1:6.6 2:3.0 3:4.4 4:1.4
2 1:6.8 2:2.8 3:4.8 4:1.4
2 1:6.7 2:3.0 3:5.0 4:1.7
2 1:6.0 2:2.9 3:4.5 4:1.5
2 1:5.7 2:2.6 3:3.5 4:1.0
2 1:5.5 2:2.4 3:3.8 4:1.1
2 1:5.5 2:2.4 3:3.7 4:1.0
2 1:5.8 2:2.7 3:3.9 4:1.2
2 1:6.0 2:2.7 3:5.1 4:1.6
2 1:5.4 2:3.0 3:4.5 4:1.5
2 1:6.0 2:3.4 3:4.5 4:1.6
2 1:6.7 2:3.1 3:4.7 4:1.5
2 1:6.3 2:2.3 3:4.4 4:1.3
2 1:5.6 2:3.0 3:4.1 4:1.3
2 1:5.5 2:2.5 3:4.0 4:1.3
2 1:5.5 2:2.6 3:4.4 4:1.2
2 1:6.1 2:3.0 3:4.6 4:1.4
2 1:5.8 2:2.6 3:4.0 4:1.2
2 1:5.0 2:2.3 3:3.3 4:1.0
2 1:5.6 2:2.7 3:4.2 4:1.3
2 1:5.7 2:3.0 3:4.2 4:1.2
2 1:5.7 2:2.9 3:4.2 4:1.3
2 1:6.2 2:2.9 3:4.3 4:1.3
2 1:5.1 2:2.5 3:3.0 4:1.1
2 1:5.7 2:2.8 3:4.1 4:1.3
3 1:6.3 2:3.3 3:6.0 4:2.5
3 1:5.8 2:2.7 3:5.1 4:1.9
3 1:7.1 2:3.0 3:5.9 4:2.1
3 1:6.3 2:2.9 3:5.6 4:1.8
3 1:6.5 2:3.0 3:5.8 4:2.2
3 1:7.6 2:3.0 3:6.6 4:2.1
3 1:4.9 2:2.5 3:4.5 4:1.7
3 1:7.3 2:2.9 3:6.3 4:1.8
3 1:6.7 2:2.5 3:5.8 4:1.8
3 1:7.2 2:3.6 3:6.1 4:2.5
3 1:6.5 2:3.2 3:5.1 4:2.0
3 1:6.4 2:2.7 3:5.3 4:1.9
3 1:6.8 2:3.0 3:5.5 4:2.1
3 1:5.7 2:2.5 3:5.0 4:2.0
3 1:5.8 2:2.8 3:5.1 4:2.4
3 1:6.4 2:3.2 3:5.3 4:2.3
3 1:6.5 2:3.0 3:5.5 4:1.8
3 1:7.7 2:3.8 3:6.7 4:2.2
3 1:7.7 2:2.6 3:6.9 4:2.3
3 1:6.0 2:2.2 3:5.0 4:1.5
3 1:6.9 2:3.2 3:5.7 4:2.3
3 1:5.6 2:2.8 3:4.9 4:2.0
3 1:7.7 2:2.8 3:6.7 4:2.0
3 1:6.3 2:2.7 3:4.9 4:1.8
3 1:6.7 2:3.3 3:5.7 4:2.1
3 1:7.2 2:3.2 3:6.0 4:1.8
3 1:6.2 2:2.8 3:4.8 4:1.8
3 1:6.1 2:3.0 3:4.9 4:1.8
3 1:6.4 2:2.8 3:5.6 4:2.1
3 1:7.2 2:3.0 3:5.8 4:1.6
3 1:7.4 2:2.8 3:6.1 4:1.9
3 1:7.9 2:3.8 3:6.4 4:2.0
3 1:6.4 2:2.8 3:5.6 4:2.2
3 1:6.3 2:2.8 3:5.1 4:1.5
3 1:6.1 2:2.6 3:5.6 4:1.4
3 1:7.7 2:3.0 3:6.1 4:2.3
3 1:6.3 2:3.4 3:5.6 4:2.4
3 1:6.4 2:3.1 3:5.5 4:1.8
3 1:6.0 2:3.0 3:4.8 4:1.8
3 1:6.9 2:3.1 3:5.4 4:2.1
3 1:6.7 2:3.1 3:5.6 4:2.4
3 1:6.9 2:3.1 3:5.1 4:2.3
3 1:5.8 2:2.7 3:5.1 4:1.9
3 1:6.8 2:3.2 3:5.9 4:2.3
3 1:6.7 2:3.3 3:5.7 4:2.5
3 1:6.7 2:3.0 3:5.2 4:2.3
3 1:6.3 2:2.5 3:5.0 4:1.9
3 1:6.5 2:3.0 3:5.2 4:2.0
3 1:6.2 2:3.4 3:5.4 4:2.3
3 1:5.9 2:3.0 3:5.1 4:1.8
package com.yyds.tags.ml.clusteringimport org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.sql.{DataFrame, SparkSession}/*** 使用KMeans算法对鸢尾花数据进行聚类操作*/
object IrisClusterTest {def main(args: Array[String]): Unit = {val spark = SparkSession.builder().appName(this.getClass.getSimpleName.stripSuffix("$")).master("local[2]").config("spark.sql.shuffle.partitions", "2").getOrCreate()import org.apache.spark.sql.functions._import spark.implicits._// 1. 读取鸢尾花数据集val irisDF: DataFrame = spark.read.format("libsvm").load("datas/iris/iris_kmeans.txt")irisDF.printSchema()irisDF.show(10, truncate = false)}}
root|-- label: double (nullable = true)|-- features: vector (nullable = true)+-----+-------------------------------+
|label|features                       |
+-----+-------------------------------+
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.9,3.0,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.7,3.2,1.3,0.2])|
|1.0  |(4,[0,1,2,3],[4.6,3.1,1.5,0.2])|
|1.0  |(4,[0,1,2,3],[5.0,3.6,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[5.4,3.9,1.7,0.4])|
|1.0  |(4,[0,1,2,3],[4.6,3.4,1.4,0.3])|
|1.0  |(4,[0,1,2,3],[5.0,3.4,1.5,0.2])|
|1.0  |(4,[0,1,2,3],[4.4,2.9,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.9,3.1,1.5,0.1])|
+-----+-------------------------------+
only showing top 10 rows

2.2 模型训练

// 2. 构建KMeans算法val kmeans: KMeans = new KMeans()// 设置输入特征列名称和输出列的名名称.setFeaturesCol("features").setPredictionCol("prediction")// 设置K值为3.setK(3)// 设置最大的迭代次数.setMaxIter(20)// 3. 应用数据集训练模型, 获取转换器val kMeansModel: KMeansModel = kmeans.fit(irisDF)// 获取聚类的簇中心点kMeansModel.clusterCenters.foreach(println)
[5.88360655737705,2.7409836065573776,4.388524590163936,1.4344262295081969]
[5.005999999999999,3.4180000000000006,1.4640000000000002,0.2439999999999999]
[6.853846153846153,3.0769230769230766,5.715384615384615,2.053846153846153]

2.3 模型评估和预测

   // 4. 模型评估val wssse: Double = kMeansModel.computeCost(irisDF)println(s"WSSSE = ${wssse}")// 5. 使用模型预测val predictionDF: DataFrame = kMeansModel.transform(irisDF)predictionDF.show(10, truncate = false)// 应用结束,关闭资源spark.stop()
+-----+-------------------------------+----------+
|label|features                       |prediction|
+-----+-------------------------------+----------+
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.9,3.0,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.7,3.2,1.3,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.6,3.1,1.5,0.2])|1         |
|1.0  |(4,[0,1,2,3],[5.0,3.6,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[5.4,3.9,1.7,0.4])|1         |
|1.0  |(4,[0,1,2,3],[4.6,3.4,1.4,0.3])|1         |
|1.0  |(4,[0,1,2,3],[5.0,3.4,1.5,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.4,2.9,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.9,3.1,1.5,0.1])|1         |
+-----+-------------------------------+----------+

3 Spark MLlib中决策树入门案例

决策树学习采用的是 自顶向下 的递归方法 ,其基本思想是以信息熵为度量构造一颗熵值下降最快的树,到叶子节点处,熵值为0。其具有可读性、分类速度快的优点,是一种有监督学习。

最早提及决策树思想的是Quinlan在1986年提出的ID3算法和1993年提出的C4.5算法,以及Breiman等人在1984年提出的CART算法。

决策树算法是机器学习算法中非常重要的算法之一,既可以分类又可以回归,其中还可以构建出集成学习算法。

由于决策树分类模型 DecisionTreeClassificationModel 属于概率分类模型ProbabilisticClassificationModel ,所以构建模型时要求数据集中标签label必须从0开始

在这里插入图片描述

上述数据集中特征:退款和婚姻状态,都是类别类型特征,需要将其转换为数值特征,数值从0开始计算。

针对 特征:退款 来说,将其转换为【0,1】两个值,不能是【1,2】数值。

3.1 读取数据

package com.yyds.tags.ml.classificationimport org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.sql.{DataFrame, SparkSession}object DecisionTreeTest {def main(args: Array[String]): Unit = {val spark = SparkSession.builder().appName(this.getClass.getSimpleName.stripSuffix("$")).master("local[4]").getOrCreate()import org.apache.spark.sql.functions._import spark.implicits._// 1. 加载数据val dataframe: DataFrame = spark.read.format("libsvm").load("datas/iris/sample_libsvm_data.txt")dataframe.printSchema()dataframe.show(10, truncate = false)spark.stop()}}

在这里插入图片描述

3.2 特征工程

    // 2. 特征工程:特征提取、特征转换及特征选择// a. 将标签值label,转换为索引,从0开始,到 K-1val labelIndexer: StringIndexerModel = new StringIndexer().setInputCol("label").setOutputCol("index_label").fit(dataframe)val df1: DataFrame = labelIndexer.transform(dataframe)// b. 对类别特征数据进行特殊处理, 当每列的值的个数小于设置K,那么此列数据被当做类别特征,自动进行索引转换val featureIndexer: VectorIndexerModel = new VectorIndexer().setInputCol("features").setOutputCol("index_features").setMaxCategories(4).fit(df1)val df2: DataFrame = featureIndexer.transform(df1)df2.printSchema()df2.show(10, truncate = false)
root|-- label: double (nullable = true)|-- features: vector (nullable = true)|-- index_label: double (nullable = true)|-- index_features: vector (nullable = true)

3.3 训练模型

    // 3. 划分数据集:训练数据和测试数据val Array(trainingDF, testingDF) = df2.randomSplit(Array(0.8, 0.2))// 4. 使用决策树算法构建分类模型val dtc: DecisionTreeClassifier = new DecisionTreeClassifier().setLabelCol("index_label").setFeaturesCol("index_features")// 设置决策树算法相关超参数.setMaxDepth(5).setMaxBins(32)       // 此值必须大于等于类别特征类别个数.setImpurity("gini")  // 也可以是香农熵:entropyval dtcModel: DecisionTreeClassificationModel = dtc.fit(trainingDF)println(dtcModel.toDebugString)
DecisionTreeClassificationModel (uid=dtc_338073100075) of depth 1 with 3 nodesIf (feature 406 <= 72.0)Predict: 1.0Else (feature 406 > 72.0)Predict: 0.0

3.4 模型评估

    // 5. 模型评估,计算准确度val predictionDF: DataFrame = dtcModel.transform(testingDF)predictionDF.printSchema()predictionDF.select($"label", $"index_label", $"probability", $"prediction").show(10, truncate = false)val evaluator = new MulticlassClassificationEvaluator().setLabelCol("index_label").setPredictionCol("prediction").setMetricName("accuracy")val accuracy: Double = evaluator.evaluate(predictionDF)println(s"Accuracy = $accuracy")
Accuracy = 0.8823529411764706

4、ML Pipeline

管道 Pipeline 概念:将多个Transformer转换器Estimators模型学习器按照 依赖顺序 组工作流WorkFlow形式,方面数据集的特征转换和模型训练及预测。

将上面的决策树分类代码,改为使用 Pipeline 构建模型与预测。

在这里插入图片描述

package com.yyds.tags.ml.classificationimport org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.sql.{DataFrame, SparkSession}object PipelineTest {def main(args: Array[String]): Unit = {val spark = SparkSession.builder().appName(this.getClass.getSimpleName.stripSuffix("$")).master("local[4]").getOrCreate()import org.apache.spark.sql.functions._import spark.implicits._// 1. 加载数据val dataframe: DataFrame = spark.read.format("libsvm").load("datas/iris/sample_libsvm_data.txt")//dataframe.printSchema()//dataframe.show(10, truncate = false)// 划分数据集:训练集和测试集val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2))// 2. 构建管道Pipeline// a. 将标签值label,转换为索引,从0开始,到 K-1val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("index_label").fit(dataframe)// b. 对类别特征数据进行特殊处理, 当每列的值的个数小于设置K,那么此列数据被当做类别特征,自动进行索引转换val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("index_features").setMaxCategories(4).fit(dataframe)// c. 使用决策树算法构建分类模型val dtc: DecisionTreeClassifier = new DecisionTreeClassifier().setLabelCol("index_label").setFeaturesCol("index_features")// 设置决策树算法相关超参数.setMaxDepth(5).setMaxBins(32) // 此值必须大于等于类别特征类别个数.setImpurity("gini")// d. 创建Pipeline,设置Stage(转换器和模型学习器)val pipeline: Pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtc))// 3. 训练模型val pipelineModel: PipelineModel = pipeline.fit(trainingDF)// 获取决策树分类模型val dtcModel: DecisionTreeClassificationModel =pipelineModel.stages(2).asInstanceOf[DecisionTreeClassificationModel]println(dtcModel.toDebugString)// 4. 模型评估val predictionDF: DataFrame = pipelineModel.transform(testingDF)predictionDF.printSchema()predictionDF.select($"label", $"index_label", $"probability", $"prediction").show(20, truncate = false)val evaluator = new MulticlassClassificationEvaluator().setLabelCol("index_label").setPredictionCol("prediction").setMetricName("accuracy")val accuracy: Double = evaluator.evaluate(predictionDF)println(s"Accuracy = $accuracy")// 应用结束,关闭资源spark.stop()}}

5、模型调优

使用决策树算法训练模型时,可以调整相关超参数,结合训练验证(Train-Validation Split)交叉验证(Cross-Validation),获取最佳模型。

5.1 训练验证

将数据集划分为两个部分 ,静态的划分,一个用于训练模型,一个用于验证模型

通过评估指标,获取最佳模型,超参数设置比较好。

在这里插入图片描述

// 无论使用何种验证方式通过调整算法超参数来进行模型调优,需要使用工具类ParamGridBuilder 将 超参数封装到Map集合中
import org.apache.spark.ml.tuning.ParamGridBuilderval paramGrid: Array[ParamMap] = new ParamGridBuilder().addGrid(lr.regParam, Array(0.1, 0.01)).addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)).build()// 使用训练验证 TrainValidationSplit 方式获取最佳模型
val trainValidationSplit = new TrainValidationSplit().setEstimator(lr)                      // 也可以是pipeline.setEvaluator(new RegressionEvaluator) // 评估器.setEstimatorParamMaps(paramGrid)      // 超参数// 80% of the data will be used for training and the remaining 20% for validation..setTrainRatio(0.8)

训练验证的使用

package com.yyds.tags.ml.classificationimport org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{VectorAssembler, VectorIndexer}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit, TrainValidationSplitModel}
import org.apache.spark.sql.{DataFrame, SparkSession}object HPO {/*** 调整算法超参数,找出最优模型* @param dataframe 数据集* @return*/def trainBestModel(dataframe: DataFrame): PipelineModel = {// a. 特征向量化val assembler: VectorAssembler = new VectorAssembler().setInputCols(Array("color", "product")).setOutputCol("raw_features")// b. 类别特征进行索引val indexer: VectorIndexer = new VectorIndexer().setInputCol("raw_features").setOutputCol("features").setMaxCategories(30)// .fit(dataframe)// c. 构建决策树分类器val dtc: DecisionTreeClassifier = new DecisionTreeClassifier().setFeaturesCol("features").setLabelCol("label").setPredictionCol("prediction")// d. 构建Pipeline管道流实例对象val pipeline: Pipeline = new Pipeline().setStages(Array(assembler, indexer, dtc))// e. 构建参数网格,设置超参数的值val paramGrid: Array[ParamMap] = new ParamGridBuilder().addGrid(dtc.maxDepth, Array(5, 10)).addGrid(dtc.impurity, Array("gini", "entropy")).addGrid(dtc.maxBins, Array(32, 64)).build()// f. 多分类评估器val evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction")// 指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy.setMetricName("accuracy")// g. 训练验证val trainValidationSplit = new TrainValidationSplit().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid)// 80% of the data will be used for training and the remaining 20% for validation..setTrainRatio(0.8)// h. 训练模型val model: TrainValidationSplitModel =trainValidationSplit.fit(dataframe)// i. 获取最佳模型返回model.bestModel.asInstanceOf[PipelineModel]}}

5.2 交叉验证(K折)

将数据集划分为两个部分 ,动态的划分为K个部分数据集,其中1份数据集为验证数据集,其他K-1分数据为训练数据集,调整参数训练模型。

在这里插入图片描述

/*** 采用K-Fold交叉验证方式,调整超参数获取最佳PipelineModel模型* @param dataframe 数据集* @return*/def trainBestPipelineModel(dataframe: DataFrame): PipelineModel = {// a. 特征向量化val assembler: VectorAssembler = new VectorAssembler().setInputCols(Array("color", "product")).setOutputCol("raw_features")// b. 类别特征进行索引val indexer: VectorIndexer = new VectorIndexer().setInputCol("raw_features").setOutputCol("features").setMaxCategories(30)// .fit(dataframe)// c. 构建决策树分类器val dtc: DecisionTreeClassifier = new DecisionTreeClassifier().setFeaturesCol("features").setLabelCol("label").setPredictionCol("prediction")// d. 构建Pipeline管道流实例对象val pipeline: Pipeline = new Pipeline().setStages(Array(assembler, indexer, dtc))// e. 构建参数网格,设置超参数的值val paramGrid: Array[ParamMap] = new ParamGridBuilder().addGrid(dtc.maxDepth, Array(5, 10)).addGrid(dtc.impurity, Array("gini", "entropy")).addGrid(dtc.maxBins, Array(32, 64)).build()// f. 多分类评估器val evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction")// 指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy.setMetricName("accuracy")// g. 构建交叉验证实例对象val crossValidator: CrossValidator = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(3)// h. 训练模式val crossValidatorModel: CrossValidatorModel =  crossValidator.fit(dataframe)// i. 获取最佳模型val pipelineModel: PipelineModel = crossValidatorModel.bestModel.asInstanceOf[PipelineModel]// j. 返回模型pipelineModel}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker k8s

Docker docker到底与一般的虚拟机有什么不同呢&#xff1f; 我们知道一般的linux系统即GNU/Linux系统包括两个部分&#xff0c;linux系统内核GNU提供的大量自由软件&#xff0c;而centos就是众多GNU/Linux系统中的一个。 虚拟机会在宿主机上虚拟出一个完整的操作系统与宿主机完…

在 3ds Max 中对链模型进行摆放姿势处理

推荐&#xff1a; NSDT场景编辑器助你快速搭建可二次开发的3D应用场景 建模和“摆姿势”3D链可能看起来是一项繁琐的工作&#xff0c;但实际上可以通过使用阵列工具并将链中的链接视为骨骼来轻松完成。在本教程中&#xff0c;我将向您展示如何对链条进行建模&#xff0c;并通过…

oled拼接屏在柳州的户外广告中有哪些应用展现?

柳州oled拼接屏是一种高端的显示屏&#xff0c;它采用了OLED技术&#xff0c;具有高亮度、高对比度、高色彩饱和度、高刷新率等优点&#xff0c;能够呈现出更加真实、清晰、细腻的图像效果。 同时&#xff0c;柳州oled拼接屏还具有拼接功能&#xff0c;可以将多个屏幕拼接在一…

vue element select下拉框回显展示数字

vue element select下拉框回显展示数字 问题截图&#xff1a; 下拉框显示数字可以从数据类型来分析错误&#xff0c;接收的数据类型是字符串&#xff0c;但是value是数字类型 <el-form-item prop"classifyLabelId" :label"$t(item.classifyLabelId)"…

GUI-Menu菜单实例

运行代码&#xff1a; //GUI-Menu菜单实例 #include"std_lib_facilities.h" #include"GUI/Simple_window.h" #include"GUI/GUI.h" #include"GUI/Graph.h" #include"GUI/Point.h"struct Lines_window :Window {Lines_window…

常见的网络攻击

​ 1.僵木蠕毒 攻击业内习惯把僵尸网络、木马、蠕虫、感染型病毒合称为僵木蠕毒。从攻击路径来看&#xff0c;蠕虫和感染型病毒通过自身的能力进行主动传播&#xff0c;木马则需要渠道来进行投放&#xff0c;而由后门木马&#xff08;部分具备蠕虫或感染传播能力&#xff09;构…

Mybatis架构简介

文章目录 1.整体架构图2. 基础支撑层2.1 类型转换模块2.2 日志模块2.3 反射工具模块2.4 Binding 模块2.5 数据源模块2.6缓存模块2.7 解析器模块2.8 事务管理模块3. 核心处理层3.1 配置解析3.2 SQL 解析与 scripting 模块3.3 SQL 执行3.4 插件4. 接口层1.整体架构图 MyBatis 分…

智能优化算法——灰狼优化算法(PythonMatlab实现)

目录 1 灰狼优化算法基本思想 2 灰狼捕食猎物过程 2.1 社会等级分层 2.2 包围猎物 2.3 狩猎 2.4 攻击猎物 2.5 寻找猎物 3 实现步骤及程序框图 3.1 步骤 3.2 程序框图 4 Python代码实现 5 Matlab实现 1 灰狼优化算法基本思想 灰狼优化算法是一种群智能优化算法&#xff0c;它的…

java版工程项目管理系统 Spring Cloud+Spring Boot+Mybatis+Vue+ElementUI+前后端分离 功能清单

Java版工程项目管理系统 Spring CloudSpring BootMybatisVueElementUI前后端分离 功能清单如下&#xff1a; 首页 工作台&#xff1a;待办工作、消息通知、预警信息&#xff0c;点击可进入相应的列表 项目进度图表&#xff1a;选择&#xff08;总体或单个&#xff09;项目显示…

cocos creator Richtext点击事件

组件如图 添加ts自定义脚本&#xff0c;定义onClickFunc点击方法&#xff1a; import { Component, _decorator} from "cc";const { ccclass } _decorator; ccclass(RichTextComponent) export class RichTextComponent extends Component{public onClickFunc(even…

C++入门学习(2)

思维导图&#xff1a; 一&#xff0c;缺省参数 如何理解缺省参数呢&#xff1f;简单来说&#xff0c;缺省参数就是一个会找备胎的参数&#xff01;为什么这样子说呢&#xff1f;来看一个缺省参数就知道了&#xff01;代码如下&#xff1a; #include<iostream> using std…

【个人笔记】linux命令之ls

目录 Linux中一切皆文件ls命令常用参数常用命令lscpu lspci Linux中一切皆文件 理解参考&#xff1a;为什么说&#xff1a;Linux中一切皆文件&#xff1f; ls命令 ls&#xff08;英文全拼&#xff1a; list directory contents&#xff09;命令用于显示指定工作目录下之内容…

实现大文件传输的几种方法,并实现不同电脑间大文件传输

随着网络技术的快速发展&#xff0c;大文件的传输需求越来越多&#xff0c;如何在不同的电脑之间实现大文件的快速传输&#xff0c;是一个挑战&#xff0c;下面介绍几种常用的方法可以解决这个问题。 1、利用局域网传输&#xff1a;把两台电脑接入同一个网络环境&#xff0c;通…

每天一道大厂SQL题【Day27】脉脉真题实战(三)连续两天活跃用户

文章目录 每天一道大厂SQL题【Day27】脉脉真题实战(三)连续两天活跃用户每日语录第26题 中级题: 活跃时长的均值1. 需求列表思路分析 答案获取加技术群讨论附表文末SQL小技巧 后记 每天一道大厂SQL题【Day27】脉脉真题实战(三)连续两天活跃用户 大家好&#xff0c;我是Maynor。…

AtCoder Beginner Contest 310-D - Peaceful Teams(DFS)

Problem Statement There are N sports players. Among them, there are M incompatible pairs. The i-th incompatible pair (1≤i≤M) is the Ai​-th and Bi​-th players. You will divide the players into T teams. Every player must belong to exactly one team, an…

Web3.0:重新定义数字资产的所有权和交易方式

随着区块链技术的发展和应用&#xff0c;数字资产的概念已经逐渐深入人心。数字资产不仅包括加密货币&#xff0c;还包括数字艺术品、虚拟土地、游戏道具等各种形式的数字物品。然而&#xff0c;在传统的互联网环境下&#xff0c;数字资产的所有权和交易方式往往受到限制和约束…

SQL中为何时常见到 where 1=1?

你是否曾在 SELECT 查询中看到过 WHERE 11 条件。我在许多不同的查询和许多 SQL 引擎中都有看过。这条件显然意味着 WHERE TRUE&#xff0c;所以它只是返回与没有 WHERE 子句时相同的查询结果。此外&#xff0c;由于查询优化器几乎肯定会删除它&#xff0c;因此对查询执行时间没…

猿创征文|一文带你了解前端开发者工具

前端开发者工具目录 一、前言二、前端开发者工具——编译器&#xff08;含插件&#xff09;1、VS Code2、VS Code 必备插件3、WebStorm 三、前端开发者工具——UI 框架工具1、Element2、Vant 四、前端开发者工具——API 调试工具1、ApiPost 五、写在最后&#xff08;总结&#…

微服务sleuth+zipkin---链路追踪+nacos配置中心

目录 1.分布式链路追踪 1.1.链路追踪Sleuth介绍 1.2.如何完成sleuth 1.3.zipkin服务器 2.配置中心 2.1.常见配置中心组件 2.2.微服务集群共享一个配置文件 2.2.1实时刷新--配置中心数据 2.2.2.手动写一个实时刷新的配置类 ----刷新配置文件 2.3.多个微服务公用一个配…

【最新教程】树莓派安装系统及VNC远程桌面连接

大家好&#xff0c;今天就不给大家介绍PYTHONL ,今天我作为一个刚入坑树莓派的小白&#xff0c;整理了一下自己安装树莓派的整个过程&#xff0c;分享给大家。 目录 树莓派 准备工作&#xff1a; 树莓派远程ssh失败access denied 原因&#xff1a; 树莓派系统安装 1、下载…