Spark2.x 入门:决策树分类器

一、方法简介 ​

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

决策树学习通常包括3个步骤:特征选择、决策树的生成和决策树的剪枝。

示例代码

我们以iris数据集(iris)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。决策树可以用于分类和回归,接下来我们将在代码中分别进行介绍。

1. 导入需要的包:

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector,Vectors}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

2. 读取数据,简要分析:

导入spark.implicits._,使其支持把一个RDD隐式转换为一个DataFrame。我们用case class定义一个schema:Iris,Iris就是我们需要的数据的结构;然后读取文本文件,第一个map把每行的数据用“,”隔开,比如在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类;我们这里把特征存储在Vector中,创建一个Iris模式的RDD,然后转化成dataframe;然后把刚刚得到的数据注册成一个表iris,注册成这个表之后,我们就可以通过sql语句进行数据查询;选出我们需要的数据后,我们可以把结果打印出来查看一下数据。

scala> import spark.implicits._
import spark.implicits._scala> case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)
defined class Irisscala> val data = spark.read.textFile("file:///root/data/iris.txt").map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble,p(3).toDouble),p(4).toString())).toDF()scala> data.createOrReplaceTempView("iris")scala> val df = spark.sql("select * from iris")
df: org.apache.spark.sql.DataFrame = [features: vector, label: string]scala> df.map(t => t(1)+":"+t(0)).collect().foreach(println)
Iris-setosa:[5.1,3.5,1.4,0.2]
Iris-setosa:[4.9,3.0,1.4,0.2]
Iris-setosa:[4.7,3.2,1.3,0.2]
Iris-setosa:[4.6,3.1,1.5,0.2]
Iris-setosa:[5.0,3.6,1.4,0.2]
Iris-setosa:[5.4,3.9,1.7,0.4]
Iris-setosa:[4.6,3.4,1.4,0.3]
......

3. 进一步处理特征和标签,以及数据分组:

//分别获取标签列和特征列,进行索引,并进行了重命名。
scala> val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df) 
labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_6c3c138d61bfscala> val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(df)
featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_08c01d7fd953//这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的。
scala> val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_11ce3220e43a//接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。
scala> val Array(trainingData, testData) = df.randomSplit(Array(0.7, 0.3))
trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]

4. 构建决策树分类模型:

//导入所需要的包
scala> import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassificationModelscala> import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassifierscala> import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator//训练决策树模型,这里我们可以通过setter的方法来设置决策树的参数,也可以用ParamMap来设置(具体的可以查看spark mllib的官网)。具体的可以设置的参数可以通过explainParams()来获取。
scala> val dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtClassifier: org.apache.spark.ml.classification.DecisionTreeClassifier = dtc_7948c1724433//在pipeline中进行设置
scala> val pipelinedClassifier = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))
pipelinedClassifier: org.apache.spark.ml.Pipeline = pipeline_b5a49e693b35//训练决策树模型
scala> val modelClassifier = pipelinedClassifier.fit(trainingData)
modelClassifier: org.apache.spark.ml.PipelineModel = pipeline_b5a49e693b35//进行预测
scala> val predictionsClassifier = modelClassifier.transform(testData)
predictionsClassifier: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]//查看部分预测的结果
scala> predictionsClassifier.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

5. 评估决策树分类模型:

scala> val evaluatorClassifier = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
evaluatorClassifier: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_8059f30a8634scala> val accuracy = evaluatorClassifier.evaluate(predictionsClassifier)
accuracy: Double = 0.94scala> println("Test Error = " + (1.0 - accuracy))
Test Error = 0.06000000000000005scala> val treeModelClassifier = modelClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
treeModelClassifier: org.apache.spark.ml.classification.DecisionTreeClassificationModel = DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodesscala> println("Learned classification tree model:\n" + treeModelClassifier.toDebugString)
Learned classification tree model:
DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodesIf (feature 2 <= 1.9)Predict: 0.0Else (feature 2 > 1.9)If (feature 3 <= 1.6)If (feature 2 <= 4.9)Predict: 1.0Else (feature 2 > 4.9)If (feature 0 <= 6.0)Predict: 1.0Else (feature 0 > 6.0)Predict: 2.0Else (feature 3 > 1.6)If (feature 2 <= 4.8)If (feature 1 <= 2.8)Predict: 2.0Else (feature 1 > 2.8)Predict: 1.0Else (feature 2 > 4.8)Predict: 2.0

从上述结果可以看到模型的预测准确率为 0.94 以及训练的决策树模型结构。

6. 构建决策树回归模型:

//导入所需要的包
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.regression.DecisionTreeRegressor//训练决策树模型
scala> val dtRegressor = new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtRegressor: org.apache.spark.ml.regression.DecisionTreeRegressor = dtr_e98e9ef10e22//在pipeline中进行设置
scala> val pipelineRegressor = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtRegressor, labelConverter))
pipelineRegressor: org.apache.spark.ml.Pipeline = pipeline_9f0fb530c801//训练决策树模型
scala> val modelRegressor = pipelineRegressor.fit(trainingData)
modelRegressor: org.apache.spark.ml.PipelineModel = pipeline_9f0fb530c801//进行预测
scala> val predictionsRegressor = modelRegressor.transform(testData)
predictionsRegressor: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 4 more fields]//查看部分预测结果
scala> predictionsRegressor.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

