决策树(Decision Tree)

  • 决策树的原理

    决策树算法是一种基于树结构的分类和回归算法。它通过对数据集进行递归地分割,构建一个树形模型,用于预测目标变量的值。

    决策树的构建过程基于以下原理:
    1. 特征选择:选择最佳的特征来进行数据集的分割。常用的特征选择指标有信息增益、信息增益比、基尼指数等。
    2. 样本划分:根据选定的特征将数据集划分为不同的子集。每个子集中的样本具有相同的特征值。
    3. 递归构建:对每个子集递归地应用上述步骤,直到满足终止条件,例如达到最大深度、样本数小于阈值或者没有更多特征可供选择。
    4. 叶节点生成:当终止条件满足时,将子集划分为叶节点,该叶节点表示一个分类或回归结果。

    在分类问题中,每个叶节点表示一种类别;在回归问题中,每个叶节点表示一个数值。

    决策树算法具有以下优点:
    1、可解释性强,易于理解和解释生成的模型。
    2、能够处理离散和连续型特征。
    3、能够处理多输出问题。
    4、对缺失值和异常值具有一定的鲁棒性(鲁棒性是指一个系统或算法在面对噪声、异常值、失效或者意外情况时,能够保持稳定的表现)。

    然而,决策树算法也存在一些缺点:
    1、容易过拟合(过拟合(overfitting)是指机器学习模型过于复杂,过于精细地拟合了训练数据集中的每一个样本,以至于在新的数据集上表现不佳的现象),特别是在处理复杂数据集时。
    2、对于包含大量特征的数据集,决策树可能过于复杂,容易产生过度分支。
    3、对于类别数量不平衡的数据集,决策树可能偏向于拟合样本较多的类别。

    为了克服这些问题,可以采用剪枝技术、随机森林(小蒟蒻还没学)等方法对决策树进行改进和优化。

  • 决策树的构建

    1. ID3算法:ID3(Iterative Dichotomiser 3)是一种基于信息增益的决策树构建算法。它通过计算每个属性的信息增益来选择最优的划分属性,递归地构建决策树。

    2. C4.5算法:C4.5是ID3算法的改进版本,它使用信息增益率来选择最优的划分属性。与ID3相比,C4.5还能处理缺失值,并且支持连续型属性。

    3. CART算法:CART(Classification and Regression Trees)是一种既能构建分类树又能构建回归树的决策树算法。CART算法通过基尼指数或均方差来选择最优的划分属性,并采用二叉树结构进行构建。

    以上是常见的决策树构建方法,不同的算法在属性选择和剪枝策略上有所差异,选择适合问题需求的方法可以提高决策树的性能和准确度。

本文以CART方法为例,基于鸢尾花数据集实现CART分类决策树算法

一、CART决策树算法简介

    CART(Classification and Regression Trees)是一种常用的策树算法,可以用于分类和回归问题。CART算法通过对数据集进行递归的二分划分,构建出一棵二叉树模型。

    CART算法的划分准则是基于Gini指数或基尼不纯度。在分类问题中,Gini指数衡量了一个样本集合中不同类别样本的不均匀程度。选择划分特征时,CART算法通过计算每个特征的Gini指数,选择使得Gini指数最小的特征作为划分特征。

    CART算法通过递归地对数据集进行划分,直到满足停止条件。停止条件可以是达到最大深度、叶子节点的样本数小于某个阈值等。划分过程中,每次选择Gini指数最小的特征进行划分,并将数据集按照该特征的取值分为两部分。

    对于分类问题,CART算法构建出的决策树可以用于预测新样本的类别。对于回归问题,CART算法构建出的决策树可以用于预测新样本的实数值。

    CART算法具有较好的可解释性和较高的准确性,在实际应用中被广泛使用。它能够处理离散和连续特征,并且对异常值和缺失值具有较好的鲁棒性。同时,CART算法还可以通过剪枝来提高模型的泛化能力,避免过拟合问题。

二、基尼系数

    基尼系数代表模型的不纯度,基尼系数越小,则不纯度越低。

    在分类问题中,假设有K个类,样本点属于第k类的概率为pk,则概率分布的基尼系数定义为:

                                  {\mathop{\rm Gini}\nolimits} (p) = \sum\limits_{k = 1}^K {​{p_k}(1 - {p_k})} = 1 - \sum\limits_{k = 1}^K {p_k^2}

     对给定的样本集合D,其基尼指数:

                                

三、CART决策树生成

1. 初始化:将整个训练集作为输入数据,选择一个目标变量(分类问题中是分类变量,回归问题中是连续变量)。

2. 选择最佳切分特征:根据某种指标(如信息增益、基尼系数等),计算每个特征的切分点,选择最佳的特征和切分点作为当前节点的切分标准。

