机器学习作业3____决策树(CART算法)

目录

一、简介

 二、具体步骤

样例:

三、代码

四、结果

五、问题与解决


一、简介

CART(Classification and Regression Trees)是一种常用的决策树算法,可用于分类和回归任务。这个算法由Breiman等人于1984年提出,它的主要思想是通过递归地将数据集划分为两个子集,然后在每个子集上继续划分,直到满足某个停止条件为止。

CART算法在分类和回归问题上表现良好,并且能够处理多种数据类型(包括离散型和连续型特征)。由于其简单、易于理解和实现,以及在一些应用中的良好性能,CART算法被广泛应用于实践中。

 二、具体步骤

  1. 计算整体数据集的基尼指数:

    • 首先,计算整个数据集的基尼指数。基尼指数表示从数据集中随机选择两个样本,它们属于不同类别的概率。对于每个节点,基尼指数可以通过以下公式计算: \text{Gini}(D) = 1 - \sum_{i=1}^k (p_i)^2 ( D ) 是当前节点的数据集,( k ) 是类别的数量,( p_i ) 是第 ( i ) 类样本在数据集 ( D ) 中的频率。
  2. 选择最佳特征和切分点:

    • 对于每个特征,遍历其所有可能的取值作为切分点。
    • 对于每个切分点,将数据集分为两个子集:左子集(特征值小于等于切分点)和右子集(特征值大于切分点)。
    • 计算基尼指数来衡量使用当前特征和切分点进行划分后的加权基尼指数。数学上,对于一个特征 ( A ) 的某个切分点 ( t ),其左子集和右子集的基尼指数可以计算为: \text{Gini}(A, t) = \frac{|D{\text{l}}|}{|D|} \times \text{Gini}(D_{\text{l}}) + \frac{|D_{\text{r}}|}{|D|} \times \text{Gini}(D_{\text{r}})
    • 选择使得基尼指数最小的特征和切分点作为当前节点的划分依据。
  3. 递归划分子集:

    • 根据选择的最佳特征和切分点,将当前节点的数据集划分为两个子集:左子集和右子集。
    • 对每个子集递归地重复步骤 1 和步骤 2,直到达到停止条件,例如达到最大深度、节点样本数量小于预设阈值等。
  4. 停止条件:

    • 决策树构建过程中,需要设定停止条件,以防止过度拟合或无限生长。常见的停止条件包括:达到最大深度、节点样本数量小于预设阈值、节点基尼指数低于阈值等。
  5. 剪枝:

    • 在决策树生长完成后,可以应用剪枝来降低树的复杂度和提高泛化能力。剪枝的目标是通过移除部分节点或子树来减小模型的复杂度,常见的剪枝方法有预剪枝和后剪枝。
    • 预剪枝:预剪枝是在决策树构建过程中,在决策树生长过程中进行判断并提前终止树的生长。在预剪枝中,可以设置一些停止生长的条件,例如限制树的最大深度、节点中最小样本数、基尼不纯度的阈值等。当达到任何一个预设条件时,就停止分裂节点并将该节点标记为叶子节点,不再继续向下生长,从而避免过拟合。
    • 后剪枝:后剪枝是在决策树构建完成之后,对已生成的决策树进行修剪来减少过拟合。后剪枝通过剪掉一些子树或者将子树替换为叶子节点来减少树的复杂度,从而提高泛化能力。后剪枝的过程通常是自底向上地遍历决策树,然后对每个内部节点尝试剪枝,判断剪枝后的决策树性能是否提升,如果提升则进行剪枝操作。

    • 预剪枝和后剪枝各有优劣势:预剪枝可以在构建树的过程中避免过拟合,但可能会导致欠拟合,因为它在生长时就限制了树的复杂度。后剪枝在构建完整个树后进行修剪,更容易实现,但可能会由于过拟合而导致剪枝效果不佳。

  6. 预测:

    • 使用生成的决策树对新样本进行分类预测。
    • 从根节点开始,根据特征值逐步向下遍历树的分支,直到到达叶子节点,然后将叶子节点的预测值作为样本的预测结果。