7. 评估决策树回归模型:

scala> val evaluatorRegressor = new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse")
evaluatorRegressor: org.apache.spark.ml.evaluation.RegressionEvaluator = regEval_162861380a26scala> val rmse = evaluatorRegressor.evaluate(predictionsRegressor)
rmse: Double = 0.2449489742783178scala> println("Root Mean Squared Error (RMSE) on test data = " + rmse)
Root Mean Squared Error (RMSE) on test data = 0.2449489742783178scala> val treeModelRegressor = modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel]
treeModelRegressor: org.apache.spark.ml.regression.DecisionTreeRegressionModel = DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodesscala> println("Learned regression tree model:\n" + treeModelRegressor.toDebugString)
Learned regression tree model:
DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodesIf (feature 2 <= 1.9)Predict: 0.0Else (feature 2 > 1.9)If (feature 3 <= 1.6)If (feature 2 <= 4.9)Predict: 1.0Else (feature 2 > 4.9)If (feature 0 <= 6.0)Predict: 1.0Else (feature 0 > 6.0)Predict: 2.0Else (feature 3 > 1.6)If (feature 2 <= 4.8)If (feature 1 <= 2.8)Predict: 2.0Else (feature 1 > 2.8)Predict: 1.0Else (feature 2 > 4.8)Predict: 2.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/53684.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

美术馆订票门票预约展览预约售票订票百度图表计算机毕业设计/springboot/javaWEB/J2EE/MYSQL数据库/vue前后分离小程序

1. 需求分析 首先&#xff0c;明确需求&#xff1a; 功能&#xff1a;门票预约、展览预约、售票、查询等系统&#xff1a;前后端分离的小程序技术栈&#xff1a;Spring Boot (后端)、Vue.js (前端)、MySQL (数据库) 2. 设计系统架构 设计系统的整体架构&#xff0c;包括前后…

web项目如何部署到服务器上呢?——麻烦的方法

只需关注web项目如何部署到服务器上&#xff0c;因为服务器运行时就可以访问web项目了。 一、麻烦的方法 1、首先启动服务器 &#xff08;1&#xff09;找到bin文件夹 &#xff08;2&#xff09;双击运行startup.bat文件 &#xff08;3&#xff09;运行之后的界面如下&#…

Dart 3.5更新对普通开发者有哪些影响?

哈喽&#xff0c;我是老刘 Flutter 3.24以及Dart 3.5不久前发布了。 突然觉得时间过得好快。六年前刚开始使用Flutter 1.0的场景还在眼前。 之前写了一篇文章盘点Flutter 3.24的新功能对普通开发者有哪些影响。Flutter 3.24 对普通开发者有哪些影响&#xff1f;https://mp.wei…

vivado 设置物理约束

设置物理约束 在本实验中&#xff0c;您将为CPU网表设计创建物理约束&#xff0c;观察中的操作 GUI转换为Tcl命令。使用Tcl命令&#xff0c;可以轻松编写复杂的操作脚本 用于在流动的不同阶段重复使用。 注意&#xff1a;如果您从实验1继续&#xff0c;并且您的设计已打开&…

面试—JVM

目录 JVM内存结构 类的生命周期 双亲委派机制 打破双亲委派机制 垃圾回收机制 判断垃圾回收算法 垃圾回收算法 G1垃圾回收器 JVM内存结构 程序计数器 记录要执行的字节码指令的地址&#xff0c;可以控制程序指令的进行&#xff0c;实现分支、跳转、异常等 在多线程执行…

Centos7.9 安装Elasticsearch 8.15.1(图文教程)

本章教程,主要记录在Centos7.9 安装Elasticsearch 8.15.1的整个安装过程。 一、下载安装包 下载地址: https://www.elastic.co/cn/downloads/past-releases/elasticsearch-8-15-1 你可以通过手动下载然后上传到服务器,也可以直接使用在线下载的方式。 wget https://artifacts…

Python世界:力扣题43大数相乘算法实践

Python世界&#xff1a;力扣题43大数相乘算法实践 任务背景思路分析方案1方案2方案3方案4无测试套主调测试套主调 本文小结 任务背景 问题来自力扣题目43&#xff1a;字符串相乘&#xff0c;大意如下&#xff1a; Given two non-negative integers num1 and num2 represented a…

【学术会议征稿】2024年智能驾驶与智慧交通国际学术会议(IDST 2024)

2024年智能驾驶与智慧交通国际学术会议(IDST 2024) 2024 International Conference on Intelligent Driving and Smart Transportation 智能驾驶和智慧交通利用新兴技术&#xff0c;使城市出行更加方便、更具成本效益且更安全。在此背景下&#xff0c;由中南大学主办的2024年…

LLMs技术 | 整合Ollama实现本地LLMs调用

