机器学习 | 决策树 Decision Tree

—— 分而治之,逐个击破

                把特征空间划分区域

                每个区域拟合简单模型

                分级分类决策


1、核心思想和原理

  • 举例:
    • 特征选择、节点分类、阈值确定


2、信息嫡

       

        熵本身代表不确定性,是不确定性的一种度量。

        熵越大,不确定性越高,信息量越高。

       

        为什么用log?—— 两种解释,可能性的增长呈指数型;log可以将乘法变为加减法。

        

        联合熵 的物理意义:观察一个多变量系统获得的信息量。

        条件熵 的物理意义:知道其中一个变量的信息后,另一个变量的信息量。

                给定了训练样本 X ,分类标签中包含的信息量是什么。

         

        

        信息增益(互信息)

                代表了一个特征能够为一个系统带来多少信息。

        

        

        熵的分类

        

        

        熵的本质:特殊的衡量分布的混乱程度与分散程度的距离

        

        

        二分类信息熵:

二分类信息熵

import numpy as np
import matplotlib.pyplot as plt
def entropy(p):return -(p * np.log2(p) + (1 - p) * np.log2(1 - p))
plot_x = np.linspace(0.001, 0.999, 100)
plt.plot(plot_x, entropy(plot_x))
plt.show()

        

 

         决策树的本质

        


 3、决策树分类代码实现

 

数据集

from sklearn.datasets import load_irisiris = load_iris()
x = iris.data[:, 1:3]
y = iris.target
plt.scatter(x[:,0], x[:,1], c = y)
plt.show()

 

3.1、sklearn中的决策树

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(max_depth=2, criterion='entropy')
clf.fit(x, y)

DecisionTreeClassifier

DecisionTreeClassifier(criterion='entropy', max_depth=2)

决策边界绘制的代码: 

def decision_boundary_plot(X, y, clf):axis_x1_min, axis_x1_max = X[:,0].min() - 1, X[:,0].max() + 1axis_x2_min, axis_x2_max = X[:,1].min() - 1, X[:,1].max() + 1x1, x2 = np.meshgrid( np.arange(axis_x1_min,axis_x1_max, 0.01) , np.arange(axis_x2_min,axis_x2_max, 0.01))z = clf.predict(np.c_[x1.ravel(),x2.ravel()])z = z.reshape(x1.shape)from matplotlib.colors import ListedColormapcustom_cmap = ListedColormap(['#F5B9EF','#BBFFBB','#F9F9CB'])plt.contourf(x1, x2, z, cmap=custom_cmap)plt.scatter(X[:,0], X[:,1], c=y)plt.show()
decision_boundary_plot(x, y, clf)

from sklearn.tree import plot_tree
plot_tree(clf)
[Text(0.4, 0.8333333333333334, 'X[1] <= 2.45\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]'),Text(0.2, 0.5, 'entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]'),Text(0.6, 0.5, 'X[1] <= 4.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]'),Text(0.4, 0.16666666666666666, 'entropy = 0.154\nsamples = 45\nvalue = [0, 44, 1]'),Text(0.8, 0.16666666666666666, 'entropy = 0.497\nsamples = 55\nvalue = [0, 6, 49]')]

 


 

3.2、最优划分条件

from collections import Counter
Counter(y)
Counter({0: 50, 1: 50, 2: 50})
def calc_entropy(y):counter = Counter(y)sum_ent = 0for i in counter:p = counter[i] / len(y)sum_ent += (-p * np.log2(p))return sum_ent
calc_entropy(y)
1.584962500721156
def split_dataset(x, y, dim, value):index_left = (x[:, dim] <= value)index_right = (x[:, dim] > value)return x[index_left], y[index_left], x[index_right], y[index_right]
def find_best_split(x, y):best_dim = -1best_value = -1best_entropy = np.infbest_entropy_left, best_entropy_right = -1, -1for dim in range(x.shape[1]):sorted_index = np.argsort(x[:, dim])for i in range(x.shape[0] - 1): # x列数value_left, value_right = x[sorted_index[i], dim], x[sorted_index[i + 1], dim]if value_left != value_right:value = (value_left + value_right) / 2x_left, y_left, x_right, y_right = split_dataset(x, y, dim, value)entropy_left, entropy_right = calc_entropy(y_left), calc_entropy(y_right)entropy = (len(x_left) * entropy_left + len(x_right) * entropy_right) / x.shape[0]if entropy < best_entropy:best_dim = dimbest_value = valuebest_entropy = entropybest_entropy_left, best_entropy_right = entropy_left, entropy_rightreturn best_dim, best_value, best_entropy, best_entropy_left, best_entropy_right
find_best_split(x, y)
(1, 2.45, 0.6666666666666666, 0.0, 1.0)
x_left, y_left, x_right, y_right = split_dataset(x, y, 1, 2.45)
find_best_split(x_right, y_right)
(1, 4.75, 0.34262624992678425, 0.15374218032876188, 0.4971677614160753)


