【机器学习】回归树

回归树是一种用于数值型目标变量的监督学习算法,通过将特征空间划分为多个区域,并在每个区域内使用简单的预测模型(如区域均值)来进行回归。回归树以“递归划分-计算区域均值”的方式逐层生成树节点,最终形成叶节点预测值。相比于线性回归,回归树更适合处理非线性和复杂数据结构。

回归树的基本原理

在回归树中,每个节点执行以下操作:

  • 选择最优特征及分割点:通过最小化均方误差(Mean Squared Error, MSE)等标准选择最佳分割特征和分割点。
  • 分割数据:根据选择的分割特征将数据划分成两部分,形成左子节点和右子节点。
  • 递归分割:对子节点进行递归分割,直至满足停止条件(如最大深度或最小样本数)。

分割准则

均方误差(MSE)

在回归树中,常用均方误差(MSE)作为分割准则:
MSE = 1 N ∑ i = 1 N ( y i − y ˉ ) 2 \text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \bar{y})^2 MSE=N1i=1N(yiyˉ)2
其中,( y_i ) 是样本 ( i ) 的实际值,( \bar{y} ) 是区域内样本的平均值。分割点选择通过最小化分割前后数据的 MSE 来完成。

回归树的构建步骤

  1. 选择最佳分割特征与分割点:遍历每个特征和可能的分割点,计算分割后的MSE,选择使MSE最小的分割特征和点。
  2. 递归分割数据:在左、右子节点递归执行上述过程,形成新的分支节点。
  3. 生成叶节点:一旦满足停止条件,将当前节点的预测值设为该区域中所有样本的均值。

用 Numpy 实现回归树

以下代码展示了如何用 Numpy 实现一个基本的回归树,并通过均方误差来确定分割点。

import numpy as np
import matplotlib.pyplot as plt# 计算均方误差(MSE)
def mean_squared_error(y):return np.var(y) * len(y)# 数据集分割
def split_dataset(X, y, feature, threshold):left_mask = X[:, feature] <= thresholdright_mask = ~left_maskreturn X[left_mask], y[left_mask], X[right_mask], y[right_mask]# 查找最佳分割特征和分割点
def best_split(X, y):best_mse = float("inf")best_feature, best_threshold = None, Nonefor feature in range(X.shape[1]):thresholds = np.unique(X[:, feature])for threshold in thresholds:_, y_left, _, y_right = split_dataset(X, y, feature, threshold)if len(y_left) == 0 or len(y_right) == 0:continuemse_split = mean_squared_error(y_left) + mean_squared_error(y_right)if mse_split < best_mse:best_mse = mse_splitbest_feature = featurebest_threshold = thresholdreturn best_feature, best_threshold# 回归树类
class RegressionTree:def __init__(self, max_depth=3, min_samples_split=2):self.max_depth = max_depthself.min_samples_split = min_samples_splitself.tree = Nonedef fit(self, X, y, depth=0):if len(y) < self.min_samples_split or depth >= self.max_depth:return np.mean(y)feature, threshold = best_split(X, y)if feature is None:return np.mean(y)left_X, left_y, right_X, right_y = split_dataset(X, y, feature, threshold)left_node = self.fit(left_X, left_y, depth + 1)right_node = self.fit(right_X, right_y, depth + 1)self.tree = {"feature": feature, "threshold": threshold, "left": left_node, "right": right_node}return self.treedef predict_sample(self, x, tree):if not isinstance(tree, dict):return treeif x[tree["feature"]] <= tree["threshold"]:return self.predict_sample(x, tree["left"])else:return self.predict_sample(x, tree["right"])def predict(self, X):return np.array([self.predict_sample(x, self.tree) for x in X])# 生成示例数据
np.random.seed(0)
X = np.random.rand(100, 1) * 10  # 特征数据
y = 2 * X.flatten() + np.random.randn(100) * 2  # 标签数据# 训练回归树
tree = RegressionTree(max_depth=4, min_samples_split=5)
tree.fit(X, y)# 预测并可视化
X_test = np.linspace(0, 10, 100).reshape(-1, 1)
y_pred = tree.predict(X_test)plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred, color="red", label="回归树预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("回归树预测示意图")
plt.legend()
plt.show()

在代码中,我们首先通过遍历各个特征和分割点来选择最优分割点,使得均方误差最小。然后在每个节点递归进行分割,直至达到设定的深度或最小样本数。最终通过构建的树结构进行预测。

使用 Sklearn 的回归树

Scikit-Learn 提供了 DecisionTreeRegressor 来实现回归树模型,可以大大简化建模过程。

from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error# 训练回归树
regressor = DecisionTreeRegressor(max_depth=4, min_samples_split=5)
regressor.fit(X, y)# 预测
y_pred_sklearn = regressor.predict(X_test)# 计算均方误差
mse = mean_squared_error(y, regressor.predict(X))
print("均方误差:", mse)# 可视化
plt.scatter(X, y, color="blue", label="训练数据")
plt.plot(X_test, y_pred_sklearn, color="red", label="Sklearn 回归树预测")
plt.xlabel("特征")
plt.ylabel("目标值")
plt.title("Sklearn 回归树预测示意图")
plt.legend()
plt.show()

总结

本文介绍了回归树的基本概念与实现,包括回归树的分割准则、MSE 计算、最佳分割点选择等细节。通过 Numpy 手动实现了一个简单的回归树模型,并展示了如何在 Scikit-Learn 中快速实现和使用回归树。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/884538.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity humanoid 模型头发动画失效问题

在上一篇【Unity实战笔记】第二十二 提到humanoid 模型会使原先的头发动画失效&#xff0c;如下图所示&#xff1a; 头发摆动的是generic模型和动画&#xff0c;不动的是humanoid模型和动画 一开始我是尝试过在模型Optimize Game objects手动添加缺失的头发骨骼的&#xff0c;奈…

基于MATLAB的战术手势识别

手势识别的研究起步于20世纪末&#xff0c;由于计算机技术的发展&#xff0c;特别是近年来虚拟现实技术的发展&#xff0c;手势识别的研究也到达一个新的高度。熵分析法是韩国的李金石、李振恩等人通过从背景复杂的视频数据中分割出人的手势形状&#xff0c;然后计算手型的质心…

CSS学习之Grid网格布局基本概念、容器属性

网格布局 网格布局&#xff08;Grid&#xff09;是将网页划分成一个个网格单元&#xff0c;可任意组合不同的网格&#xff0c;轻松实现各种布局效果&#xff0c;也是目前CSS中最强大布局方案&#xff0c;比Flex更强大。 基本概念 容器和项目 当一个 HTML 元素将 display 属性…

Yelp 数据集进行用户画像, 使用聚类做推荐

使用 Yelp 数据集进行用户画像&#xff08;User Profiling&#xff09;是一项有趣的任务&#xff0c;可以理解用户的偏好、行为和特征。以下是总结的一个基本的步骤&#xff0c;帮助构建用户画像 pandas 加载数据&#xff1a; import pandas as pd# 加载数据 users pd.read_…

JAVA题目笔记(十) 带有继承结构的JavaBean类

一、创建带有继承结构的标准JavaBean类(1) public class Worker {private String name;private int workid;private int salary;public Worker(){}public Worker(String name,int workid,int payment){this.namename;this.salarypayment;this.workidworkid;}public void eat(){…

keepalive+mysql8双主

1.概述 利用keepalived实现Mysql数据库的高可用&#xff0c;KeepalivedMysql双主来实现MYSQL-HA&#xff0c;我们必须保证两台Mysql数据库的数据完全一致&#xff0c;实现方法是两台Mysql互为主从关系&#xff0c;通过keepalived配置VIP&#xff0c;实现当其中的一台Mysql数据库…

【C++笔记】容器适配器及deque和仿函数

【C笔记】容器适配器及deque和仿函数 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;C笔记 文章目录 【C笔记】容器适配器及deque和仿函数前言一.容器适配器1.1什么是容器适配器1.2 STL标准库中stack和queue的底层结构 二.stack2.1stack类模…

centos7.X zabbix监控参数以及邮件报警和钉钉报警

1&#xff1a;zabbix安装 1.1 zabbix 环境要求 硬件配置: 2个CPU核心, 4G 内存, 50G 硬盘&#xff08;最低&#xff09; 操作系统: Linux centos7.2 x86_64 Python 2.7.x Mariadb Server ≥ 5.5.56 httpd-2.4.6-93.el7.centos.x86_64 PHP 5.4.161.2 zabbix安装版本 [rootnod…

基于向量检索的RAG大模型

一、什么是向量 向量是一种有大小和方向的数学对象。它可以表示为从一个点到另一个点的有向线段。例如&#xff0c;二维空间中的向量可以表示为 (&#x1d465;,&#x1d466;) &#xff0c;表示从原点 (0,0)到点 (&#x1d465;,&#x1d466;)的有向线段。 1.1、文本向量 1…

串口屏控制的自动滑轨(未完工)

序言 疫情期间自己制作了一个自动滑轨&#xff0c;基于无线遥控的&#xff0c;但是整体太大了&#xff0c;非常不方便携带&#xff0c;所以重新设计了一个新的&#xff0c;以2020铝型材做导轨的滑轨&#xff0c;目前2020做滑轨已经很成熟了&#xff0c;配件也都非常便宜&#x…

如何使用Get进行状态管理

文章目录 1. 概念介绍2. 思路与方法2.1 实现思路2.2 相关组件3. 示例代码4. 内容总结我们在上一章回中介绍了"使用get进行依赖管理"相关的内容,本章回中将介绍如何使用get进行状态管理一.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 在Flutter开发中状态管理…

计算机视觉常用数据集Cityscapes的介绍、下载、转为YOLO格式进行训练

我在寻找Cityscapes数据集的时候花了一番功夫&#xff0c;因为官网下载需要用公司或学校邮箱邮箱注册账号&#xff0c;等待审核通过后才能进行下载数据集。并且一开始我也并不了解Cityscapes的格式和内容是什么样的&#xff0c;现在我弄明白后写下这篇文章&#xff0c;用于记录…

033_Structure_Static_In_Matlab求解结构静力学问题两套方法

结构静力学问题 静力学问现在是已经很简单的问题&#xff0c;在材料各向同性的情况下&#xff0c;对于弹性固体材料&#xff0c;很容易通过有限元求解。特别是线弹性问题&#xff0c;方程的矩阵形式可以很容易的写出&#xff08;准确得说是很容易通过有限元表达&#xff09;&a…

rnn/lstm 项目实战

tip:本项目用到的数据和代码在https://pan.baidu.com/s/1Cw6OSSWJevSv7T1ouk4B6Q?pwdz6w2 1. RNN : 预测股价 任务&#xff1a;基于zgpa_train.csv数据,建立RNN模型,预测股价 1.完成数据预处理&#xff0c;将序列数据转化为可用于RNN输入的数据 2.对新数据zgpa_test.csv进…

jenkins 构建报错 mvn: command not found

首先安装过 maven&#xff0c;并且配置过环境变量 win r ,输入 cmd 键入 mvn -v 出现上图输出&#xff0c;则证明安装成功。 原因 jenkins 没有 maven 配置全局属性, 导致无法找到 mvn 命令。 解决方案 找到全局属性&#xff0c;点击新增&#xff0c;配置 MAVEN_HOME 路…

轮廓图【HTML+CSS+JavaScript】

给大家分享一个很好看的轮播图&#xff0c;这个也是之前看到别人写的效果感觉很好看&#xff0c;所以后面也自己实现了一下&#xff0c;在这里分享给大家&#xff0c;希望大家也可以有所收获 轮播图效果&#xff1a; 视频效果有点浑浊&#xff0c;大家凑合着看&#xff0c;大家…

ChatGPT变AI搜索引擎!以后还需要谷歌吗?

前言 在北京时间11月1日凌晨&#xff0c;正值ChatGPT两岁生日之际&#xff0c;OpenAI宣布推出最新的人工智能搜索体验&#xff01;具备实时网络功能&#xff01;与 Google 展开直接竞争。 ChatGPT搜索的推出标志着ChatGPT成功消除了即时信息这一最后的短板。 这项新功能可供 …

Netty 组件介绍 - ByteBuf

直接内存&堆内存 ByteBuf buffer ByteBufAllocator.DEFAULT.heapBuffer(10);ByteBuf byteBuf ByteBufAllocator.DEFAULT.directBuffer(10); 组成 ByteBuf维护了两个不同的索引&#xff0c;一个用于读取&#xff0c;一个用于写入。 写入 内存回收 堆内存使用的是JVM内…

都快2025年了,来看看哪个编程语言才是时下热门吧

早上好啊&#xff0c;大佬们&#xff0c;今天咱们不讲知识&#xff0c;今天我们来看看时下热门的编程语言都是哪些&#xff0c;大佬们又都是在学哪些语言呢。 最近一些朋友和我在讨论哪个编程语言是现在 最好用 最厉害 的编程语言。 有人说&#xff0c;Python简单好用&#xf…

【雷达信号数据集】雷达脉冲活动分段的多级学习算法【附下载链接】

摘要 无线电信号识别是电子战中的一项重要功能。电子战系统需要精确识别和定位雷达脉冲活动&#xff0c;以产生有效的对抗措施。尽管这些任务很重要&#xff0c;但基于深度学习的雷达脉冲活动识别方法在很大程度上仍未得到充分探索。虽然之前已经探索了用于雷达调制识别的深度…