Python:以鸢尾花数据为例,介绍决策树算法

文章参考来源:

https://www.cnblogs.com/yanqiang/p/11600569.html

https://www.cnblogs.com/baby-lily/p/10646226.html

https://blog.csdn.net/liuziyuan333183/article/details/107399633


决策树算法

决策树算法主要有ID3, C4.5, CART这三种。

ID3算法从树的根节点开始,总是选择信息增益最大的特征,对此特征施加判断条件建立子节点,递归进行,直到信息增益很小或者没有特征时结束。
信息增益:特征 A 对于某一训练集 D 的信息增益 g(D,A)g(D,A) 定义为集合 D 的熵 H(D)H(D) 与特征 A 在给定条件下 D 的熵 H(D/A)H(D/A) 之差。
熵(Entropy)是表示随机变量不确定性的度量。

g(D,A)=H(D)−H(D∣A)g(D,A)=H(D)−H(D∣A)

C4.5是使用了信息增益比来选择特征,这被看成是 ID3 算法的一种改进。

但这两种算法都会导致过拟合的问题,需要进行剪枝。

决策树的修剪,其实就是通过优化损失函数来去掉不必要的一些分类特征,降低模型的整体复杂度。

CART 算法在生成树的过程中,分类树采用了基尼指数(Gini Index)最小化原则,而回归树选择了平方损失函数最小化原则。
CART 算法也包含了树的修剪,CART 算法从完全生长的决策树底端剪去一些子树,使得模型更加简单。

具体代码实现上,scikit-learn 提供的 DecisionTreeClassifier 类可以做多分类任务。

1. DecisionTreeClassifier API 的使用

和其他分类器一样,DecisionTreeClassifier 需要两个数组作为输入:
X: 训练数据,稀疏或稠密矩阵,大小为 [n_samples, n_features]
Y: 类别标签,整型数组,大小为 [n_samples]