样例:

一些解释:

分裂阈值是决策树算法中用来划分数据集的一个值,它决定了将数据集分成两部分的标准。在每个节点上,决策树算法会选择一个特征和一个分裂阈值,将数据集分为两部分,使得分裂后的子集尽可能地纯净(即属于同一类别)。

假设有一个二维数据集,包含两个特征和一个类别:

X1X2Y
1.02.00
2.03.00
2.02.01
3.04.01
3.03.00

在构建过程中,需要选择一个特征和一个分裂阈值来将数据集划分为左右两个子集。假设现在可以选择特征X1,并将分裂阈值设为2.5。所有X1小于2.5的样本将被划分到左子集,而X1大于等于2.5的样本将被划分到右子集。

首先,我们根据选定的特征和分裂阈值将数据集划分成两个子集。

左子集(X1 < 2.5):

X1X2Y
1.02.00
2.03.00
2.02.01

右子集(X1 >= 2.5):

X1X2Y
3.04.01
3.03.00

对于左子集:

  • 类别0的频率:p_0 = \frac{2}{3} = 0.67
  • 类别1的频率:p_1 = \frac{1}{3} = 0.33

左子集的基尼指数为:

Gini_{left} = 1 - (0.67^2 + 0.33^2)= 0.4422

对于右子集:

  • 类别0的频率:p_0 = \frac{1}{2} = 0.5
  • 类别1的频率:p_1 = \frac{1}{2} = 0.5

右子集的基尼指数为:

Gini_{right} = 1 - (0.5^2 + 0.5^2) = 0.5

计算加权基尼指数

Weighted Gini Index = \frac{3}{5} \times Gini_{left} + \frac{2}{5} \times Gini_{right}= \frac{3}{5} \times 0.4422 + \frac{2}{5} \times 0.5 = 0.66532

在这个例子中,选定特征X1和分裂阈值2.5的加权基尼指数为约0.66532。

三、代码

