决策树实验分析(分类和回归任务,剪枝,数据对决策树影响)

目录

1. 前言

2. 实验分析

        2.1 导入包

        2.2 决策树模型构建及树模型的可视化展示

        2.3 概率估计

        2.4 绘制决策边界

        2.5 决策树的正则化(剪枝)

        2.6 对数据敏感

        2.7 回归任务

        2.8 对比树的深度对结果的影响

        2.9 剪枝


1. 前言

        本文主要分析了决策树的分类和回归任务,对比一系列的剪枝的策略对结果的影响,数据对于决策树结果的影响。

        介绍使用graphaviz这个决策树可视化工具

2. 实验分析

        2.1 导入包

#1.导入包
import os
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings('ignore')

        2.2 决策树模型构建及树模型的可视化展示

        下载安装包:https://graphviz.gitlab.io/_pages/Download/Download_windows.html

         选择一款安装,注意安装时要配置环境变量

        注意这里使用的是鸢尾花数据集,选择花瓣长和宽两个特征

#2.建立树模型
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:,2:] # petal legth and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X,y)
#3.树模型的可视化展示
from sklearn.tree import export_graphviz
export_graphviz(tree_clf,out_file='iris_tree.dot',feature_names=iris.feature_names[2:],class_names=iris.target_names,rounded=True,filled=True
)

        然后就可以使用graphviz包中的dot.命令工具将此文件转换为各种格式的如pdf,png,如 dot -Tpng iris_tree.png -o iris_tree.png

        可以去文件系统查看,也可以用python展示

from IPython.display import Image
Image(filename='iris_tree.png',width=400,height=400)

        分析:value表示每个节点所有样本中各个类别的样本数,用花瓣宽<=0.8和<=1.75 作为根节点划分,叶子节点表示分类结果,结果执行少数服从多数策略,gini指数随着分类进行在减小。

        2.3 概率估计

        估计类概率 输入数据为:花瓣长5厘米,宽1.5厘米的花。相应节点是深度为2的左节点,因此决策树因输出以下概率:

        iris-Setosa为0%(0/54)

        iris-Versicolor为90.7%(49/54)

        iris-Virginica为9.3%(5/54)

        

#4.概率估计
print(tree_clf.predict_proba([[5,1.5]]))
print(tree_clf.predict([[5,1.5]]))

        2.4 绘制决策边界

        

#5.绘制决策边界
from matplotlib.colors import ListedColormapdef plot_decision_boundary(clf,X,y,axes=[0,7.5,0,3],iris=True,legend=False,plot_training=True):#找两个特征 x1 x2x1s = np.linspace(axes[0],axes[1],100)x2s = np.linspace(axes[2],axes[3],100)#构建棋盘x1,x2 = np.meshgrid(x1s,x2s)#在棋盘中构建待测试数据X_new = np.c_[x1.ravel(),x2.ravel()]#将预测值算出来y_pred = clf.predict(X_new).reshape(x1.shape)#选择颜色custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])#绘制并填充不同的区域plt.contourf(x1,x2,y_pred,alpha=0.3,cmap=custom_cmap)if not iris:custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])plt.contourf(x1,x2,y_pred,alpha=0.8,cmap=custom_cmap2)#可以把训练数据展示出来if plot_training:plt.plot(X[:,0][y==0],X[:,1][y==0],'yo',label='Iris-Setosa')plt.plot(X[:,0][y==1],X[:,1][y==1],'bs',label='Iris-Versicolor')plt.plot(X[:,0][y==2],X[:,1][y==2],'g^',label='Iris-Virginica')if iris:plt.xlabel('Petal length',fontsize = 14)plt.ylabel('Petal width',fontsize = 14)else:plt.xlabel(r'$x_1$',fontsize=18)plt.ylabel(r'$x_2$',fontsize=18)if legend:plt.legend(loc='lower right',fontsize=14)plt.figure(figsize=(8,4))
plot_decision_boundary(tree_clf,X,y)
plt.plot([2.45,2.45],[0,3],'k-',linewidth=2)
plt.plot([2.45,7.5],[1.75,1.75],'k--',linewidth=2)
plt.plot([4.95,4.95],[0,1.75],'k:',linewidth=2)
plt.plot([4.85,4.85],[1.75,3],'k:',linewidth=2)
plt.text(1.40,1.0,'Depth=0',fontsize=15)
plt.text(3.2,1.80,'Depth=1',fontsize=13)
plt.text(4.05,0.5,'(Depth=2)',fontsize=11)
plt.title('Decision Tree decision boundareies')plt.show()

        可以看出三种不同颜色的代表分类结果,Depth=0可看作第一刀切分,Depth=1,2 看作第二刀,三刀,把数据集切分。

        2.5 决策树的正则化(剪枝)

        决策树的正则化

        DecisionTreeClassifier类还具有一些其他的参数类似地限制了决策树的形状

        min-samples_split(节点在分割之前必须具有的样本数)

        min-samples_leaf(叶子节点必须具有的最小样本数)

        max-leaf_nodes(叶子节点的最大数量)

        max_features(在每个节点处评估用于拆分的最大特征数)

        max_depth(树的最大深度)