前言 近两年AIGC发展的非常迅速&#xff0c;从刚开始的只有ChatGPT到现在的很百家争鸣。从开始的大参数模型&#xff0c;再到后来的小参数模型&#xff0c;从一开始单一的文本模型到现在的多模态模型等等。随着一起进步的不仅仅是模型的多样化&#xff0c;还有模型的使用方式。…

65、Python之函数高级:装饰器实战,通用日志记录功能的动态添加

引言 从系统开发的规范性来说&#xff0c;日志的记录是一个规范化的要求&#xff0c;但是&#xff0c;有些程序员会觉得麻烦&#xff0c;反而不愿意记录日志&#xff0c;还是太年轻了…… 其实&#xff0c;如果个人保护意识稍微强一些&#xff0c;一定会主动进行日志的记录的…

python_openCV_计算图片中的区域的黑色比例

希望对原始图片进行处理&#xff0c;然后计算图片上的黑色和白色的占比 上图&#xff0c; 原始图片 import numpy as np import cv2 import matplotlib.pyplot as pltdef cal_black(img_file):#功能&#xff1a; 计算图片中的区域的黑色比例#取图片中不同的位置进行计算&…

关于武汉芯景科技有限公司的IIC缓冲器芯片XJ4307开发指南(兼容LTC4307)

一、芯片引脚介绍 1.芯片引脚 2.引脚描述 二、系统结构图 三、功能描述 1.总线超时&#xff0c;自动断开连接 当 SDAOUT 或 SCLOUT 为低电平时&#xff0c;将启动内部定时器。定时器仅在相应输入变为高电平时重置。如果在 30ms &#xff08;典型值&#xff09; 内没有变为高…

国产芯片LT9211D:MIPI转LVDS转换器,分辨率高达3840x2160 30Hz,碾压其它同功能芯片

以下为LT9211D&#xff1a;MIPI TO LVDS的芯片简单介绍&#xff0c;供各位参考 Lontium LT9211D是一款高性能MIPI DSI/CSI-2到双端口LVDS转换器。LT9211D反序列化 输入MIPI视频数据&#xff0c;解码数据包&#xff0c;转换格式化的视频数据流到LVDS发射机输出AP与移动显示面板或…

基于STM32L431小熊派设计的智能花盆(微信小程序+腾讯云IOT)(223)

文章目录 一、前言1.1 项目介绍【1】项目背景【2】设计实现的功能【3】项目硬件模块组成1.2 设计思路【1】整体设计思路【2】ESP8266工作模式配置1.3 项目开发背景【1】选题的意义【2】可行性分析【3】参考文献1.4 开发工具的选择【1】设备端开发【2】上位机开发1.5 系统框架图…

ppt模板简约下载哪个?这些模板简约又大气

中秋节&#xff0c;作为中国传统节日中最具诗意的一个&#xff0c;月圆人团圆的美好寓意总是让人心生向往。 想在国际网站上宣传这一传统节日的独特魅力&#xff0c;却担心自己的PPT不够吸引人&#xff1f;别急&#xff0c;使用精美免费的ppt模板&#xff0c;可以让你的演示瞬…

创新性处理Java编程技术问题的策略

在Java编程领域&#xff0c;解决技术问题的方式不断进化。本文将探讨一些创新性和针对性的技术问题处理方法&#xff0c;帮助开发者高效地应对挑战&#xff0c;提高代码质量和开发效率。 1. 动态代理与反射机制的优化 Java的动态代理和反射机制为程序员提供了强大的功能&#…

【性能】DJANGO + REDIS 缓存提速

不加REDIS缓存时&#xff0c;每次访问都要读取数据库&#xff0c;当访问量非常大的时候&#xff0c; 就会有很多次的数据库查询&#xff0c;会造成访问速度变慢&#xff0c;服务器资源占用较多等问题。 当使用了缓存后&#xff0c;访问情况变成了如下&#xff1a;访问一个网址时…

用户登录和注销

在Linux系统中&#xff0c;用户登录和注销是一个常见的操作&#xff0c;涉及到用户账户管理和服务管理等多个方面。下面分别介绍用户在图形界面和命令行下的登录和注销流程。 图形界面下的登录和注销 登录 登录界面&#xff1a; 当用户启动计算机时&#xff0c;通常会看到一…

Python Flask_APScheduler定时任务的正确(最佳)使用

描述 APScheduler基于Quartz的一个Python定时任务框架&#xff0c;实现了Quartz的所有功能。最近使用Flask框架使用Flask_APScheduler来做定时任务&#xff0c;在使用过程当中也遇到很多问题&#xff0c;例如在定时任务调用的方法中需要用到flask的app.app_context()时&#…

无影云电脑:在最破的电脑上玩最顶配的游戏

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点 我对云电脑很感兴趣&#xff0c;这几天我深度体验了无影云电脑的个人版.&#xff0c;我给大家分享下。这款云电脑到底能不能替代你的笔记本?到底能不能改变人们使用电脑的方式? 先说结论&#xff1a; (1)从草根创…