3. 根据切分标准将数据集分割为两个子集:根据最佳切分特征和切分点,将数据集划分为两个子集,一个子集包含满足切分标准的样本,另一个子集包含不满足切分标准的样本。

4. 递归生成子树:对于每个子集,重复步骤2和步骤3,直到满足停止条件(如达到预定的树深度、样本数量小于阈值等)。

5. 构建决策树:将生成的子树连接到当前节点上,形成完整的决策树。

6. 剪枝处理:对生成的决策树进行剪枝处理,通过损失函数最小化或交叉验证选择最优的剪枝参数,以避免过拟合。

7. 输出决策树模型:将生成的决策树模型输出。 

以上是CART决策树生成算法的基本流程,它是一种基于贪心策略的自上而下生成方法。

四、CART决策树代码

手敲实现:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as npclass Node:          # 特征          阈值             标签        左子树     右子树def __init__(self, feature=None, threshold=None, label=None, left=None, right=None):self.feature = featureself.threshold = thresholdself.label = labelself.left = leftself.right = rightclass DecisionTree:def __init__(self, max_depth=None):self.max_depth = max_depth # 可以用于剪枝def _gini(self, y): # 计算基尼不纯度classes, counts = np.unique(y, return_counts=True) # 获取y中的所有唯一类别和它们对应的出现次数impurity = 1 - np.sum((counts / np.sum(counts)) ** 2) # 根据基尼不纯度的公式计算不纯度值return impuritydef _best_split(self, X, y):best_gini = float('inf')best_feature = None # 最佳特征best_threshold = None # 最佳分割点for feature in range(X.shape[1]): # X.shape[1],特征数量# thresholds代表鸢尾花的一列特征值(唯一值)thresholds:numpy.ndarray = np.unique(X[:, feature]) # 获取该特征的所有唯一值作为候选分割点的阈值。for threshold in thresholds: # 遍历每个候选分割点的阈值。对于每个阈值,根据特征值与阈值的比较,将样本分为左子集和右子集。# 对于每个阈值,根据特征值与阈值的比较,将样本分为左子集和右子集。left_indices:bool = X[:, feature] < threshold # left_indices 是一个布尔数组,表示哪些样本属于左子集。right_indices:bool = X[:, feature] >= thresholdleft_gini = self._gini(y[left_indices]) # 计算左子集的基尼系数(基尼不纯度)right_gini = self._gini(y[right_indices]) # 计算右子集的基尼系数(基尼不纯度)'''根据左右子集的基尼系数和样本数量计算加权平均基尼系数 gini,并与当前的最佳基尼系数 best_gini 进行比较。如果当前 gini 值小于 best_gini,则更新 best_gini、best_feature 和 best_threshold。'''gini = (left_gini * np.sum(left_indices) + right_gini * np.sum(right_indices)) / len(y) # 加权平均基尼系数if gini < best_gini:best_gini = ginibest_feature = featurebest_threshold = thresholdreturn best_feature, best_threshold # 返回找到的最佳特征和最佳分割点的信息。def _build_tree(self, X, y, depth):# 检查是否达到了最大深度或者标签只有一种类别,如果满足其中一种条件,就创建一个叶子节点,并将最常见的标签作为节点的标签值。if depth == self.max_depth or len(np.unique(y)) == 1: # 起到了剪枝的作用'''set(y):将列表 y 转换为一个集合。集合是无序且不包含重复元素的数据结构。y.tolist():将集合 y 转换为列表。这是因为集合对象本身不支持 count 方法,而列表对象支持。y.tolist().count:使用列表的 count 方法,返回每个元素在列表中出现的次数。max(set(y), key=y.tolist().count):使用 max 函数,根据元素的出现次数找出最大值,并返回该元素。最终,变量 label 将被赋值为列表 y 中出现次数最多的元素。'''label = max(set(y), key=y.tolist().count)return Node(label=label)feature, threshold = self._best_split(X, y) # 最佳特征和最佳分割点的信息。if feature is None or threshold is None: # 起到了剪枝的作用label = max(set(y), key=y.tolist().count)return Node(label=label)# 代码根据分割特征和阈值将训练数据划分为左子节点和右子节点的索引。left_indices = X[:, feature] < threshold # left_indices 是一个布尔数组,表示哪些样本属于左子集。right_indices = X[:, feature] >= threshold# 递归调用 _build_tree 方法,传入左子节点和右子节点对应的训练数据和标签,以及增加了深度的值。left_node = self._build_tree(X[left_indices], y[left_indices], depth+1)right_node = self._build_tree(X[right_indices], y[right_indices], depth+1)# 代码创建一个节点对象,并将分割特征、阈值、左子节点和右子节点赋值给节点对象,并返回该节点。node = Node(feature=feature, threshold=threshold, left=left_node, right=right_node)return nodedef fit(self, X, y): # fit方法用于训练模型self.root = self._build_tree(X, y, 0) # _build_tree方法用于构建决策树的节点def _predict_single(self, x, node):  # _predict_single方法用于递归地预测单个样本的类别# 如果 node.label 不是 None 的话,就返回 node.labelif node.label is not None:return node.labelif x[node.feature] < node.threshold:return self._predict_single(x, node.left)else:return self._predict_single(x, node.right)def predict(self, X): # 用于预测新的样本的类别y_pred = []for x in X:label = self._predict_single(x, self.root)y_pred.append(label)return np.array(y_pred)'''在预测过程中,根据节点的特征和阈值,将样本分配到左子树或右子树,直到叶子节点得到最终的预测结果。最终将预测结果存储在y_pred列表中,并返回一个numpy数组。'''# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 创建决策树模型并训练
tree = DecisionTree(max_depth=3)
tree.fit(X_train, y_train)# 预测测试集
y_pred = tree.predict(X_test)# 输出准确率
accuracy = np.sum(y_pred == y_test) / len(y_test)
print("准确率:", accuracy)'''
运行结果:
准确率: 1.0
'''