from sklearn import tree
X = [[0, 0], [1, 1]]
Y = [0, 1]
#clf = tree.DecisionTreeClassifier()
clf = tree.DecisionTreeClassifier(criterion="entropy"   #不纯度的计算方法。"entropy"表示使用信息熵;"gini"表示使用基尼系数,splitter="best"	#控制决策树中的随机选项。“best”表示在分枝时会优先选择重要的特征进行分枝;“random”表示分枝时会更加随机,常用来防止过拟合,max_depth=10	#限制树的最大深度,min_samples_split=5	#节点必须包含训练样本的个数,min_samples_leaf=1	#叶子最少包含样本的个数,min_weight_fraction_leaf=0.0,max_features=None	#限制分枝的特征个数,random_state=None	#输入任意数字会让模型稳定下来。加上random_state这个参数后,score就不会总是变化,max_leaf_nodes=None,min_impurity_decrease=0.0	#限制信息增益的大小,信息增益小于设定值分枝不会发生,min_impurity_split=None	#结点必须含有最小信息增益再划分,class_weight=None	#设置样本的权重,当正反样本差别较大时,又需要对少的样本进行精确估计时使用,搭配min_weight_fraction_leaf来剪枝,presort=False)clf = clf.fit(X, Y)

DecisionTreeClassifier参数如下:

函数的参数含义如下所示:

  • criterion:gini或者entropy,前者是基尼系数,后者是信息熵。
  • splitter: best or random 前者是在所有特征中找最好的切分点 后者是在部分特征中,默认的”best”适合样本量不大的时候,而如果样本数据量非常大,此时决策树构建推荐”random” 。
  • max_features:None(所有),log2,sqrt,N  特征小于50的时候一般使用所有的
  • max_depth:  int or None, optional (default=None) 设置决策随机森林中的决策树的最大深度,深度越大,越容易过拟合,推荐树的深度为:5-20之间。
  • min_samples_split:设置结点的最小样本数量,当样本数量可能小于此值时,结点将不会在划分。
  • min_samples_leaf: 这个值限制了叶子节点最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝。
  • min_weight_fraction_leaf: 这个值限制了叶子节点所有样本权重和的最小值,如果小于这个值,则会和兄弟节点一起被剪枝默认是0,就是不考虑权重问题。
  • max_leaf_nodes: 通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。
  • class_weight: 指定样本各类别的的权重,主要是为了防止训练集某些类别的样本过多导致训练的决策树过于偏向这些类别。这里可以自己指定各个样本的权重,如果使用“balanced”,则算法会自己计算权重,样本量少的类别所对应的样本权重会高。
  • min_impurity_split: 这个值限制了决策树的增长,如果某节点的不纯度(基尼系数,信息增益,均方差,绝对差)小于这个阈值则该节点不再生成子节点。即为叶子节点 。

模型拟合后,可以用于预测样本的分类

clf.predict([[2., 2.]])
array([1])

此外,可以预测样本属于每个分类(叶节点)的概率,(输出结果:0%,100%)

clf.predict_proba([[2., 2.]])
array([[0., 1.]])

DecisionTreeClassifier() 模型方法中也包含非常多的参数值。例如:

  • criterion = gini/entropy 可以用来选择用基尼指数或者熵来做损失函数。
  • splitter = best/random 用来确定每个节点的分裂策略。支持 “最佳” 或者“随机”。
  • max_depth = int 用来控制决策树的最大深度,防止模型出现过拟合。
  • min_samples_leaf = int 用来设置叶节点上的最少样本数量,用于对树进行修剪。

2. 由鸢尾花数据集构建决策树

鸢尾花数据集:
数据集名称的准确名称为 Iris Data Set,总共包含 150 行数据。每一行数据由 4 个特征值及一个目标值组成。
其中 4 个特征值分别为:萼片长度、萼片宽度、花瓣长度、花瓣宽度。
而目标值为三种不同类别的鸢尾花,分别为:Iris Setosa,Iris Versicolour,Iris Virginica。

DecisionTreeClassifier 既可以用于二分类,也可以用于多分类。
对于鸢尾花数据集,可以如下构建决策树:

from sklearn.datasets import load_iris
from sklearn import tree
X, y = load_iris(return_X_y=True)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

2.1 简单绘制决策树

拟合完后,可以用plot_tree()方法绘制出决策树来,如下图所示

tree.plot_tree(clf)

2.2 Graphviz形式输出决策树

也可以用 Graphviz 格式(export_graphviz)输出。
如果使用的是 conda 包管理器,可以用如下方式安装:

conda install python-graphviz
pip install graphviz

以下展示了用 Graphviz 输出上述从鸢尾花数据集得到的决策树,结果保存为 iris.pdf

import graphviz
iris = load_iris()
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("iris")

export_graphviz 支持使用参数进行视觉优化,包括根据分类或者回归值绘制彩色的结点,也可以使用显式的变量或者类名。
Jupyter Notebook 还可以自动内联呈现这些绘图。

dot_data = tree.export_graphviz(clf, out_file=None,feature_names=iris.feature_names,class_names=iris.target_names,filled=True, rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph

2.3 文本形式输出决策树

此外,决策树也可以使用 export_text 方法以文本形式输出,这个方法不需要安装其他包,也更加的简洁。

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree.export import export_text
iris = load_iris()
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(iris.data, iris.target)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- class: 1
|   |--- petal width (cm) >  1.75
|   |   |--- class: 2

3. 绘制决策平面

绘制由特征对构成的决策平面,决策边界由训练集得到的简单阈值组成。

print(__doc__)import numpy as np
import matplotlib.pyplot as pltfrom sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree# Parameters
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02# Load data
iris = load_iris()for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],[1, 2], [1, 3], [2, 3]]):# We only take the two corresponding featuresX = iris.data[:, pair]y = iris.target# Trainclf = DecisionTreeClassifier().fit(X, y)# Plot the decision boundaryplt.subplot(2, 3, pairidx + 1)x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),np.arange(y_min, y_max, plot_step))plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)plt.xlabel(iris.feature_names[pair[0]])plt.ylabel(iris.feature_names[pair[1]])# Plot the training pointsfor i, color in zip(range(n_classes), plot_colors):idx = np.where(y == i)plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],cmap=plt.cm.RdYlBu, edgecolor='black', s=15)plt.suptitle("Decision surface of a decision tree using paired features")
plt.legend(loc='lower right', borderpad=0, handletextpad=0)
plt.axis("tight")plt.figure()
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
plot_tree(clf, filled=True)
plt.show()
Automatically created module for IPython interactive environment

4. 数据集划分及结果评估

数据集获取

from sklearn import datasets # 导入方法类iris = datasets.load_iris() # 加载 iris 数据集
iris_feature = iris.data # 特征数据
iris_target = iris.target # 分类数据

数据集划分

from sklearn.model_selection import train_test_splitfeature_train, feature_test, target_train, target_test = train_test_split(iris_feature, iris_target, test_size=0.33, random_state=42)

模型训练及预测