import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_scoreclass DecisionTreeClassifier:def __init__(self, max_depth=None, min_samples_split=2):"""初始化决策树分类器参数:- max_depth: 决策树的最大深度,控制树的生长。默认为None,表示不限制深度。- min_samples_split: 内部节点再划分所需的最小样本数。默认为2。"""self.max_depth = max_depthself.min_samples_split = min_samples_splitdef fit(self, X, y):"""根据训练数据拟合模型参数:- X: 训练数据的特征数组。- y: 训练数据的标签数组。"""self.n_classes = len(np.unique(y))self.n_features = X.shape[1]self.tree_ = self._grow_tree(X, y)def _grow_tree(self, X, y, depth=0):"""递归地构建决策树参数:- X: 当前节点的特征数组。- y: 当前节点的标签数组。- depth: 当前节点的深度。"""# 计算每个类别的样本数n_samples_per_class = [np.sum(y == i) for i in range(self.n_classes)]# 预测当前节点的类别为样本数最多的类别predicted_class = np.argmax(n_samples_per_class)if depth < self.max_depth and X.shape[0] >= self.min_samples_split:best_gini = float('inf')best_feature = Nonebest_threshold = None# 遍历每个特征for feature in range(self.n_features):unique_values = np.unique(X[:, feature])# 遍历每个特征值作为分裂阈值for threshold in unique_values:y_left = y[X[:, feature] < threshold]y_right = y[X[:, feature] >= threshold]if len(y_left) == 0 or len(y_right) == 0:continue# 计算基尼不纯度gini = self._gini_impurity(y_left, y_right)# 选择最小基尼不纯度对应的特征和阈值if gini < best_gini:best_gini = ginibest_feature = featurebest_threshold = threshold# 如果存在可以降低基尼不纯度的分裂,则继续构建子树if best_gini < float('inf'):left_indices = X[:, best_feature] < best_thresholdX_left, y_left = X[left_indices], y[left_indices]X_right, y_right = X[~left_indices], y[~left_indices]left_subtree = self._grow_tree(X_left, y_left, depth + 1)right_subtree = self._grow_tree(X_right, y_right, depth + 1)return {'feature': best_feature, 'threshold': best_threshold,'left': left_subtree, 'right': right_subtree}# 当无法继续分裂时,返回当前节点的预测类别return {'predicted_class': predicted_class}def _gini_impurity(self, y_left, y_right):"""计算基尼不纯度参数:- y_left: 左子节点的标签数组。- y_right: 右子节点的标签数组。"""n_left, n_right = len(y_left), len(y_right)n_total = n_left + n_rightp_left = np.array([np.sum(y_left == c) / n_left for c in range(self.n_classes)])p_right = np.array([np.sum(y_right == c) / n_right for c in range(self.n_classes)])# 计算左右节点的基尼不纯度gini_left = 1.0 - np.sum(p_left ** 2)gini_right = 1.0 - np.sum(p_right ** 2)# 计算加权基尼不纯度gini = (n_left / n_total) * gini_left + (n_right / n_total) * gini_rightreturn ginidef predict(self, X):"""对输入数据进行预测参数:- X: 待预测数据的特征数组。返回:- 预测的标签数组。"""return np.array([self._predict(inputs) for inputs in X])def _predict(self, inputs):"""递归地预测单个样本的标签参数:- inputs: 单个样本的特征数组。返回:- 预测的标签。"""node = self.tree_while 'predicted_class' not in node:feature_value = inputs[node['feature']]# 根据特征值和阈值判断进入左子树还是右子树if feature_value < node['threshold']:node = node['left']else:node = node['right']return node['predicted_class']# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 初始化决策树分类器
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(X_train, y_train)
print("test:"+str(X_test)+"\n         pre:"+str(y_test))
print("Predictions:", clf.predict(X_test))accuracy = accuracy_score(y_test, clf.predict(X_test))
print("准确率:"+str(accuracy*100)+'%')

四、结果

本次实验采用的是python自带的鸢尾花数据集,将数据集8:2分为训练集和测试集,将树的最大深度设置为2,得到的结果如下:

可以看到只有一个点出现了错误,预测的效果不错。

五、问题与解决

问题1.

  1. 过拟合:决策树容易在训练数据上过拟合,即模型过于复杂,过度拟合训练数据中的噪声或特定样本,导致在测试数据上表现不佳。欠拟合:与过拟合相反,如果决策树过于简单,可能无法捕捉数据中的复杂关系,导致在训练和测试数据上都表现不佳。

  2.  解决减缓过拟合:降低模型复杂度、增加训练数据量、使用正则化技术、特征选择等。减缓欠拟合:增加模型复杂度、添加更多特征、使用更复杂的模型等

        

 问题2.

  1. 内存消耗问题:如果数据集过大或者决策树过深,可能导致内存消耗过大,甚至导致程序崩溃或者运行缓慢。

  2. 解决:限制最大深度:通过设置决策树的最大深度来限制树的复杂度,从而减少内存消耗。限制叶子节点数量:设置叶子节点的最小样本数,以避免树过于深。使用剪枝:在树的训练过程中或者之后对树进行剪枝,去掉不必要的分支和节点。分批处理数据:将数据集划分为小批次,并逐批进行训练和预测,以减少一次性处理的内存需求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/4691.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何让Ubuntu上的MySQL开发更便捷

前言 作为一款开源的数据库开发与数据库管理协同工具&#xff0c;&#xff08;OceanBase Developer Center&#xff0c;简称ODC&#xff09;&#xff0c;针对MySQL数据源&#xff0c;已提供了涵盖SQL开发、变更风险管控、数据安全合规等多个方面的功能&#xff0c;从而为MySQL…

新媒体运营-----短视频运营-----PR视频剪辑----视频调色