掉包实现:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target# 将数据集拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 创建决策树分类器
clf = DecisionTreeClassifier()# 在训练集上训练决策树模型
clf.fit(X_train, y_train)# 在测试集上进行预测
y_pred = clf.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("准确率:", accuracy)'''
运行结果:
准确率: 1.0
'''

与knn算法进行试验对比:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score# 加载鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=50)# 使用决策树算法进行训练和预测
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X_train, y_train)
dt_pred = dt_clf.predict(X_test)
dt_accuracy = accuracy_score(y_test, dt_pred)# 使用knn算法进行训练和预测
knn_clf = KNeighborsClassifier(n_neighbors=5)
knn_clf.fit(X_train, y_train)
knn_pred = knn_clf.predict(X_test)
knn_accuracy = accuracy_score(y_test, knn_pred)# 输出结果
print("决策树算法准确率:", dt_accuracy)
print("knn算法准确率:", knn_accuracy)'''
运行结果:
决策树算法准确率: 0.9555555555555556
knn算法准确率: 0.9333333333333333
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/5886.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

堆排序与直接选择排序

目录 一、直接选择排序 1.基本思想 2.直接选择排序的特性总结 3.代码实现&#xff1a; 二、堆排序 1. 概念&#xff1a; 2.图像实现&#xff1a; 3.代码实现&#xff1a; 一、直接选择排序 1.基本思想 每一次从待排序的数据元素中选出最小&#xff08;或最大&#xff09…

MySQL多表连接查询练习

准备工作 创建student表 CREATE TABLE student ( id INT(10) NOT NULL UNIQUE PRIMARY KEY , name VARCHAR(20) NOT NULL , sex VARCHAR(4) , birth YEAR, department VARCHAR(20) , address VARCHAR(50) );创建score表 CREATE TABLE score ( id INT(10) NOT …

NetSuite ERP顾问的进阶之路

目录 1.修养篇 1.1“道”是什么&#xff1f;“器”是什么&#xff1f; 1.2 读书这件事儿 1.3 十年计划的力量 1.3.1 一日三省 1.3.2 顾问损益表 1.3.3 阶段课题 2.行为篇 2.1协作 2.2交流 2.3文档管理 2.4时间管理 3.成长篇 3.1概念能力 3.1.1顾问的知识结构 …

大数据学习05-Kafka分布式集群部署

系统环境&#xff1a;centos7 软件版本&#xff1a;jdk1.8、zookeeper3.4.8、hadoop2.8.5 本次实验使用版本 kafka_2.12-3.0.0 一、安装 Kafka官网 将安装包上传至linux服务器上 解压 tar -zxvf kafka_2.12-3.0.0.tgz -C /home/local/移动目录至kafka mv kafka_2.12-3.0…

护城河理论

护城河理论 护城河理论|来自股神巴菲特&#xff0c;是指投资的企业在某一方面的核心竞争力。 模型介绍 在2000年的伯克希尔哈撒韦的年会上&#xff0c;巴菲特说&#xff1a;让我们来把护城河作为一个伟大企业的首要标准&#xff0c;保持它的宽度&#xff0c;保持它不被跨越。我…

听GPT 讲K8s源代码--pkg(五)

在 Kubernetes 中&#xff0c;kubelet 是运行在每个节点上的主要组件之一&#xff0c;它负责管理节点上的容器&#xff0c;并与 Kubernetes 控制平面交互以确保容器在集群中按照期望的方式运行。kubelet 的代码位于 Kubernetes 代码库的 pkg/kubelet 目录下。 pkg/kubelet 目录…

数学建模-分类模型 Fisher线性判别分析

论文中1. 判别分析系数 2. 分类结果 多分类问题 勾选内容和上面一样

【C++】入门 --- 命名空间

文章目录 &#x1f36a;一、前言&#x1f369;1、C简介&#x1f369;2、C关键字 &#x1f36a;二、命名冲突&#x1f36a;三、命名空间&#x1f369;1、命名空间定义&#x1f369;2、命名空间的使用 &#x1f36a;四、C输入&输出 &#x1f36a;一、前言 本篇文章是《C 初阶…

Linux笔记——管道相关命令以及shell编程

文章目录 管道相关命令 目标 准备工作 1 cut 1.1 目标 1.2 路径 1.3 实现 2 sort 2.1 目标 2.2 路径 2.3 实现 第一步: 对字符串排序 第二步&#xff1a;去重排序 第三步: 对数值排序 默认按照字符串排序 升序 -n 倒序 -r 第四步: 对成绩排序【按照列排序】 …

ffmpeg中filter_query_formats函数解析

ffmpeg中filter_query_formats主要起一个pix fmt引用指定的功能。 下下结论&#xff1a; 先看几个结构体定义&#xff1a; //删除了一些与本次分析不必要的成员 struct AVFilterLink {AVFilterContext *src; ///< source filterAVFilterPad *srcpad; ///<…

PhpStudy靶场首页管理

PhpStudy靶场首页管理 一、源码一二、源码二三、源码三四、源码四 一、源码一 index.html <!DOCTYPE html> <html><head><meta charset"UTF-8"><title>靶场访问首页</title><style>body {background-color: #f2f2f2;colo…

JavaDemo——使用jks的https

java使用https主要就是设置下sslContext&#xff0c;sslContext初始化需要密钥管理器和信任管理器&#xff0c;密钥管理器用于管理本地证书和私钥&#xff0c;信任管理器用于验证远程服务器的证书&#xff0c;这两种管理器都需要KeyStore初始化&#xff0c;两种管理器可以按需只…

Ubuntu 网络配置指导手册

一、前言 从Ubuntu 17.10 Artful开始&#xff0c;Netplan取代ifupdown成为默认的配置实用程序&#xff0c;网络管理改成 netplan 方式处理&#xff0c;不在再采用从/etc/network/interfaces 里固定 IP 的配置 &#xff0c;配置写在 /etc/netplan/01-network-manager-all.yaml 或…

【事业单位-语言理解1】中心理解02

【事业单位-语言理解1】中心理解02 1.中心理解1.1 并列关系1.2 主题词1.3程度词&#xff0c;表示强调 二、标题填入题&#xff08;优先考虑主题词&#xff09;三、词句理解题 1.中心理解 解题思路 1.1 并列关系 涉及时间顺序 注意选项不要逻辑不当 并列关系的时候&…

行云创新 CloudOS 助力上汽乘用车企业云原生IT架构变革

近日&#xff0c;在2023架构可持续未来峰会成都制造业分会场上&#xff0c;上海汽车集团股份有限公司乘用车公司基础架构部主管茹洋带来了议题为《云原生时代上汽乘用车企业IT架构变革和实践》的精彩演讲。他从云原生对于企业IT架构的意义、企业IT架构变革的必要性入手&#xf…

C程序环境及预处理

​​​​​文章目录 一、程序的翻译环境和执行环境 1.程序编译过程 2.编译内部原理 3.执行环境 二、程序运行前的预处理 1.预定义符号归纳 2.define定义标识符 3.define定义宏 4.define替换规则 5.宏和函数的对比 三、头文件被包含的方式 四、练习&#xff1a;写一…

Vue3状态管理库Pinia——核心概念(Store、State、Getter、Action)

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

Java将数据集合转换为PDF

这里写自定义目录标题 将数据集合转换为pdf引入包工具类测试代码导出效果 将数据集合转换为pdf 依赖itext7包将数据集合转换导出为PDF文件 引入包 <properties><itext.version>7.1.11</itext.version> </properties><dependency><groupId&…

什么是HTTP 500错误,怎么解决

目录 什么是HTTP 500 HTTP 500错误的常见原因&#xff1a; 如何修复HTTP 500 总结 什么是HTTP 500 错误 HTTP 500内部服务器错误是指在客户端发出请求后&#xff0c;服务器在处理请求过程中发生了未知的问题&#xff0c;导致服务器无法完成请求。HTTP 500错误是一个通用的服…

Spring-缓存初步认识

Spring-缓存 简单介绍 缓存是一种介于数据永久存储介质和数据应用之间的数据临时存储介质缓存有效提高读取速度&#xff0c;加速查询效率 spring使用缓存方式 添加依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring…