python机器学习——决策树

决策树

# 模块导入
from sklearn.tree import ExtraTreeRegressor as ETR, DecisionTreeRegressor as DTR

ExtraTreeRegressorDecisionTreeRegressor是scikit-learn库中的两种回归模型,用于拟合和预测连续型目标变量。

决策树是一种基于树结构的机器学习算法,用于解决分类和回归问题。它通过对数据的特征进行一系列判断和分支,逐步将数据集划分成不同的子集,最终得到一个基于特征的树形结构,用于预测新数据的类别或数值。

决策树算法的核心思想是在每个节点上选择最有价值的特征进行划分,使得子节点间的纯度尽可能高,同时保持树的简单性。在分类任务中,纯度通常指子节点中样本所属类别的比例;在回归任务中,纯度通常指子节点中样本目标变量的方差或标准差

决策树在处理离散型和连续型特征时有不同的处理方式,其中最常见的方法是使用信息增益、信息增益比、基尼指数等指标进行节点划分。对于过拟合问题,可以通过剪枝、随机森林等方法进行处理。

两种树对比

DecisionTreeClassifier和DecisionTreeRegressor是决策树算法的两个变体,用于解决分类和回归问题。它们的主要区别在于所解决的问题类型和输出类型。

  1. DecisionTreeClassifier(决策树分类器):
    • 问题类型:DecisionTreeClassifier用于解决分类问题,即将样本分为不同的类别。
    • 输出类型:其输出是一个离散的类别标签,表示样本属于哪个类别。
  2. DecisionTreeRegressor(决策树回归器):
    • 问题类型:DecisionTreeRegressor用于解决回归问题,即预测连续目标变量的值。
    • 输出类型:其输出是一个连续的数值,表示预测的目标变量的值。

除了上述区别,DecisionTreeClassifier和DecisionTreeRegressor在算法实现上也有一些略微的差异:

  • 分割准则:分类树通常使用基尼系数(Gini index)或熵(entropy)来度量特征的重要性,以选择最佳的分割点。而回归树通常使用平方误差(mean squared error)作为分割准则。
  • 剪枝策略:为了防止过拟合,决策树通常需要进行剪枝。对于分类树来说,常用的剪枝策略有预剪枝和后剪枝。对于回归树来说,通常采用贪心策略进行自底向上的剪枝。

总结起来,DecisionTreeClassifier和DecisionTreeRegressor主要区别在于解决的问题类型和输出类型。前者适用于分类问题,输出离散类别标签;后者适用于回归问题,输出连续数值。

决策树中的专业名词

  • 节点(Node):表示决策树上的一个数据处理单元,包含一个或多个子节点和一个父节点。
  • 根节点(Root Node):表示决策树的起始节点,没有父节点。
  • 叶节点(Leaf Node):表示决策树的终止节点,没有子节点。
  • 内部节点(Internal Node):表示除根节点和叶节点外的其他节点,拥有一个或多个子节点。
  • 特征(Feature):表示决策树划分节点时使用的属性或特征值,可以是离散型或连续型。
  • 阈值(Threshold):表示用于划分连续型特征的阈值,通常是根据特征值的中位数或平均值确定的。
  • 深度(Depth):表示决策树从根节点到某个节点的路径长度,根节点的深度为0。
  • 路径(Path):表示从根节点到叶节点的一条路径,由一系列节点和边组成。
  • 分支(Branch):表示从一个节点到其子节点的一条边。
  • 剪枝(Pruning):表示对决策树进行修剪,以防止模型过拟合。常用的剪枝方法有预剪枝(Pre-Pruning)和后剪枝(Post-Pruning)。
  • 信息增益(Information Gain):表示在某个节点上划分前后数据集的信息熵差异,用于选择最佳划分特征。
  • 基尼指数(Gini Index):表示在某个节点上划分前后数据集的基尼系数差异,用于选择最佳划分特征。

DecisionTreeRegressor

导入模块

from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor

创建模型对象

dtr = DecisionTreeRegressor(max_depth=None, criterion='mse', splitter='best', random_state=None)

参数说明

  • max_depth:决策树的最大深度,默认为None,表示不限制深度。
  • criterion:节点划分的标准,可选’mse’(均方误差)或’mae’(平均绝对误差),默认为’mse’。
  • splitter:节点划分的策略,可选’best’(最优划分)或’random’(随机划分),默认为’best’。
  • random_state:随机种子,用于重复实验。

拟合模型

dtr.fit(X,y)

X是一个二维数组或者数据框,其中每一行代表一个样本,每一列表示一个特征

y是目标变量向量,是一个一维数组或列表,其中每个元素表示一个样本的目标值

预测

y_pred_dtr = dtr.predict(X_test)

其中X_test是我们待预测的新特征矩阵

示例代码

数据分为训练集和测试集

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

train_test_split(X, y, test_size=0.2, random_state=42) 是一个常用的函数调用,用于将数据集 X 和对应的目标变量 y 划分为训练集和测试集。具体解释如下:

  • X:表示样本特征矩阵,其中每一行代表一个样本,每一列代表一个特征。
  • y:表示目标变量(或标签),是与样本特征矩阵 X 对应的目标值。
  • test_size=0.2:表示将数据集划分为训练集和测试集时,测试集的大小为全部数据的 20%。
  • random_state=42:表示设置随机数种子为 42,用于控制随机划分的重现性。

该函数会返回划分后的训练集和测试集,以便后续在机器学习模型中使用。具体返回结果会有以下四个元组:

  • X_train:表示划分后的训练集样本特征。
  • X_test:表示划分后的测试集样本特征。
  • y_train:表示划分后的训练集目标变量。
  • y_test:表示划分后的测试集目标变量。

通过这个函数可以确保训练集和测试集的划分是随机的,并且可以重复该划分过程。同时,通过指定随机数种子,可以使得每次运行时得到相同的划分结果,以保持实验的可重现性。

具体代码

from sklearn.datasets import load_diabetes # 导入糖尿病数据集
from sklearn.model_selection import train_test_split # 将数据划分为训练集和测试集
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error # 计算均方误差# 加载糖尿病数据集
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# X_train, X_test, y_train, y_test 分别是 训练集样本特征矩阵 测试集样本特征矩阵 训练集目标向量 测试集目标向量# 创建决策树回归模型dtr = DecisionTreeRegressor(max_depth=5, random_state=42)# 训练模型dtr.fit(X_train, y_train)# 预测测试集y_pred_dtr = dtr.predict(X_test)# 评估模型性能mse = mean_squared_error(y_test, y_pred_dtr)
print("均方误差 (MSE):", mse) # 3600 均方误差较大,需要改进模型

均方误差

均方误差(Mean Squared Error,MSE)是一种常用的回归模型评估指标。它用于衡量模型预测结果与真实值之间的差异程度,具体计算方式如下:

MSE = (1/n) * Σ(yᵢ - ŷᵢ)²

其中,n 是样本数量,yᵢ 是真实值,ŷᵢ 是模型的预测值。

MSE 的计算方法是将每个样本的预测误差平方后求和,再除以样本数量。因为误差被平方,所以 MSE 比较敏感,较大的误差会被放大,而较小的误差则相对较小。

对于 MSE 来说,**数值越小表示模型的预测结果与真实值之间的差异越小,模型的拟合能力越好。**当 **MSE 为0时,表示模型完全拟合了训练数据,但这可能意味着模型过于复杂,存在过拟合的风险。**通常情况下,我们希望选择一个使得 MSE 较小且在训练集和测试集上表现一致的模型。

需要注意的是,MSE 的值与数据集的单位相关,因此无法直接进行跨数据集的比较。在评估模型时,可以将 MSE 与其他模型的 MSE 进行比较,或者将其与问题的背景和要求相结合来进行评估,例如与实际误差的大小进行比较或与领域专家的知识相结合。

模型优化

要修改决策树的参数、进行剪枝以及使用基尼系数进行划分,使得模型更加优化,通常需要使用机器学习库来实现

首先,我们导入所需的库和数据集:

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

加载鸢尾花数据集并将其分为训练集和测试集:

data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)

创建决策树分类器对象,并设置参数:

clf = DecisionTreeClassifier(criterion='gini', max_depth=None, random_state=42)

其中,criterion参数设置了用于划分的准则,这里选择了基尼系数(gini index)。max_depth参数表示树的最大深度,设置为None表示不限制深度。random_state参数用于确定每次运行时的随机性,以便结果可重复。

拟合(训练)决策树分类器:

clf.fit(X_train, y_train)

使用训练好的模型进行预测:

y_pred = clf.predict(X_test)

计算预测的准确率:

accuracy = accuracy_score(y_test, y_pred)
print("准确率:", accuracy)

接下来,我们可以进行剪枝。

首先,我们可以使用预剪枝设置max_depth参数限制树的最大深度:

clf_pruned = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)

然后,重复之前的拟合、预测和准确率计算过程:

clf_pruned.fit(X_train, y_train)
y_pred_pruned = clf_pruned.predict(X_test)
accuracy_pruned = accuracy_score(y_test, y_pred_pruned)
print("剪枝后的准确率:", accuracy_pruned)

最后,基于基尼系数的划分是决策树算法的默认选择,所以不需要额外的代码来设置它。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/133243.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VLAN与配置

VLAN与配置 什么是VLAN 以最简单的形式为例。如下图,此时有4台主机处于同一局域网中,很明显这4台主机是能够直接通讯。但此时我需要让处于同一局域网中的PC3和PC4能通讯,PC5和PC6能通讯,并且PC3和PC4不能与PC5和PC6通讯。 为了实…

笔记本电脑 禁用/启用 自带键盘

现在无论办公还是生活 很多人都会选择笔记本电脑 但很多人喜欢机械键盘 或者 用一些外接键盘 但是很多时候我们想操作 会碰到笔记本原来的键盘导致错误操作 那么 我们就需要将笔记本原来的键盘禁用掉 我们先以管理员身份运行命令窗口 然后 有两个命令 禁用默认键盘 sc conf…

你犯过程序员容易犯的这些错误吗?快来看看!

一、前言 写了20多年代码,我见过不下于4位数的程序员,我觉得程序员的能力水平可以分为4个阶段:线性级、逻辑级、架构级和工程级。 同样的在这些人当中,我也发现了8个程序员最常见的陋习,基本上可以覆盖90%的人&#…

GPT学习笔记

百度的文心一言 阿里的通义千问 通过GPT能力,提升用户体验和产品力 GPT的出现是AI的iPhone时刻。2007年1月9日,第一代iPhone发布,开启移动互联网时代。新一轮的产业革命。 GPT模型发展时间线: Copilot - 副驾驶 应用&#xf…

大数据毕业设计选题推荐-家具公司运营数据分析平台-Hadoop-Spark-Hive

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

在Windows或Mac上安装并运行LLAMA2

LLAMA2在不同系统上运行的结果 LLAMA2 在windows 上运行的结果 LLAMA2 在Mac上运行的结果 安装Llama2的不同方法 方法一: 编译 llama.cpp 克隆 llama.cpp git clone https://github.com/ggerganov/llama.cpp.git 通过conda 创建或者venv. 下面是通过conda 创建…

我的崽崽跑着跑就长大了

一瞬间感觉你都长这么大了,看着你骑单车的背影,不知不觉心里感觉到有点酸酸的,回头想想看着你,一个人带你在累,在苦都值得,萌娃骑车。 你的可爱能治愈我的一切不快乐。

大数据商城人流数据分析与可视化 - python 大数据分析 计算机竞赛

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 基于大数据的基站数据分析与可视化 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🥇学长这里给一个题目综合评分(每项满分5分) 难度…

今天零九的雪

下雪了, 漫天大雪。 好大的雪。 远处的鸡鸣, 在一个高嗓门的公鸡的带领下, 应该是醒了。 苍茫的天空下, 白蒙蒙的雾笼罩着山坳, 挡住了远处的山峰, 哦,那不是雾,是雪。 初春的微…

Linux内核移植之主频设置

一. Linux内核移植 正点原子 ALPHA开发板已经添加到 Linux内核里面去了,前面文章关于如何添加已经掌握。但是,还有一些驱动的问题需要修改。 正点原子 I.MX6U-ALPHA 开发板所使用的 I.MX6ULL 芯片主频都是 792MHz 的,也就是NXP 官方宣…

[pytorch]手动构建一个神经网络并且训练

0.写在前面 上一篇博客全都是说明类型的,实际代码能不能跑起来两说,谨慎观看.本文中直接使用fashions数据实现softmax的简单训练并且完成结果输出.实现一个预测并且观测到输出结果. 并且更重要的是,在这里对一些训练的过程,数据的形式,以及我们在softmax中主要做什么以及怎么…

Linux 指令心法(十五)`flash_eraseall` 擦除整个Flash存储器

文章目录 flash_eraseall作用flash_eraseall命令的主要特点和使用场景flash_eraseall命令应用方法flash_eraseall命令可以解决哪些问题?flash_eraseall命令使用时注意事项 flash_eraseall作用 这个命令可以擦除整个Flash存储器,将所有数据清除为初始状态。使用这个…

19.7 Boost Asio 传输序列化数据

序列化和反序列化是指将数据结构或对象转换为一组字节,以便在需要时可以将其存储在磁盘上或通过网络传输,并且可以在需要时重新创建原始对象或数据结构。 序列化是将内存中的对象转换为字节的过程。在序列化期间,对象的状态被编码为一组字节…

数字化转型:云表低代码开发助力制造业腾飞

数字化转型已成为制造业不可避免的趋势。为了应对市场快速变化、提高运营效率以及降低成本,制造业企业积极追求更加智能化、敏捷的生产方式。在这个转型过程中,低代码技术作为一种强大的工具,正逐渐崭露头角,有望加速制造业的数字…

rust持续学习 raw pointer 1

c里头 float a 1; float* p &a; 是可以直接 int * p1 (int*)p; 来强转类型做一些事情的经过了解 rust里是这么操作的 unsafe {std::mem::transmute::<origin_type, target_type>(raw_bytes) };比如上面是四个u8,可以拼一个u32 然后是函数指针这个东西 fn foo()…

Java设计模式——策略模式

1.策略模式简介 策略模式&#xff1a;策略模式是一种行为型模式&#xff0c;它将对象和行为分开&#xff0c;将行为定义为 一个行为接口 和 具体行为的实现。策略模式最大的特点是行为的变化&#xff0c;行为之间可以相互替换。每个if判断都可以理解为就是一个策略。本模式使得…

WindowsServer2019-搭建FTP服务器

这里写自定义目录标题 一、基础配置IP地址安装FTP服务检查连通性Windows10连接FTP服务 二、了解和使用FTP具体模块及其配置1、FTP IP地址和域限制2、FTP SSL设置3、FTP当前会话4、FTP防火墙5、FTP目录浏览6、FTP请求筛选7、FTP日志8、FTP身份验证9、FTP授权规则10、FTP消息11、…

useContext本身并不能直接向下传递方法,

useContext本身并不能直接向下传递方法&#xff0c;但可以通过将其包装在自定义 hook 中来实现。 例如&#xff0c;假设你有一个 context 叫做 MyContext&#xff0c;其中包含一个函数叫做 myFunction。你可以创建一个新的 hook 来暴露这个函数&#xff1a; jsx import { use…

云栖大会72小时沉浸式精彩回顾

计算&#xff0c;为了无法计算的价值 2023 杭州云栖大会震撼落幕 自2015年&#xff0c;云计算支撑着移动互联网创新 AI时代&#xff0c;继续支撑所有开发者的创新与梦想 当大会主题再次回归 让我们也打开时空隧道 一起回顾72小时云栖之旅 打造一朵AI时代最开放的云 随着…

C++中不允许复制的类

C中不允许复制的类 假设您需要模拟国家的政体。一个国家只能有一位总统&#xff0c;而 President 类面临如下风险&#xff1a; President ourPresident; DoSomething(ourPresident); // duplicate created in passing by value President clone; clone ourPresident; // dup…