新媒体运营-----短视频运营-----PR视频剪辑-----持续更新(进不去说明我没写完)&#xff1a;https://blog.csdn.net/grd_java/article/details/138079659 文章目录 1. Lumetri调色&#xff0c;明暗对比度2. Lumetri调色&#xff0c;创意与矢量示波器2.1 创意2.2 矢量示波器 3. L…

视频美颜SDK与主播美颜工具的技术原理与应用场景分析

在直播视频领域中&#xff0c;视频美颜SDK和主播美颜工具发挥着至关重要的作用。本文将探讨这些工具的技术原理及其在不同应用场景中的应用。 一、视频美颜SDK的技术原理 1.1 图像处理技术 视频美颜SDK的核心技术之一是图像处理技术。根据用户设定的美颜参数进行相应的调整。…

Meta Llama 3 性能提升与推理服务部署

利用 NVIDIA TensorRT-LLM 和 NVIDIA Triton 推理服务器提升 Meta Llama 3 性能 我们很高兴地宣布 NVIDIA TensorRT-LLM 支持 Meta Llama 3 系列模型&#xff0c;从而加速和优化您的 LLM 推理性能。 您可以通过浏览器用户界面立即试用 Llama 3 8B 和 Llama 3 70B&#xff08;该…

SpringBoot 快速开始 Dubbo RPC

文章目录 SpringBoot 快速开始 Dubbo RPC下载 Nacos项目启动项目的创建创建主项目接口定义服务的创建Dubbo 服务提供者的创建服务的消费者创建 添加依赖给 Provider、Consumer 添加依赖 开始写代码定义接口在 Provider 中实现在 Consumer 里面使用创建启动类 注册中心配置启动 …

YOKOGAWA横河手操器维修hart通讯器YHC5150X-01

横河手操器设置注意事项&#xff1a;内藏指示计显示选择与单位设置 有如下 5 种显示模式及单位设置百分比显示、用户设置显示、用户设置和百分比交替显示、输入压力显示、输入压力和百分比交替显示。即应用在当没有输入时操作要求输出为20mA引压方向设置右/左侧高压&#xff0c…

Docker容器:数据管理与镜像的创建(主要基于Dockerfile)

目录 一、Docker 数据管理 1、数据卷&#xff08;Data Volumes&#xff09; 2、数据卷容器&#xff08;DataVolumes Containers&#xff09; 二、容器互联&#xff08;使用centos镜像&#xff09; 三、Docker 镜像的创建 1、基于现有镜像创建 2、基于本地模板创建 3、基…

QT Windows 实现调用Windows API获取ARP 表

简介 使用ping方式获取网络可访问或者存在的设备发现部分会无法ping通但实际网络上存在此设备&#xff0c; 但使用arp -a却可以显示出来&#xff0c; 所以现在使用windows API的方式获取arp 表。 实现 参考Windows提供的示例转化成Qt Qt .pro LIBS -liphlpapiLIBS -lws2_32…

R-Tree: 原理及实现代码

文章目录 R-Tree: 原理及实现代码1. R-Tree 原理1.1 R-Tree 概述1.2 R-Tree 结构1.3 R-Tree 插入与查询 2. R-Tree 实现代码示例&#xff08;Python&#xff09;结语 R-Tree: 原理及实现代码 R-Tree 是一种用于管理多维空间数据的数据结构&#xff0c;常用于数据库系统和地理信…

【CANoe示例分析】TCP Chat(CAPL) with TLS encription

1、工程路径 C:\Users\Public\Documents\Vector\CANoe\Sample Configurations 15.3.89\Ethernet\Simulation\TLSSimChat 在CANoe软件上也可以打开此工程:File|Help|Sample Configurations|Ethernet - Simulation of Ethernet ECUs|Basic AUTOSAR Adaptive(SOA) 2、示例目…

面试题:斐波那契数列

