超简单白话文机器学习 - 回归树树剪枝(含算法介绍,公式,源代码实现以及调包实现)

1. 回归树

1.1 算法介绍

大家看到这篇文章时想必已经对树这个概念已经有基础了,如果不是很了解的朋友可以看看笔者的这篇文章:

超简单白话文机器学习-决策树算法全解(含算法介绍,公式,源代码实现以及调包实现)_白话决策树-CSDN博客

对于回归树的建立,我们一般使用CART回归树,CART(Classification and Regression Trees)回归树是一种用于连续值预测的树模型。它通过递归地分裂数据集,以最小化预测误差为目标,最终生成一棵树结构的模型。

CART回归树的构建核心是选择最佳分裂点通过计算MSE进行衡量。

1. 选择最佳分裂点,对每个特征尝试所有的分裂点,计算分裂后各个数据集的均方误差。

2. 计算分裂前后的总MSE:

其中,n为总样本数,各分子分别是左子节点和右子节点的样本数。

3. 递归分裂,对每个子节点重复上述步骤直到满足停止条件(例如达到最大深度或叶节点中的样本数少于阈值)

获得最佳划分特征之后,需要确定分裂节点的阈值,需要最小化目标函数

1. 首先对于最佳划分特征中的数值进行迭代。

2. 对于该特征特定数值进行分裂的样本进行错误率的计算。

3. 汇总后选择错误率最小的数值作为阈值选择。

2. 树剪枝概述

2.1 预剪枝

2.1.1 算法

预剪枝的核心是在生成决策树的过程中提前停止树的增长。计算当前的划分是否能带来模型泛化能力的提升,如果不能,则不再继续生长子树。

有如下几种方法:

( 1 )当树到达一定深度的时候,停止树的生长。
( 2 )当到达当前结点的样本数量小于某个阈值的时候,停止树的生长。
( 3 )计算每次分裂对测试集的准确度提升,当小于某个阈值的时候 ,不再继续扩展。

2.2 后剪枝

2.2.1 算法

首先我们先讲后剪枝的伪代码用口水话进行呈现:

基于已有的树切分测试数据:

1. 如果存在任一子集是一棵树,则在该子集递归剪枝过程

2. 计算将当前两个叶节点合并后的误差

3. 计算不合并的误差

4. 如果合并可以降低误差,就合并

剪枝策略:

如果剪枝后的叶节点误差小于或等于未剪枝子树的误差,则进行剪枝,即将该内部节点变为叶节点。继续评估和剪枝树中的其他节点,直到不再有可以进一步剪枝的节点。

误差的衡量方式有多种,回归树的误差衡量我们一般选择MSE。

2.2.2 代价复杂度剪枝

前文我们已经讲了,防止过拟合的方法之一时,对决策树进行剪枝,即减少树的分支。 剪枝防止过拟合使得在测试集上的表现更好。

将公式呈现在这里:

让我们用白话文转化一下这个公式:

评价一棵树的得分由两部分组成,第一部分为SSR,一种预测错误率的衡量方式。第二部分代表决策树T的叶子结点个数,阿尔法是自定义指数,需要通过交叉验证的方式得到最佳参数,不同的参数影响最终所生成的树。

举个例子:

对于这四个树我们取得了他们总体的SSR值,假设我们的参数值为1000,计算树的得分。

选取得分最小的树作为我们的预测模型,即第一棵树拥有四个叶子节点。改变参数值会选择不同的预测模型,让我们计算在什么参数值下会分别指向哪一棵树。

在不同参数值的条件下,我们使用测试集迭代进行交叉验证,根据测试集最后的得分我们选择最佳参数作为判断标准,最终构造我们的预测树模型。

3. 手写代码实现

3.1 回归树

def regLeaf(dataset):return np.mean(dataset[:,-1]) #得到叶结点,目标变量的均值def regErr(dataset):return np.var(dataset[:,-1]) * np.shape(dataset)[0] #返回的是总方差def chooseBestSplit(dataset,leafType=regLeaf,errType=regErr,ops=(1,4)):tols = ops[0];tolN = ops[1] #tols是容许的误差下降值, yolN是切分的最少样本数if len(set(dataset[:,-1].T.tolist()[0])) == 1: #如果剩余特征为1return None,leafType(dataset) #直接返回叶子结点m,n = np.shape(dataset)S = errType(dataset) #数据集的总误差bestS = 100000; bestIndex=0;bestvalue = 0for featIndex in range(n-1):for splitVal in set(dataset[:,featIndex]): #对于某特征不同值的集合进行迭代mat0,mat1 = binSplitDataset(dataset,feat,splitVal)if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue #如果不满足最少切分样本树newS = errType(mat0) + errorType(mat1) #返回数据集的总方差if newS < bestS: #选择总方差最少的数据分类方式bestIndex = featbestvalue = splitValbestS = newSif (S - bestS) < tols: #如果小于要求的误差下降值,则直接返回叶子结点return None,leafType(dataset)mat0,mat1 = binSplitDataset(dataset,bestIndex,bestvalue)if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):return None, leafType(dataset)return bestIndex,bestValue

4. 调包实现

4.1 预剪枝

import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score# 加载数据集
data = load_iris()
X = data.data
y = data.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 设置预剪枝条件
max_depth = 3  # 限制树的最大深度
min_samples_split = 4  # 分裂一个内部节点所需的最小样本数
min_samples_leaf = 2  # 叶节点所需的最小样本数# 初始化并训练决策树分类器
clf = DecisionTreeClassifier(random_state=42, max_depth=max_depth, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf)clf.fit(X_train, y_train)# 预测并评估模型性能
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)print(f'预剪枝条件下的决策树分类器准确率: {accuracy:.4f}')# 可视化决策树(需要graphviz支持)
from sklearn.tree import export_graphviz
import graphvizdot_data = export_graphviz(clf, out_file=None, feature_names=data.feature_names,  class_names=data.target_names,  filled=True, rounded=True,  special_characters=True)  
graph = graphviz.Source(dot_data)  
graph.render("iris_prepruned_tree")  # 将树保存为PDF文件
graph  # 在Jupyter Notebook中显示决策树

4.2 后剪枝

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, cross_val_score
import matplotlib.pyplot as plt# 示例数据集
X = np.array([[2.7, 2.5], [1.3, 1.5], [3.2, 2.8], [3.8, 2.5], [2.9, 2.4],[6.5, 3.1], [7.1, 3.4], [6.0, 2.9], [7.6, 3.2], [6.3, 3.0]])
y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 生成完整的决策树
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)# 获取剪枝路径
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities# 遍历不同的剪枝参数,选择最佳剪枝
clfs = []
for ccp_alpha in ccp_alphas:clf = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp_alpha)clf.fit(X_train, y_train)clfs.append(clf)# 交叉验证选择最佳剪枝参数
alpha_scores = [cross_val_score(clf, X_train, y_train, cv=2).mean() for clf in clfs]
best_clf = clfs[np.argmax(alpha_scores)]# 在测试集上评估最佳模型
test_score = best_clf.score(X_test, y_test)
print(f'Best alpha: {ccp_alphas[np.argmax(alpha_scores)]}')
print(f'Test set score: {test_score}')# 可视化剪枝路径
plt.figure(figsize=(10, 6))
plt.plot(ccp_alphas, alpha_scores, marker='o', drawstyle='steps-post')
plt.xlabel('Alpha')
plt.ylabel('Cross-validated accuracy')
plt.title('Alpha vs Cross-validated accuracy')
plt.show()

5. 剪枝的优点与局限性

5.1 预剪枝

5.1.1 优点

提高可解释性:便于理解。

减少计算复杂度:在构建树的过程中提前停止分裂,减少模型训练时间和计算资源的消耗。

防止过拟合:限制树的复杂度,提高模型的泛化能力。

5.1.2 局限性

次优决策:在树构建过程中基于局部信息作出决策,可能忽略了更深层次的潜在有用分裂。

信息丢失:某些潜在的重要特征和信息可能未能充分利用,导致模型的表达能力有限。

难以处理复杂模式:简单树结构可能无法捕捉复杂的决策边界,从而影响分类或回归的精度。

5.2 后剪枝

5.2.1 优点

后剪枝比预剪枝保留了更多的分支, 欠拟合风险小 , 泛化性能往往优于预剪枝决策树

5.2.2 局限性

训练时间开销大 :后剪枝过程是在生成完全决策树 之后进行的,需要自底向上对所有非叶结点逐一计算

6. 应用前景

1. 医疗保健:

-疾病预测:回归树用于疾病的发生概率,基于病患的历史数据和体检报告进行精准预测

-治疗效果评测:预测不同治疗方案的效果,帮助医生制定个性化的治疗计划

2. 环境科学:

-气象预测:用于预测天气变化趋势,例如温度,降水量等

-环境监测:监测和预测空气质量,水质等环境指标

...

6. 参考资料

https://www.cnblogs.com/wuliytTaotao/p/10724118.html

机器学习-预剪枝和后剪枝-CSDN博客

回归树剪枝:代价复杂度剪枝

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/15194.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BL121DT网关在智能电网分布式能源管理中的应用钡铼技术协议网关

随着全球能源结构的转型和智能电网技术的飞速发展&#xff0c;分布式能源管理系统在提高能源利用效率、促进可再生能源接入及保障电网稳定运行方面发挥着日益重要的作用。然而&#xff0c;分布式能源系统内设备种类繁多&#xff0c;通信协议各异&#xff0c;如何高效整合这些设…

如何从http免费升级到https

使用https协议开头是为了在用户访问网站时提供更安全的网络环境。相比http&#xff0c;使用https有数据加密、身份验证、保护隐私、搜索引擎优化等优势。一般获取https证书&#xff0c;则需要支付费用给证书颁发机构&#xff08;CA&#xff09;。还有一些免费的证书证书颁发机构…

解决 SpringBoot 的 Date、LocalDateTime 变成时间戳和数组的问题,创建自定义对象消息转换器

问题描述 SpringBoot 项目&#xff0c;当返回前端的数据类型为 Map 的时候&#xff0c;在 Map 中 put() 时间对象会出现以下问题&#xff1a; 传递的 Date 对象会变成时间戳传递的 LocalDateTime 对象会变成数组 问题复现 编写一个 Controller 方法&#xff0c;返回值为 Ma…

Java并发: 基于Unsafe的CAS实现无锁数据结构

在上一篇Java并发: 面临的挑战文章中说过CAS是解决原子性问题的方案之一。Unsafe提供了CAS的支持&#xff0c;支持实例化对象、访问私有属性、堆外内存访问、线程的启停等功能。 许多Java的并发类库都是基于Unsafe实现的&#xff0c;比如原子类AtomicInteger&#xff0c;并发数…

多线程(C++11)

多线程&#xff08;C&#xff09; 文章目录 多线程&#xff08;C&#xff09;前言一、std::thread类1.线程的创建1.1构造函数1.2代码演示 2.公共成员函数2.1 get_id()2.2 join()2.3 detach()2.4 joinable()2.5 operator 3.静态函数4.类的成员函数作为子线程的任务函数 二、call…

【Linux学习】深入探索进程等待与进程退出码和退出信号

文章目录 退出码return退出 进程的等待进程等待的方法 退出码 main函数的返回值&#xff1a;进程的退出码。 一般为0表示成功&#xff0c;非0表示失败。 每一个非0退出码都表示一个失败的原因&#xff1b; echo $&#xff1f;命令 作用&#xff1a;查看进程退出码。&#xf…

I.MX6ULL Linux C语言开发环境搭建(点灯实验)

系列文章目录 I.MX6ULL Linux C语言开发 I.MX6ULL Linux C语言开发 系列文章目录一、前言二、硬件原理分析三、构建步骤一、 C语言运行环境构建二、软件编写三、链接脚本 四、实验程序编写五、编译下载验证 一、前言 汇编语言编写 LED 灯实验&#xff0c;但是实际开发过程中汇…

Go语言的内存泄漏如何检测和避免?

文章目录 Go语言内存泄漏的检测与避免一、内存泄漏的检测1. 使用性能分析工具2. 使用内存泄漏检测工具3. 代码审查与测试 二、内存泄漏的避免1. 使用defer关键字2. 使用垃圾回收机制3. 避免循环引用4. 使用缓冲池 Go语言内存泄漏的检测与避免 在Go语言开发中&#xff0c;内存泄…

【已解决】C#设置Halcon显示区域Region的颜色

前言 在开发过程中&#xff0c;突然发现我需要显示的筛选区域的颜色是白色的&#xff0c;如下图示&#xff0c;这对我们来说不明显会导致我的二值化筛选的时候存在误差&#xff0c;因此我们需要更换成红色显示这样的话就可以更加的明显&#xff0c;二值化筛选更加的准确。 解…

java: 无法访问org.springframework.ldap.core.LdapTemplate

完整错误&#xff1a; java: 无法访问org.springframework.ldap.core.LdapTemplate错误的类文件: /E:/apache-maven-3.6.3/repository/org/springframework/ldap/spring-ldap-core/3.2.3/spring-ldap-core-3.2.3.jar!/org/springframework/ldap/core/LdapTemplate.class类文件具…

《2024年中国机器人行业投融资报告》| 附下载

近年来&#xff0c;国内机器人行业取得了显著的技术进步&#xff0c;包括人工智能、感知技术、自主导航等技术方面的突破&#xff0c;使得机器人能够更好地适应复杂环境和任务需求&#xff0c;带动了机器人行业加快发展。 当然&#xff0c;技术的进步是外在驱动因素&#xff0…

探索集合python(Set)的神秘面纱:它与字典有何不同?

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、集合&#xff08;Set&#xff09;与字典&#xff08;Dictionary&#xff09;的初识 1. …

L2-038 病毒溯源

详解代码 #include <iostream> #include <cstring> #include <algorithm>using namespace std;const int N 10010,M10010;int n; int h[N], e[M], ne[M], idx;//邻接表,h表示顶点&#xff0c;e表示当前边的终点&#xff0c;ne表示下一条边&#xff0c;idx当…

海外动态IP代理如何提高效率?

动态住宅IP代理之所以能够有效提升数据爬取的效率和准确性&#xff0c;主要归功于其提供的IP地址具有高度的匿名性和真实性。这些IP地址来自于真实的用户网络&#xff0c;因此相比于数据中心IP&#xff0c;它们更不容易被网站的安全系统标识为爬虫。此外&#xff0c;由于IP地址…

【vue-1】vue入门—创建一个vue应用

最近在闲暇时间想学习一下前端框架vue&#xff0c;主要参考以下两个学习资料。 官网 快速上手 | Vue.js b站学习视频 2.创建一个Vue3应用_哔哩哔哩_bilibili 一、创建一个vue3应用 <!DOCTYPE html> <html lang"en"> <head><meta charset&q…

NodeJS安装并生成Vue脚手架(保姆级)

文章目录 NodeJS下载配置环境变量Vue脚手架生成Vue脚手架创建项目Vue项目绑定git 更多相关内容可查看 NodeJS下载 下载地址&#xff1a;https://nodejs.org/en 下载的速度应该很快&#xff0c;下载完可以无脑安装&#xff0c;以下记得勾选即可 注意要记住自己的安装路径&…

【Linux】简单模拟C语言文件标准库FILE

&#x1f466;个人主页&#xff1a;Weraphael ✍&#x1f3fb;作者简介&#xff1a;目前正在学习c和算法 ✈️专栏&#xff1a;Linux &#x1f40b; 希望大家多多支持&#xff0c;咱一起进步&#xff01;&#x1f601; 如果文章有啥瑕疵&#xff0c;希望大佬指点一二 如果文章对…

R可视化:可发表的Y轴截断图

Y轴截断图by ggprism Y轴截断图by ggprism 介绍 ggplot2绘制Y轴截断图by ggprism加载R包 knitr::opts_chunk$set(message = FALSE, warning = FALSE)library(tidyverse) library(ggprism) library(patchwork)rm(list = ls()) options(stringsAsFactors = F) options(future.…

Go语言的中间件(middleware)是如何实现的?

文章目录 Go语言的中间件&#xff08;Middleware&#xff09;是如何实现的&#xff1f;中间件的工作原理中间件的实现步骤示例代码总结 Go语言的中间件&#xff08;Middleware&#xff09;是如何实现的&#xff1f; 在Go语言中&#xff0c;中间件&#xff08;Middleware&#…

springboot实现多开发环境匹配置(超级简洁没废话)

首先logbok-spring.xml里面的内容 <?xml version"1.0" encoding"UTF-8"?> <configuration><!-- 开发、测试环境 --><springProfile name"dev,test"><include resource"org/springframework/boot/logging/log…