from sklearn.tree import DecisionTreeClassifierdt_model = DecisionTreeClassifier() # 所有参数均置为默认状态
dt_model.fit(feature_train,target_train) # 使用训练集训练模型
predict_results = dt_model.predict(feature_test) # 使用模型对测试集进行预测

结果评估

scores = dt_model.score(feature_test, target_test)
scores
1.0

参考文档

scikit-learn 1.10.1 DecisionTreeClassifier API User Guide
Example: a decision tree on the iris dataset

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/436427.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【转】CT球管小知识--热容量

Heat Unit 简称HU&#xff0c;为DR、CT等医疗设备中球管的热容量单位。如&#xff0c;Varian球管RAD14的热容量为300kHU。设备工作时&#xff0c;X线管两极之间要承受极高的电压&#xff0c;并通过一定量电流&#xff0c;高速电子束撞击阳极靶面&#xff0c;将产生大量热能。X线…

一键锁屏_ios快捷指令一键登录校园网(桂航为例,哆点认证)

&#xff08;鄙人水平很有限&#xff0c;所学的专业也和此无关&#xff0c;文中有的东西可能会说错&#xff0c;但我尽量用简单的方式说。请多指教&#xff09;现在很多高校现在晚上断电断网&#xff0c;最烦恼的事莫过于第二天早上起床眯着眼摸出手机输入账号密码登录校园网的…

【转】一篇文章完整了解CT成像技术(完整版)

1&#xff0e;CT的发明与发展 1.1 CT的发明 CT是计算机断层摄影术&#xff08;Computed Tomography&#xff0c;CT&#xff09;的简称&#xff0c;是继1895年伦琴发现X线以来&#xff0c;医学影像学发展史上的一次革命。 CT的发明可以追溯到1917年。当时&#xff0c;奥地利数…

Pandas数据可视化工具:图表工具-Seaborn

内容来源&#xff1a;https://www.jiqizhixin.com/articles/2019-01-30-15 简介 在本文中&#xff0c;我们将研究Seaborn&#xff0c;它是Python中另一个非常有用的数据可视化库。Seaborn库构建在Matplotlib之上&#xff0c;并提供许多高级数据可视化功能。 尽管Seaborn库可以…

图解WinCE6.0下的内核驱动和用户驱动

图解WinCE6.0下的内核驱动和用户驱动 在《WinCE驱动程序的分类》中曾提到&#xff0c;WinCE6.0的流驱动既可以加载到内核态也可以加载到用户态。下面通过一组图片简单说明一下这两种驱动的关系。 首先编写一个流驱动WCEDrv&#xff0c;代码如下。 代码 #include <windows.h&…

人体轮廓_女性人体油画轮廓柔和生动,优美动人,你喜欢吗?

人体油画是艺术和时代的产物&#xff0c;也是艺术结晶的重要体现&#xff0c;在文艺复兴以前&#xff0c;人体艺术大都以雕塑形式来表现&#xff0c;在此之后&#xff0c;人们都以意大利威尼斯绘画为代表&#xff0c;艺术家们开始以色彩塑造人体绘画艺术。随着时代进步和人们对…

机器学习分类模型中的评价指标介绍:准确率、精确率、召回率、ROC曲线

文章来源&#xff1a;https://blog.csdn.net/wf592523813/article/details/95202448 1 二分类评价指标 准确率&#xff0c;精确率&#xff0c;召回率&#xff0c;F1-Score&#xff0c; AUC, ROC, P-R曲线 1.1 准确率&#xff08;Accuracy&#xff09; 评价分类问题的性能指标…

【转】AI-900认证考试攻略

架构师的信仰系列文章&#xff0c;主要介绍我对系统架构的理解&#xff0c;从我的视角描述各种软件应用系统的架构设计思想和实现思路。 从程序员开始&#xff0c;到架构师一路走来&#xff0c;经历过太多的系统和应用。做过手机游戏&#xff0c;写过编程工具&#xff1b;做过…

300plc与组态王mpi通讯_S7-300与S7-200之间的MPI通信

通信说明S7-200PLC与S7-300PLC之间采用MPI通讯方式时&#xff0c;S7-200PLC中不需要编写任何与通讯有关的程序&#xff0c;只需要将要交换的数据整理到一个连续的V 存储区当中即可&#xff0c;而S7-300PLC中需要在组织块OB1(或是定时中断组织块OB35)当中调用系统功能X_GET(SFC6…

ORA-01114: 将块写入文件 35 时出现 IO 错误

参考文档&#xff1a; https://blog.csdn.net/z_x_1000/article/details/17263077 https://www.cnblogs.com/login2012/p/5775602.html https://www.iteye.com/blog/yangyangcom-2200174 一、问题背景 最开始发现应用服务打不开&#xff0c;于是登录服务器发现Oracle数据关…

【转】CT影像文件格式DICOM详解

CT影像文件格式DICOM详解 DICOM简介 DICOM&#xff08;Digital Imaging and Communications in Medicine&#xff09;即医学数字成像和通信&#xff0c;是医学图像和相关信息的国际标准&#xff08;ISO 12052&#xff09;。DICOM被广泛应用于放射医疗&#xff0c;心血管成像以…

fatal error lnk1120: 1 个无法解析的外部命令_3月1日七牛云存储割韭菜的应对方法...

前言早上起来看邮件&#xff0c;看到一封被七牛云割韭菜的公告&#xff1a;内心冰冰凉&#xff0c;不过大家都要吃饭的嘛总不能一直免费下去。所以来研究一下对于我们这种穷人应该如何应对。一、七牛CDN加速流程主要流程分析1、用户通过浏览器访问我的网站(腾讯云服务器)&#…

【转】DCM(DICOM)医学影像文件格式详解

1、 什么是DICOM&#xff1f; DICOM(DigitalImaging andCommunications inMedicine)是指医疗数字影像传输协定&#xff0c;是用于医学影像处理、储存、打印、传输的一组通用的标准协定。它包含了文件格式的定义以及网络通信协议。DICOM是以TCP/IP为基础的应用协定&#xff0c;并…

SM4对称加密算法及Java实现

文章来源&#xff1a;https://www.jianshu.com/p/5ec8464b0a1b 一、简介 与DES和AES算法类似&#xff0c;SM4算法是一种分组密码算法。 其分组长度为128bit&#xff0c;密钥长度也为128bit。 加密算法与密钥扩展算法均采用32轮非线性迭代结构&#xff0c;以字&#xff08;32位…

【转】DICOM网络协议(一)概述

转自&#xff1a;https://www.jianshu.com/p/8a0f0fe6a738 作者&#xff1a;我住的城市没有福合埕 DICOM (Digital Imaging and Communications in Medicine)即医学数字成像和通信&#xff0c;DICOM网络是基于TCP/IP的网络协议。通过DICOM将影像设备和存储管理设备连接起来。…

Windows进程系列(2) -- Svchost进程

在基于NT内核的Windows操作系统家族中&#xff0c;Svchost.exe是一个非常重要的进程。很多病毒、木马驻留系统与这个进程密切相关&#xff0c;因此深入了解该进程是非常有必要的。本文主要介绍Svchost进程的功能&#xff0c;以及与该进程相关的知识。      Svchost进程概述…

【转】DICOM入门(一)——语法

转自&#xff1a;https://www.jianshu.com/p/5db8933a25a4 作者&#xff1a;我住的城市没有福合埕 1.什么是DICOM DICOM(Digital Imaging and Communications in Medicine)即医学数字成像和传输协议&#xff0c;是用医疗影像&#xff08;CT 核磁共振 DR CR 超声等&#xff0…

1000并发 MySQL数据库_再送一波干货,测试2000线程并发下同时查询1000万条数据库表及索引优化...

继上篇文章《绝对干货&#xff0c;教你4分钟插入1000万条数据到mysql数据库表&#xff0c;快快进来》发布后在博客园首页展示得到了挺多的阅读量&#xff0c;我这篇文章就是对上篇文章的千万级数据库表在高并发访问下如何进行测试访问这篇文章的知识点如下:1.如何自写几十行代码…

【转】VTK修炼之道1_初识VTK

1.VTK是什么&#xff1f; Visualization ToolKit 3D计算机图形学、图象处理及可视化工具包 VTK使用C、面向对象技术开发&#xff1b;基于OpenGL&#xff0c;封装了OpenGL中的功能&#xff0c;屏蔽细节、便于交互、易于使用提供多种语言接口C&#xff0b;&#xff0b; 、Java 、…

HTTPS原理和对中间件攻击的预防

一、https/tls原理 HTTPS访问的三个阶段 第一阶段 认证站点 客户端向站点发起HTTPS请求&#xff0c;站点返回数字证书。客户端通过数字证书验证所访问的站点是真实的目标站点。 第二阶段 协商密钥 客户端与站点服务器协商此次会话的对称加密密钥&#xff0c;用于下一阶段的加…