决策树的生成与剪枝

决策树的生成与剪枝

  • 决策树的生成
    • 生成决策树的过程
    • 决策树的生成算法
  • 决策树的剪枝
    • 决策树的损失函数
    • 决策树的剪枝算法
  • 代码

决策树的生成

生成决策树的过程

为了方便分析描述,我们对上节课中的训练样本进行编号,每个样本加一个ID值,如图所示:
在这里插入图片描述
从根结点开始生成决策树,先将上述训练样本(1-9)全部放在根节点中。

然后选择信息增益或信息增益比最大的特征向下分裂,按照所选特征取值的不同将训练样本分发到不同的子结点中。

例如所选特征有3个取值,则分裂出3个子结点,然后在子结点中重复上述过程,直到所有特征的信息增益(比)很小或者没有特征可选为止,完成决策树的构建,如下图所示:
在这里插入图片描述
图中的决策树共有2个内部结点和3个叶子结点,每个结点旁边的编号代表训练样本的 ID 值。内部结点代表样本的特征,叶子结点代表样本的预测类别,我们将叶子节点中训练样本占比最大的类作为决策树的预测标记。

在使用构建好的决策树在测试数据上分类时,只需要从根节点开始依次测试内部结点代表的特征即可得到测试样本的预测分类。

决策树的生成算法

下面我们先来总结一下 ID3分类决策树的生成算法:

输入:训练数据集 D 、特征集合 A 的信息增益阈值 ;输出:决策树 T

  1. 若 D 中的训练样本属于同一类,则 T 为单结点树,返回 D 中任意样本的类别。

  2. 若 A 中的特征为空,则 T 为单结点树,返回 D 中数量最多的类别。

  3. 使用信息增益在 A 中进行特征选择,若所选特征 A_i 的信息增益小于设定的阈值,则 T 为单结点树,返回 D 中数量最多的类别。

  4. 否则根据 A_i 的每一个取值,将 D 分成若干子集 D_i,将 D_i 中数量最多的类作为标记值,构建子结点,返回 T。

  5. 以 D_i 为训练集,{A - A_i} 为特征集,递归地调用上述步骤,得到子树 T_i,返回 T。

使用 C4.5 算法进行决策树的生成只需要将信息增益改成信息增益比即可。

决策树的剪枝

决策树的损失函数

决策树的叶子节点越多,模型越复杂。决策树的损失函数考虑了模型复杂度,我们可以通过优化其损失函数来对决策树进行剪枝。决策树的损失函数计算过程如下:

  1. 计算叶子结点 t 的样本类别经验熵
    在这里插入图片描述
    对于叶子结点 t 来说,其样本类别的经验熵越小, t 中训练样本的分类误差就越小。当叶子结点 t 中的训练样本为同一类别时,经验熵为零,分类误差为零。

  2. 计算决策树 T 在所有训练样本上的损失之和 C(T)
    在这里插入图片描述
    对于叶子结点 t 中的每一个训练样本,其类别标记都是随机变量 Y 的一个取值,这个取值的不确定性用信息熵来衡量,且可以用经验熵来估计。由上文可知,经验熵在一定程度上可以反映决策树在该样本上的预测损失,累加所有叶子结点上的训练样本损失即上图中的计算公式。

  3. 计算考虑模型复杂度的的决策树损失函数

在这里插入图片描述
决策树的叶子结点个数表示模型的复杂度,通过最小化上面的损失函数,一方面可以减少模型在训练样本上的预测误差,另一方面可以控制模型的复杂度,保证模型的泛化能力。

决策树的剪枝算法

  1. 计算决策树中每个结点的样本类别经验熵:
    在这里插入图片描述
    如上图所示,对于本课示例中的决策树,需要计算 5 个结点的经验熵。

  2. 遍历非叶子结点,剪枝相当于去除其子结点,自身变为叶子结点:

在这里插入图片描述
对于图中的非叶子结点(有工作?),剪枝后变为叶子结点,并通过多数表决的方法确定其类别标记。

以上就是这节课的所有内容了,实际上还有一种决策树算法:分类与回归树(classification and regression tree,简称 CART),它既可以用于分类也可以用于回归,同样包含了特征选择、决策树的生成与剪枝算法。

关于 CART 算法的内容,我们将在最后一章 XGBoost 中进行学习,下面请你来做一道关于信息增益比的题目,顺便回顾一下前面所学的知识

代码

## 1. 创建数据集import pandas as pd
data = [['yes', 'no', '青年', '同意贷款'],['yes', 'no', '青年', '同意贷款'],['yes', 'yes', '中年', '同意贷款'],['no', 'no', '中年', '不同意贷款'],['no', 'no', '青年', '不同意贷款'],['no', 'no', '青年', '不同意贷款'],['no', 'no', '中年', '不同意贷款'],['no', 'no', '中年', '不同意贷款'],['no', 'yes', '中年', '同意贷款']]# 转为 dataframe 格式
df = pd.DataFrame(data)
# 设置列名
df.columns = ['有房?', '有工作?', '年龄', '类别']## 2. 经验熵的实现from math import log2
from collections import Counter
def H(y):'''y: 随机变量 y 的一组观测值,例如:[1,1,0,0,0]'''# 随机变量 y 取值的概率估计值probs = [n/len(y) for n in Counter(y).values()]# 经验熵:根据概率值计算信息量的数学期望return sum([-p*log2(p) for p in probs])## 3. 经验条件熵的实现def cond_H(a):'''a: 根据某个特征的取值分组后的 y 的观测值,例如:[[1,1,1,0],[0,0,1,1]]每一行表示特征 A=a_i 对应的样本子集'''# 计算样本总数sample_num = sum([len(y) for y in a])# 返回条件概率分布的熵对特征的数学期望return sum([(len(y)/sample_num)*H(y) for y in a])## 4. 特征选择函数
def feature_select(df,feats,label):'''df:训练集数据,dataframe 类型feats:候选特征集label:df 中的样本标记名,字符串类型'''# 最佳的特征与对应的信息增益比best_feat,gainR_max = None,-1# 遍历每个特征for feat in feats:# 按照特征的取值对样本进行分组,并取分组后的样本标记数组group = df.groupby(feat)[label].apply(lambda x:x.tolist()).tolist()# 计算该特征的信息增益:经验熵-经验条件熵gain = H(df[label].values) - cond_H(group)# 计算该特征的信息增益比gainR = gain / H(df[feat].values)# 更新最大信息增益比和对应的特征if gainR > gainR_max:best_feat,gainR_max = feat,gainRreturn best_feat,gainR_max ## 5. 决策树的生成函数
import pickle
def creat_tree(df,feats,label):'''df:训练集数据,dataframe 类型feats:候选特征集,字符串列表label:df 中的样本标记名,字符串类型'''# 当前候选的特征列表feat_list = feats.copy()# 若当前训练数据的样本标记值只有一种if df[label].nunique()==1:# 将数据中的任意样本标记返回,这里取第一个样本的标记值return df[label].values[0]# 若候选的特征列表为空时if len(feat_list)==0:# 返回当前数据样本标记中的众数,各类样本标记持平时取第一个return df[label].mode()[0]# 在候选特征集中进行特征选择feat,gain = feature_select(df,feat_list,label)# 若选择的特征信息增益太小,小于阈值 0.1if gain<0.1:# 返回当前数据样本标记中的众数return df[label].mode()[0]# 根据所选特征构建决策树,使用字典存储tree = {feat:{}}# 根据特征取值对训练样本进行分组g = df.groupby(feat)# 用过的特征要移除feat_list.remove(feat)# 遍历特征的每个取值 ifor i in g.groups:# 获取分组数据,使用剩下的候选特征集创建子树tree[feat][i] = creat_tree(g.get_group(i),feat_list,label)# 存储决策树pickle.dump(tree,open('tree.model','wb'))return tree# 6. 决策树的预测函数
def predict(tree,feats,x):'''tree:决策树,字典结构feats:特征集合,字符串列表x:测试样本特征向量,与 feats 对应'''# 获取决策树的根结点:对应样本特征root = next(iter(tree))# 获取该特征在测试样本 x 中的索引i = feats.index(root)# 遍历根结点分裂出的每条边:对应特征取值for edge in tree[root]:# 若测试样本的特征取值=当前边代表的特征取值if x[i]==edge:# 获取当前边指向的子结点child = tree[root][edge]# 若子结点是字典结构,说明是一颗子树if type(child)==dict:# 将测试样本划分到子树中,继续预测return predict(child,feats,x)# 否则子结点就是叶子节点else:# 返回叶子节点代表的样本预测值return child## 7. 在样例数据上测试# 获取特征名列表
feats = list(df.columns[:-1])
# 获取标记名
label = df.columns[-1]
# 创建决策树(此处使用信息增益比进行特征选择)
T = creat_tree(df,feats,label)
# 计算训练集上的预测结果
preds = [predict(T,feats,x) for x in df[feats].values]
# 计算准确率
acc = sum([int(i) for i in (df[label].values==preds)])/len(preds)
# 输出决策树和准确率
print(T,acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64522.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SpringBoot的疫苗在线预约功能实现十二

一、前言介绍&#xff1a; 1.1 项目摘要 随着全球公共卫生事件的频发&#xff0c;如新冠疫情的爆发&#xff0c;疫苗成为了预防和控制传染病的重要手段。传统的疫苗预约方式&#xff0c;如人工挂号或电话预约&#xff0c;存在效率低、易出错、手续繁琐等问题&#xff0c;无法…

MySQL基础 -----MySQL数据类型

目录 INT类型 tinyint类型 类型大小范围 测试tinyint类型数据 float类型 测试&#xff1a; 测试正常数据范围的数据 测试插入范围超过临界值的数据&#xff1a; 测试float类型的四舍五入 ​编辑 decimal类型 同样测试&#xff1a; 字符串类型 char类型 测试&…

代码开发相关操作

使用Vue项目管理器创建项目&#xff1a;&#xff08;vue脚手架安装一次就可以全局使用&#xff09; windowR打开命令窗口&#xff0c;输入vue ui&#xff0c;进入GUI页面&#xff0c;点击创建-> 设置项目名称&#xff0c;在初始化git下面输入&#xff1a;init project&…

如何在 Ubuntu 22.04 上安装和使用 Rust 编程语言环境

简介 Rust 是一门由 Mozilla 开发的系统编程语言&#xff0c;专注于性能、可靠性和内存安全。它在没有垃圾收集的情况下实现了内存安全&#xff0c;这使其成为构建对性能要求苛刻的应用程序&#xff08;如操作系统、游戏引擎和嵌入式系统&#xff09;的理想选择。 接下来&…

MybatisPlus-配置加密

配置加密 目前配置文件中的很多参数都是明文&#xff0c;如果开发人员发生流动&#xff0c;很容易导致敏感信息的泄露。所以MybatisPlus支持配置文件的加密和解密功能。 我们以数据库的用户名和密码为例。 生成秘钥 首先&#xff0c;我们利用AES工具生成一个随机秘钥&#…

记录:virt-manager配置Ubuntu arm虚拟机

virt-manager&#xff08;Virtual Machine Manager&#xff09;是一个图形用户界面应用程序&#xff0c;通过libvirt管理虚拟机&#xff08;即作为libvirt的图形前端&#xff09; 因为要在Linux arm环境做测试&#xff0c;记录下virt-manager配置arm虚拟机的过程 先在VMWare中…

ChatGPT客户端安装教程(附下载链接)

用惯了各类AI的我们发现每天打开网页还挺不习惯和麻烦&#xff0c;突然发现客户端上架了&#xff0c;懂摸鱼的人都知道这里面的道行有多深&#xff0c;话不多说&#xff0c;开整&#xff01; 以下是ChatGPT客户端的详细安装教程&#xff0c;适用于Windows和Mac系统&#xff1a…

Element@2.15.14-tree checkStrictly 状态实现父项联动子项,实现节点自定义编辑、新增、删除功能

背景&#xff1a;现在有一个新需求&#xff0c;需要借助树结构来实现词库的分类管理&#xff0c;树的节点是不同的分类&#xff0c;不同的分类可以有自己的词库&#xff0c;所以父子节点是互不影响的&#xff1b;同样为了选择的方便性&#xff0c;提出了新需求&#xff0c;选择…

概率论得学习和整理27:关于离散的数组 随机变量数组的均值,方差的求法3种公式,思考和细节。

目录 1 例子1&#xff1a;最典型的&#xff0c;最简单的数组的均值&#xff0c;方差的求法 2 例子1的问题&#xff1a;例子1只是1个特例&#xff0c;而不是普遍情况。 2.1 例子1各种默认假设&#xff0c;导致了求均值和方差的特殊性&#xff0c;特别简单。 2.2 我觉得 加权…

【HarmonyOS NEXT】Web 组件的基础用法以及 H5 侧与原生侧的双向数据通讯

关键词&#xff1a;鸿蒙、ArkTs、Web组件、通讯、数据 官方文档Web组件用法介绍&#xff1a;文档中心 Web 组件加载沙箱中页面可参考我的另一篇文章&#xff1a;【HarmonyOS NEXT】 如何将rawfile中文件复制到沙箱中_鸿蒙rawfile 复制到沙箱-CSDN博客 目录 如何在鸿蒙应用中加…

ASP.NET Core - 依赖注入 自动批量注入

依赖注入配置变形 随着业务的增长&#xff0c;我们项目工作中的类型、服务越来越多&#xff0c;而每一个服务的依赖注入关系都需要在入口文件通过Service.Add{}方法去进行注册&#xff0c;这将是非常麻烦的&#xff0c;入口文件需要频繁改动&#xff0c;而且代码组织管理也会变…

Spring Boot 3.X:Unable to connect to Redis错误记录

一.背景 最近在搭建一个新项目&#xff0c;本着有新用新的原则&#xff0c;项目选择到了jdk17SpringBoot3.4。但是在测试Redis连接的时候却遇到了以下问题&#xff1a; redis连不上了。于是我先去检查了配置文件的连接信息&#xff0c;发现没问题&#xff1b;再去检查配置类&am…

FFmpeg第一话:FFmpeg 简介与环境搭建

FFmpeg 探索之旅 一、FFmpeg 简介与环境搭建 二、FFmpeg 解码详解 第一话&#xff1a;FFmpeg 简介与环境搭建 FFmpeg 探索之旅一、前言二、FFmpeg 是什么&#xff1f;三、简单介绍其历史背景四、为什么用 C学习 FFmpeg&#xff1f;&#xff08;一&#xff09;高性能优势&#…

(vue)el-table在表头添加筛选功能

(vue)el-table在表头添加筛选功能 筛选前&#xff1a; 选择条件&#xff1a; 筛选后&#xff1a; 返回数据格式: 代码: <el-tableref"filterTable":data"projectData.list"height"540":header-cell-style"{border-bottom: 1px soli…

流程引擎Activiti性能优化方案

流程引擎Activiti性能优化方案 Activiti工作流引擎架构概述 Activiti工作流引擎架构大致分为6层。从上到下依次为工作流引擎层、部署层、业务接口层、命令拦截层、命令层和行为层。 基于关系型数据库层面优化 MySQL建表语句优化 Activiti在MySQL中创建默认字符集为utf8&…

Vue3源码笔记阅读1——Ref响应式原理

本专栏主要用于记录自己的阅读源码的过程,希望能够加深自己学习印象,也欢迎读者可以帮忙完善。接下来每一篇都会从定义、运用两个层面来进行解析 定义 运用 例子:模板中访问ref(1) <template><div>{{str}}</div> </template> <script> impo…

神经网络基础-神经网络搭建和参数计算

文章目录 1.构建神经网络2. 神经网络的优缺点 1.构建神经网络 在 pytorch 中定义深度神经网络其实就是层堆叠的过程&#xff0c;继承自nn.Module&#xff0c;实现两个方法&#xff1a; __init__方法中定义网络中的层结构&#xff0c;主要是全连接层&#xff0c;并进行初始化。…

Dcoker Redis哨兵模式集群介绍与搭建 故障转移 分布式 Java客户端连接

介绍 Redis 哨兵模式&#xff08;Sentinel&#xff09;是 Redis 集群的高可用解决方案&#xff0c;它主要用于监控 Redis 主从复制架构中的主节点和从节点的状态&#xff0c;并提供故障转移和通知功能。通过 Redis 哨兵模式&#xff0c;可以保证 Redis 服务的高可用性和自动故…

机器学习之交叉熵

交叉熵&#xff08;Cross-Entropy&#xff09;是机器学习中用于衡量预测分布与真实分布之间差异的一种损失函数&#xff0c;特别是在分类任务中非常常见。它源于信息论&#xff0c;反映了两个概率分布之间的距离。 交叉熵的数学定义 对于分类任务&#xff0c;假设我们有&#…

Scala的泛型界限

泛型界限 上限 泛型的上限&#xff0c;下限。对类型的更加具体的约束&#xff01; 如果给某个泛型设置了上界&#xff1a;这里的类型必须是上界 如果给某个泛型设置了下界&#xff1a;这里的类型必须是下界