4、基尼系数

        

        基尼系数运算稍快;

        物理意义略有不同,信息熵表示的是随机变量的不确定度;

                基尼系数表示在样本集合中一个随机选中的样本被分错的概率,也就是纯度。

                基尼系数越小,纯度越高。

        模型效果上差异不大。

        

二分类信息熵和基尼系数代码实现:

import numpy as np
import matplotlib.pyplot as plt
def entropy(p):return -(p * np.log2(p) + (1 - p) * np.log2(1 - p))
def gini(p):return 1 - p ** 2 - (1 - p) ** 2
plot_x = np.linspace(0.001, 0.999, 100)
plt.plot(plot_x, entropy(plot_x), color = 'blue')
plt.plot(plot_x, gini(plot_x), color = 'red')
plt.show()


5、决策树剪枝

Chapter-07/7-6 决策树剪枝.ipynb · 梗直哥/Machine-Learning - Gitee.com

为什么要剪枝?

                复杂度过高。

                        预测复杂度:O(logm)

                        训练复杂度:O(n x m x logm)

                        logm为数的深度,n为数据的维度。

                容易过拟合

                        为非参数学习方法。

 目标:

                降低复杂度

                解决过拟合

 手段:

                限制深度(结点层数)

                限制广度(叶子结点个数)

   —— 设置超参数

                        


6、决策树回归

        基于一种思想:相似输入必会产生相似输出。

        取节点平均值。

        

6.1、决策树回归代码实现

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')boston = datasets.load_boston()
x = boston.data
y = boston.target
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=233)
from sklearn.tree import DecisionTreeRegressorreg = DecisionTreeRegressor()
reg.fit(x_train,y_train)

DecisionTreeRegressor

DecisionTreeRegressor()
reg.score(x_test,y_test)
0.7410680140563546
reg.score(x_train,y_train)
1.0

6.2、绘制学习曲线

from sklearn.metrics import r2_scoreplt.rcParams["figure.figsize"] = (12, 8)
max_depth = [2, 5, 10, 20]for i, depth in enumerate(max_depth):reg = DecisionTreeRegressor(max_depth=depth)train_error, test_error = [], []for k in range(len(x_train)):reg.fit(x_train[:k+1], y_train[:k+1])y_train_pred = reg.predict(x_train[:k + 1])train_error.append(r2_score(y_train[:k + 1], y_train_pred))y_test_pred = reg.predict(x_test)test_error.append(r2_score(y_test, y_test_pred))plt.subplot(2, 2, i + 1)plt.ylim(0, 1.1)plt.title("Depth: {0}".format(depth))plt.plot([k + 1 for k in range(len(x_train))], train_error, color = "red", label = 'train')plt.plot([k + 1 for k in range(len(x_train))], test_error, color = "blue", label = 'test')plt.legend()plt.show()

6.3、网格搜索

from sklearn.model_selection import GridSearchCVparams = {'max_depth': [n for n in range(2, 15)],'min_samples_leaf': [sn for sn in range(3, 20)],
}grid = GridSearchCV(estimator = DecisionTreeRegressor(), param_grid = params, n_jobs = -1
)
grid.fit(x_train,y_train)

GridSearchCV

GridSearchCV(estimator=DecisionTreeRegressor(), n_jobs=-1,param_grid={'max_depth': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,14],'min_samples_leaf': [3, 4, 5, 6, 7, 8, 9, 10, 11, 12,13, 14, 15, 16, 17, 18, 19]})

estimator: DecisionTreeRegressor

DecisionTreeRegressor()

DecisionTreeRegressor

DecisionTreeRegressor()
grid.best_params_
{'max_depth': 5, 'min_samples_leaf': 3}
grid.best_score_
0.7327442904059717
reg = grid.best_estimator_
reg.score(x_test, y_test)
0.781690085676063

7、优缺点和适用条件

优点:

        符合人类直观思维

        可解释性强

        能够处理数值型数据和分类型数据

        能够处理多输出问题

缺点:

        容易产生过拟合

        决策边界只能是水平或竖直方向

                

        不稳定,数据的微小变化可能生成完全不同的树


参考于

Chapter-07/7-4 决策树分类.ipynb · 梗直哥/Machine-Learning - Gitee.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/229761.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MIT6.S081-实验准备

实验全程在Vmware虚拟机 (镜像&#xff1a;Ubuntu-20.04-beta-desktop-amd64) 中进行 一、版本控制 1.1 将mit的实验代码克隆到本地 git clone git://g.csail.mit.edu/xv6-labs-2020 1.2 修改本地git配置文件 创建github仓库&#xff0c;记录仓库地址 我的仓库地址就是htt…

C# 基本桌面编程(二)

一、前言 本章为C# 基本桌面编程技术的第二节也是最后一节。前一节在下面这个链接 C# 基本桌面编程&#xff08;一&#xff09;https://blog.csdn.net/qq_71897293/article/details/135024535?spm1001.2014.3001.5502 二、控件布局 1 叠放顺序 在WPF当中布局&#xff0c;通…

饥荒Mod 开发(十):制作一把AOE武器

饥荒Mod 开发(九)&#xff1a;物品栏排列 饥荒Mod 开发(十一)&#xff1a;修改物品堆叠 前面的文章介绍了很多基础知识以及如何制作一个物品&#xff0c;这次制作一把武器&#xff0c;装备之后可以用来攻击怪物。 制作武器贴图和动画 1.1 制作贴图。 先准备一张武器的贴图&a…

解决:AttributeError: module ‘scipy.misc’ has no attribute ‘imsave’

解决&#xff1a;AttributeError: module ‘scipy.misc’ has no attribute ‘imsave’ 文章目录 解决&#xff1a;AttributeError: module scipy.misc has no attribute imsave背景报错问题报错翻译报错位置代码报错原因解决方法方法一 scipy版本回退&#xff08;不推荐&#…

MySQL_13.InonDB表空间

InnoDB 表空间介绍以及管理 1.mysql表空间类型 system tablespace 系统表空间 file-per-table tablespace 独立表空间 temporary tablespace 临时表空间 undo tablespace UNDO表空间 general tablespace…

React中props 和 state异同初探

在 React 中&#xff0c;props 和 state 是两个非常重要的概念&#xff0c;它们决定了组件的行为和渲染方式。 Props props&#xff08;属性&#xff09;是父组件传递给子组件的数据。它们类似于函数的参数&#xff0c;可以在组件内部被访问和使用&#xff0c;但不能被修改。…

鸿蒙4.0核心技术-WebGL开发

场景介绍 WebGL主要帮助开发者在前端开发中完成图形图像的相关处理&#xff0c;比如绘制彩色图形等。 接口说明 表1 WebGL主要接口列表 接口名描述canvas.getContext获取canvas对象上下文。webgl.createBuffer(): WebGLBuffernullwebgl.bindBuffer(target: GLenum, buffer: …

饥荒Mod 开发(十三):木牌传送

饥荒Mod 开发(十二)&#xff1a;一键制作 饥荒Mod 开发(十四)&#xff1a;制作屏幕弹窗 一键传送源码 饥荒的地图很大&#xff0c;跑地图太耗费时间和饥饿值&#xff0c;如果大部分时间都在跑图真的是很无聊&#xff0c;所以需要有一个能够传送的功能&#xff0c;不仅可以快速…

Web前端-JavaScript(js表达式)

文章目录 JavaScript基础第01天1.编程语言概述1.1 编程1.2 计算机语言1.2.1 机器语言1.2.2 汇编语言1.2.3 高级语言 1.4 翻译器 2.计算机基础2.1 计算机组成2.2 数据存储2.3 数据存储单位2.4 程序运行 3.初始JavaScript3.1 JavaScript 是什么3.2 JavaScript的作用3.3 HTML/CSS/…

《点云处理》平面拟合

前言 在众多点云处理算法中&#xff0c;其中关于平面拟合的算法十分广泛。本篇内容主要是希望总结归纳各类点云平面拟合算法&#xff0c;并且将代码进行梳理保存。 环境&#xff1a; VS2019 PCL1.11.1 1.RANSAC 使用ransac对平面进行拟合是非常常见的用法&#xff0c;PCL…

医疗智能化革命:AI技术引领医疗领域的创新进程

一、“AI”医疗的崛起 随着人工智能&#xff08;AI&#xff09;技术的崛起&#xff0c;"AI"医疗正在以惊人的速度改变着医疗行业的面貌。AI作为一种强大的工具&#xff0c;正在为医疗领域带来前所未有的创新和突破。它不仅在医学影像诊断、病理学分析和基因组学研究等…

Linux ls命令教程:如何有效地列出文件和目录(附案例详解和注意事项)

Linux ls命令介绍 ls是Linux中的基本命令之一&#xff0c;任何Linux用户都应该知道。ls命令列出文件系统中的文件和目录&#xff0c;并显示有关它们的详细信息。它是所有Linux发行版都安装的GNU核心实用程序包的一部分。 Linux ls命令适用的Linux版本 ls命令在所有Linux发行…

vertx写sip服务器

Vert.x SIP 模块默认使用 TCP 协议进行通信。如果您需要支持 UDP 协议&#xff0c;您需要自定义 SIP 协议栈&#xff0c;并在其中实现 UDP 传输。 以下是一个示例代码&#xff0c;演示如何在 Vert.x 中创建一个支持 UDP 的 SIP 服务器&#xff1a; import io.vertx.core.net.…

设计模式——状态模式

引言 状态模式是一种行为设计模式&#xff0c; 让你能在一个对象的内部状态变化时改变其行为&#xff0c; 使其看上去就像改变了自身所属的类一样。 问题 状态模式与有限状态机 的概念紧密相关。 其主要思想是程序在任意时刻仅可处于几种有限的状态中。 在任何一个特定状态中…

手拉手EasyExcel极简实现web上传下载(全栈)

环境介绍 技术栈 springbootmybatis-plusmysqleasyexcel 软件 版本 mysql 8 IDEA IntelliJ IDEA 2022.2.1 JDK 1.8 Spring Boot 2.7.13 mybatis-plus 3.5.3.2 EasyExcel是一个基于Java的、快速、简洁、解决大文件内存溢出的Excel处理工具。 他能让你在不用考虑性…

【shell脚本实战案例】sed替换文本中指定数据所在的行

目录 问题背景&#xff1a; 解决方法&#xff1a; 1.sed替换每行第一个出现的关键字 2.sed替换某一行出现的关键字 3.sed替换某一行第N次出现的关键字 4.sed替换文件中所有的关键字 5.sed将文件中所有关键字替换成空&#xff08;即关键字全部删除的另一种方式&#xff09…

19.Oracle 中count(1) 、count(*) 和count(列名) 函数的区别

count(1) and count(字段) 两者的主要区别是 count(1) 会统计表中的所有的记录数&#xff0c;包含字段为null 的记录。count(字段) 会统计该字段在表中出现的次数&#xff0c;忽略字段为null 的情况。 即不统计字段为null 的记录。 count(*) 和 count(1)和count(列名)区别 …

CentOS7安装教程

1.准备工作 安装虚拟机软件 VMware Workstation 17 Player下载CentOS7 镜像 2. 安装VMware Workstation 17 Player 官网 下载链接 下载好了安装包&#xff0c;双击安装包&#xff0c;傻瓜式安装一直下一步&#xff0c;安装即可。 3. 安装CentOS7 官网 推荐下载地址&…

Tekton 构建容器镜像

Tekton 构建容器镜像 介绍如何使用 Tektonhub 官方 kaniko task 构建docker镜像&#xff0c;并推送到远程dockerhub镜像仓库。 kaniko task yaml文件下载地址&#xff1a;https://hub.tekton.dev/tekton/task/kaniko 查看kaniko task yaml内容&#xff1a; 点击Install&…

RabbitMQ 消息持久化

默认情况下&#xff0c;exchange、queue、message 等数据都是存储在内存中的&#xff0c;这意味着如果 RabbitMQ 重启、关闭、宕机时所有的信息都将丢失。 RabbitMQ 提供了持久化来解决这个问题&#xff0c;持久化后&#xff0c;如果 RabbitMQ 发送 重启、关闭、宕机&#xff…