#6.决策树正则化
from sklearn.datasets import make_moons
X,y = make_moons(n_samples=100,noise=0.25,random_state=53)
plt.plot(X[:,0],X[:,1],"b.")
tree_clf1 = DecisionTreeClassifier(random_state=42)
tree_clf2 = DecisionTreeClassifier(random_state=42,min_samples_leaf=4)
tree_clf1.fit(X,y)
tree_clf2.fit(X,y)
plt.figure(figsize=(12,4))
plt.subplot(121)
plot_decision_boundary(tree_clf1,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('no restriction')
plt.subplot(122)
plot_decision_boundary(tree_clf2,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('min_samples_leaf={}'.format(tree_clf2.min_samples_leaf))

        可以看出在没有加限制条件之前,分类器要考虑每个点,模型变得复杂,容易过拟合。其他的一些参数读者可以自行尝试。

        2.6 对数据敏感

        决策树对于数据是很敏感的

        

#6.对数据敏感
np.random.seed(6)
Xs = np.random.rand(100,2) - 0.5
ys = (Xs[:,0] > 0).astype(np.float32) * 2angle = np.pi /4
rotation_matrix = np.array([[np.cos(angle),-np.sin(angle)],[np.sin(angle),np.cos(angle)]])
Xsr = Xs.dot(rotation_matrix)tree_clf_s = DecisionTreeClassifier(random_state=42)
tree_clf_sr = DecisionTreeClassifier(random_state=42)
tree_clf_s.fit(Xs,ys)
tree_clf_sr.fit(Xsr,ys)plt.figure(figsize=(11,4))
plt.subplot(121)
plot_decision_boundary(tree_clf_s,Xs,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')plt.subplot(122)
plot_decision_boundary(tree_clf_sr,Xsr,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')plt.show()

         这里是把数据又旋转了45度,然而决策边界并没有也旋转45度,却是变复杂了。可以看出,对于复杂的数据,决策树是很敏感的。

        2.7 回归任务

#7.回归任务 
np.random.seed(42)
m = 200
X = np.random.rand(m,1)
y = 4 * (X-0.5)**2
y = y + np.random.randn(m,1) /10
plt.plot(X,y,'b.')
from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X,y)
from sklearn.tree import export_graphviz
export_graphviz(tree_reg,out_file='regression_tree.dot',feature_names=['X1'],rounded=True,filled=True
)
from IPython.display import Image
Image(filename='regression_tree.png',width=400,height=400)

 

         回归任务,这里的衡量标准就变成了均方误差。

        2.8 对比树的深度对结果的影响

#8.对比树的深度对结果的影响
from sklearn.tree import DecisionTreeRegressor
tree_reg1 = DecisionTreeRegressor(random_state=42,max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42,max_depth=3)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)def plot_regression_predictions(tree_reg,X,y,axes=[0,1,-0.2,1],ylabel='$y$'):x1 = np.linspace(axes[0],axes[1],500).reshape(-1,1)y_pred = tree_reg.predict(x1)plt.axis(axes)plt.xlabel('$X_1$',fontsize =18)if ylabel:plt.ylabel(ylabel,fontsize = 18,rotation=0)plt.plot(X,y,'b.')plt.plot(x1,y_pred,'r.-',linewidth=2,label=r'$\hat{y}$')plt.figure(figsize=(11,4))
plt.subplot(121)plot_regression_predictions(tree_reg1,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):plt.plot([split,split],[-0.2,1],style,linewidth = 2)
plt.text(0.21,0.65,'Depth=0',fontsize= 15)
plt.text(0.01,0.2,'Depth=1',fontsize= 13)
plt.text(0.65,0.8,'Depth=0',fontsize= 13)
plt.legend(loc='upper center',fontsize = 18)
plt.title('max_depth=2',fontsize=14)
plt.subplot(122)
plot_regression_predictions(tree_reg2,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):plt.plot([split,split],[-0.2,1],style,linewidth = 2)
for split in (0.0458,0.1298,0.2873,0.9040):plt.plot([split,split],[-0.2,1],linewidth = 1)
plt.text(0.3,0.5,'Depth=2',fontsize= 13)
plt.title('max_depth=3',fontsize=14)plt.show()

        不同的树的深度,对于结果产生极大的影响

        2.9 剪枝

        

#9.加一些限制
tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42,min_samples_leaf=10)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)x1 = np.linspace(0,1,500).reshape(-1,1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)plt.figure(figsize=(11,4))plt.subplot(121)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred1,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('No restrctions',fontsize =14)plt.subplot(122)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred2,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('min_samples_leaf={}'.format(tree_reg2.min_samples_leaf),fontsize =14)plt.show()

        一目了然。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/724765.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java开发避坑指南,手把手教你写Java项目文档

前言 作为一个有丰富经验的微服务系统架构师&#xff0c;经常有人问我&#xff0c;“应该选择RabbitMQ还是Kafka&#xff1f;” 基于某些原因&#xff0c; 许多开发者会把这两种技术当做等价的来看待。的确&#xff0c;在一些案例场景下选择RabbitMQ还是Kafka没什么差别&…

Android布局优化之include、merge、ViewStub的使用,7年老Android一次坑爹的面试经历

前言 开发10年&#xff0c;老码农&#xff0c;曾经是爱奇艺架构 点击领取完整开源项目《安卓学习笔记总结最新移动架构视频大厂安卓面试真题项目实战源码讲义》 师&#xff0c;东芝集团高级工程师&#xff0c;三星架构师。5年之内频繁被辞退。内心拔凉拔凉的&#xff0c;在这五…

Android开发新手入门教程,2024Android开发社招面试解答之性能优化

前言 本想今年辞掉工作大干一场&#xff0c;没想到碰到疫情&#xff0c;家里蹲了3个月…&#xff0c;还好字节能给一次机会。前阵子字节跳动的提前批开始了&#xff0c;看宣传是说有海量HC&#xff0c;机会多多&#xff0c;本着涨涨面经的心理&#xff0c;然后就投递了一下杭州…

论文阅读:Dataset Quantization

摘要 最先进的深度神经网络使用大量&#xff08;百万甚至数十亿&#xff09;数据进行训练。昂贵的计算和内存成本使得在有限的硬件资源上训练它们变得困难&#xff0c;特别是对于最近流行的大型语言模型 (LLM) 和计算机视觉模型 (CV)。因此最近流行的数据集蒸馏方法得到发展&a…

代码随想录算法训练营第11天

20. 有效的括号 方法&#xff1a; 1. 如果 !st.empty() return false2.如果st.top() ! s[i] return false3. 如果 st.empty() return false注意&#xff1a; 以下这种写法 不满足 题目要求的第二点&#xff0c;不能以正确的顺序闭合 if(s[i] st.top()){return true;s…

基于springboot+vue的图书管理系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

RabbitMQ如何保证消息不丢

如何保证Queue消息能不丢呢&#xff1f; RabbitMQ在接收到消息后&#xff0c;默认并不会立即进行持久化&#xff0c;而是先把消息暂存在内存中&#xff0c;这时候如果MQ挂了&#xff0c;那么消息就会丢失。所以需要通过持久化机制来保证消息可以被持久化下来。 队列和交换机的…

11. C语言标准函数库

C语言制定了一组使用方式通用的函数&#xff0c;称为C语言标准函数库&#xff0c;用于实现编程常用功能&#xff0c;标准函数库由编译器系统提供&#xff0c;并按功能分类存储在不同源代码文件中&#xff0c;调用标准库内函数时需要首先使用 #include 连接对应的源代码文件。 【…

Tensorflow2.0笔记 - 常见激活函数sigmoid,tanh和relu

本笔记主要记录常见的三个激活函数sigmoid&#xff0c;tanh和relu&#xff0c;关于激活函数详细的描述&#xff0c;可以参考这里&#xff1a; 详解激活函数&#xff08;Sigmoid/Tanh/ReLU/Leaky ReLu等&#xff09; - 知乎 import tensorflow as tf import numpy as nptf.__ve…

『python爬虫』ip代理池使用 协采云 账密模式(保姆级图文)

目录 实现效果实现思路代码示例总结 欢迎关注 『python爬虫』 专栏&#xff0c;持续更新中 欢迎关注 『python爬虫』 专栏&#xff0c;持续更新中 实现效果 在官网原版demo基础上小改了一下,修正了接口错误(把2023改成2024就可以了),原版demo只能测试单个ip,我这里批量测试所有…

为什么MySQL中多表联查效率低,连接查询实现的原理是什么?

MySQL中多表联查效率低的原因主要涉及到以下几个方面&#xff1a; 数据量大: 当多个表通过连接查询时&#xff0c;如果这些表的数据量很大&#xff0c;那么查询就需要处理更多的数据&#xff0c;这自然会降低查询效率。 连接操作复杂性: 连接查询需要对参与连接的每个表中的数…

从零学习Linux操作系统 第三十二部分 ansible中剧本的应用

一、什么是playbook及playbook的组成 1.Playbook的功能 playbook 是由一个或多个play组成的列表 Playboot 文件使用YAML来写的 play就是一个个模块用列表的方式体现出来 playbook的语法是用YAML的预防进行书写的 2.YAML 简介 是一种表达资料序列的格式&#xff0c;类似XM…

【从零开始学GIS再到精通GIS】专题图制作-地图渲染-地图整饰

本篇主要介绍如何在gis中进行专题图制作-地图渲染-地图整饰&#xff1b;示例数据下载链接该网站更新了很多有关地理的数据。 1 数据准备&#xff1a;点、线、面等矢量数据、栅格数据的准备等&#xff08;下一更会详细介绍数据处理等方面的内容&#xff09;&#xff1b; 2 加载…

记录一则 线上域名证书更新及cdn证书更新

本篇为阿里云免费证书更新记录。 登录阿里云账号 搜索数字证书管理服务管理控制台 点击创建证书 输入你的域名 填写相关信息&#xff08;注&#xff1a;域名验证方式选择文件验证&#xff09; 等待审核通过&#xff08;时间不久&#xff0c;一般为半小时内&#xff09; …

Vue2高级篇

Vue高级 Vue生命周期 生命周期又称为生命周期回调函数、生命周期函数、生命周期钩子, 是Vue在运行过程中的关键时刻帮我们调用的一些指函数, 生命周期函数名字不可修改, 其中的this指向的是vm或组件实例对象. 常用的生命周期钩子: mounted: 发送ajax请求、启动定时器、绑定…

【Web安全】SQL各类注入与绕过

【Web安全】SQL各类注入与绕过 【Web安全靶场】sqli-labs-master 1-20 BASIC-Injection 【Web安全靶场】sqli-labs-master 21-37 Advanced-Injection 【Web安全靶场】sqli-labs-master 38-53 Stacked-Injections 【Web安全靶场】sqli-labs-master 54-65 Challenges 与62关二…

python并发编程:IO模型

一 IO模型 二 network IO 再说一下IO发生时涉及的对象和步骤。对于一个network IO \(这里我们以read举例\)&#xff0c;它会涉及到两个系统对象&#xff0c;一个是调用这个IO的process \(or thread\)&#xff0c;另一个就是系统内核\(kernel\)。当一个read操作发生时&#xff…

无代理方式实现VMware的迁移?详细解析

在当今数字化时代&#xff0c;数据的安全性和可用性对于企业至关重要。尤其是在VMware转变订阅策略后&#xff0c;原本永久订阅的产品转变为以年付费订阅的形式&#xff0c;导致客户不得不支付更多的费用&#xff0c;大幅增加了成本。同时&#xff0c;客户也对VMware未来发展前…

k8s-kubeapps图形化管理 21

结合harbor仓库 由于kubeapps不读取hosts解析&#xff0c;因此需要添加本地仓库域名解析&#xff08;dns解析&#xff09; 更改context为全局模式 添加repo仓库 复制ca证书 添加成功 图形化部署 更新部署应用版本 再次进行部署 上传nginx 每隔十分钟会自动进行刷新 在本地仓库…

人人都写过的6个bug

大家好&#xff0c;我是知微。 程序员写bug几乎是家常便饭&#xff0c;也是我们每个人成长过程中难以避免的一部分。 为了缓解这份“尴尬”&#xff0c;今天想和大家分享一些曾经都会遇到过的bug&#xff0c;让我们一起来看看这些“经典之作”。 1、数组越界 #include <…