题目描述&#xff1a; 写一个函数,输入n,求斐波那契数列的第n项.斐波那契数列定义如下: F(0) 0 F(1) 1 F(N) F(N - 1) F(N - 2), 其中 N > 1. 解题方法&#xff1a; 算法1: 利用递归实现,这个方法效率有严重问题,时间复杂度为O(2^n) long long Fibon(int n) {if (…

微软如何打造数字零售力航母系列科普03 - Mendix是谁?作为致力于企业低代码服务平台的领头羊,它解决了哪些问题?

一、Mendix 成立的背景 Mendix的成立是为了解决软件开发中最大的问题&#xff1a;业务和IT之间的脱节。这一挑战在各个行业和地区都很普遍&#xff0c;很简单&#xff1a;业务需求通常被描述为IT无法正确解释并转化为软件。业务和IT之间缺乏协作的原因是传统的代码将开发过程限…

WPF —— MVVM 指令执行不同的任务实例

标签页 设置两个按钮&#xff0c; <Button Content"修改状态" Width"100" Height"40" Background"red"Click"Button_Click"></Button><Button Content"测试"Width"100"Height"40&…

如何让用户听话?

​福格教授&#xff08;斯坦福大学行为设计实验室创始人&#xff09;通过深入研究人类行为20年&#xff0c;2007年用自己的名子命名&#xff0c;提出了一个行为模型&#xff1a;福格行为模型。 模型表明&#xff1a;人的行为发生&#xff0c;要有做出行为的动机和完成行为的能…

web安全---xss漏洞/beef-xss基本使用

what xss漏洞----跨站脚本攻击&#xff08;Cross Site Scripting&#xff09;&#xff0c;攻击者在网页中注入恶意脚本代码&#xff0c;使受害者在浏览器中运行该脚本&#xff0c;从而达到攻击目的。 分类 反射型---最常见&#xff0c;最广泛 用户将带有恶意代码的url打开&a…

二叉树理论和题目

二叉树的种类 在我们解题过程中二叉树有两种主要的形&#xff1a;满二叉树和完全二叉树。 满二叉树 满二叉树&#xff1a;如果一棵二叉树只有度为0的结点和度为 2 的结点&#xff0c;并且度为 0 的结点在同一层上&#xff0c;则这棵二叉树为满二叉树。 这棵二叉树为满二叉树…

如何使用SOCKS5代理?

SOCKS5 是一个代理协议&#xff0c;在使用TCP/IP协议通讯的前端机器和服务器机器之间扮演一个中介角色&#xff0c;使得内部网中的前端机器变得能够访问Internet网中的服务器&#xff0c;或者使通讯更加安全。那么&#xff0c;SOCKS5代理该如何使用呢&#xff1f; 首先需要获取…

Matlab实现CNN-LSTM模型,对一维时序信号进行分类

1、利用Matlab2021b训练CNN-LSTM模型&#xff0c;对采集的一维时序信号进行分类二分类或多分类 2、CNN-LSTM时序信号多分类执行结果截图 训练进度&#xff1a; 网络分析&#xff1a; 指标变化趋势&#xff1a; 代码下载方式&#xff08;代码含数据集与模型构建&#xff0c;附…

BERT一个蛋白质-季军-英特尔创新大师杯冷冻电镜蛋白质结构建模大赛-paipai

关联比赛: “创新大师杯”冷冻电镜蛋白质结构建模大赛 解决方案 团队介绍 paipai队、取自 PAIN AI&#xff0c;核心成员如我本人IvanaXu(IvanaXu GitHub)&#xff0c;从事于金融科技业&#xff0c;面向银行信用贷款的风控、运营场景。但我们团队先后打过很多比赛&#xf…

社交媒体数据恢复:Rocket Chat

Rocket.Chat 数据恢复方法 1. 数据备份 在探讨数据恢复方法之前&#xff0c;重要的是要了解Rocket.Chat有一个自动备份功能。这个备份功能可以将你的数据定期备份到/var/snap/rocketchat-server//backup.tgz1 。如果你的Rocket.Chat服务器已经启用了这个自动备份功能&#xf…