机器学习:回归决策树(Python)

一、平方误差的计算

square_error_utils.py

import numpy as npclass SquareErrorUtils:"""平方误差最小化准则,选择其中最优的一个作为切分点对特征属性进行分箱处理"""@staticmethoddef _set_sample_weight(sample_weight, n_samples):"""扩展到集成学习,此处为样本权重的设置:param sample_weight: 各样本的权重:param n_samples: 样本量:return:"""if sample_weight is None:sample_weight = np.asarray([1.0] * n_samples)return sample_weight@staticmethoddef square_error(y, sample_weight):"""平方误差:param y: 当前划分区域的目标值集合:param sample_weight: 当前样本的权重:return:"""y = np.asarray(y)return np.sum((y - y.mean()) ** 2 * sample_weight)def cond_square_error(self, x, y, sample_weight):"""计算根据某个特征x划分的区域中y的误差值:param x: 某个特征划分区域所包含的样本:param y: x对应的目标值:param sample_weight: 当前x的权重:return:"""x, y = np.asarray(x), np.asarray(y)error = 0.0for x_val in set(x):x_idx = np.where(x == x_val)  # 按区域计算误差new_y = y[x_idx]  # 对应区域的目标值new_sample_weight = sample_weight[x_idx]error += self.square_error(new_y, new_sample_weight)return errordef square_error_gain(self, x, y, sample_weight=None):"""平方误差带来的增益值:param x: 某个特征变量:param y: 对应的目标值:param sample_weight: 样本权重:return:"""sample_weight = self._set_sample_weight(sample_weight, len(x))return self.square_error(y, sample_weight) - self.cond_square_error(x, y, sample_weight)

 二、树的结点信息封装


class TreeNode_R:"""决策树回归算法,树的结点信息封装,实体类:setXXX()、getXXX()"""def __init__(self, feature_idx: int = None, feature_val=None, y_hat=None, square_error: float = None,criterion_val=None, n_samples: int = None, left_child_Node=None, right_child_Node=None):"""决策树结点信息封装:param feature_idx: 特征索引,如果指定特征属性的名称,可以按照索引取值:param feature_val: 特征取值:param square_error: 划分结点的标准:当前结点的平方误差:param n_samples: 当前结点所包含的样本量:param y_hat: 当前结点的预测值:Ci:param left_child_Node: 左子树:param right_child_Node: 右子树"""self.feature_idx = feature_idxself.feature_val = feature_valself.criterion_val = criterion_valself.square_error = square_errorself.n_samples = n_samplesself.y_hat = y_hatself.left_child_Node = left_child_Node  # 递归self.right_child_Node = right_child_Node  # 递归def level_order(self):"""按层次遍历树...:return:"""pass# def get_feature_idx(self):#     return self.get_feature_idx()## def set_feature_idx(self, feature_idx):#     self.feature_idx = feature_idx

三、回归决策树CART算法实现

import numpy as np
from utils.square_error_utils import SquareErrorUtils
from utils.tree_node_R import TreeNode_R
from utils.data_bin_wrapper import DataBinsWrapperclass DecisionTreeRegression:"""回归决策树CART算法实现:按照二叉树构造1. 划分标准:平方误差最小化2. 创建决策树fit(),递归算法实现,注意出口条件3. 预测predict_proba()、predict() --> 对树的搜索4. 数据的预处理操作,尤其是连续数据的离散化,分箱5. 剪枝处理"""def __init__(self, criterion="mse", max_depth=None, min_sample_split=2, min_sample_leaf=1,min_target_std=1e-3, min_impurity_decrease=0, max_bins=10):self.utils = SquareErrorUtils()  # 结点划分类self.criterion = criterion  # 结点的划分标准if criterion.lower() == "mse":self.criterion_func = self.utils.square_error_gain  # 平方误差增益else:raise ValueError("参数criterion仅限mse...")self.min_target_std = min_target_std  # 最小的样本目标值方差,小于阈值不划分self.max_depth = max_depth  # 树的最大深度,不传参,则一直划分下去self.min_sample_split = min_sample_split  # 最小的划分结点的样本量,小于则不划分self.min_sample_leaf = min_sample_leaf  # 叶子结点所包含的最小样本量,剩余的样本小于这个值,标记叶子结点self.min_impurity_decrease = min_impurity_decrease  # 最小结点不纯度减少值,小于这个值,不足以划分self.max_bins = max_bins  # 连续数据的分箱数,越大,则划分越细self.root_node: TreeNode_R() = None  # 回归决策树的根节点self.dbw = DataBinsWrapper(max_bins=max_bins)  # 连续数据离散化对象self.dbw_XrangeMap = {}  # 存储训练样本连续特征分箱的端点def fit(self, x_train, y_train, sample_weight=None):"""回归决策树的创建,递归操作前的必要信息处理(分箱):param x_train: 训练样本:ndarray,n * k:param y_train: 目标集:ndarray,(n, ):param sample_weight: 各样本的权重,(n, ):return:"""x_train, y_train = np.asarray(x_train), np.asarray(y_train)self.class_values = np.unique(y_train)  # 样本的类别取值n_samples, n_features = x_train.shape  # 训练样本的样本量和特征属性数目if sample_weight is None:sample_weight = np.asarray([1.0] * n_samples)self.root_node = TreeNode_R()  # 创建一个空树self.dbw.fit(x_train)x_train = self.dbw.transform(x_train)self._build_tree(1, self.root_node, x_train, y_train, sample_weight)def _build_tree(self, cur_depth, cur_node: TreeNode_R, x_train, y_train, sample_weight):"""递归创建回归决策树算法,核心算法。按先序(中序、后序)创建的:param cur_depth: 递归划分后的树的深度:param cur_node: 递归划分后的当前根结点:param x_train: 递归划分后的训练样本:param y_train: 递归划分后的目标集合:param sample_weight: 递归划分后的各样本权重:return:"""n_samples, n_features = x_train.shape  # 当前样本子集中的样本量和特征属性数目# 计算当前数结点的预测值,即加权平均值,cur_node.y_hat = np.dot(sample_weight / np.sum(sample_weight), y_train)cur_node.n_samples = n_samples# 递归出口判断cur_node.square_error = ((y_train - y_train.mean()) ** 2).sum()# 所有的样本目标值较为集中,样本方差非常小,不足以划分if cur_node.square_error <= self.min_target_std:# 如果为0,则表示当前样本集合为空,递归出口3returnif n_samples < self.min_sample_split:  # 当前结点所包含的样本量不足以划分returnif self.max_depth is not None and cur_depth > self.max_depth:  # 树的深度达到最大深度return# 划分标准,选择最佳的划分特征及其取值best_idx, best_val, best_criterion_val = None, None, 0.0for k in range(n_features):  # 对当前样本集合中每个特征计算划分标准for f_val in sorted(np.unique(x_train[:, k])):  # 当前特征的不同取值region_x = (x_train[:, k] <= f_val).astype(int)  # 是当前取值f_val就是1,否则就是0criterion_val = self.criterion_func(region_x, y_train, sample_weight)if criterion_val > best_criterion_val:best_criterion_val = criterion_val  # 最佳的划分标准值best_idx, best_val = k, f_val  # 当前最佳特征索引以及取值# 递归出口的判断if best_idx is None:  # 当前属性为空,或者所有样本在所有属性上取值相同,无法划分returnif best_criterion_val <= self.min_impurity_decrease:  # 小于最小不纯度阈值,不划分returncur_node.criterion_val = best_criterion_valcur_node.feature_idx = best_idxcur_node.feature_val = best_val# print("当前划分的特征索引:", best_idx, "取值:", best_val, "最佳标准值:", best_criterion_val)# print("当前结点的类别分布:", target_dist)# 创建左子树,并递归创建以当前结点为子树根节点的左子树left_idx = np.where(x_train[:, best_idx] <= best_val)  # 左子树所包含的样本子集索引if len(left_idx) >= self.min_sample_leaf:  # 小于叶子结点所包含的最少样本量,则标记为叶子结点left_child_node = TreeNode_R()  # 创建左子树空结点# 以当前结点为子树根结点,递归创建cur_node.left_child_Node = left_child_nodeself._build_tree(cur_depth + 1, left_child_node, x_train[left_idx],y_train[left_idx], sample_weight[left_idx])right_idx = np.where(x_train[:, best_idx] > best_val)  # 右子树所包含的样本子集索引if len(right_idx) >= self.min_sample_leaf:  # 小于叶子结点所包含的最少样本量,则标记为叶子结点right_child_node = TreeNode_R()  # 创建右子树空结点# 以当前结点为子树根结点,递归创建cur_node.right_child_Node = right_child_nodeself._build_tree(cur_depth + 1, right_child_node, x_train[right_idx],y_train[right_idx], sample_weight[right_idx])def _search_tree_predict(self, cur_node: TreeNode_R, x_test):"""根据测试样本从根结点到叶子结点搜索路径,判定所属区域(叶子结点)搜索:按照后续遍历:param x_test: 单个测试样本:return:"""if cur_node.left_child_Node and x_test[cur_node.feature_idx] <= cur_node.feature_val:return self._search_tree_predict(cur_node.left_child_Node, x_test)elif cur_node.right_child_Node and x_test[cur_node.feature_idx] > cur_node.feature_val:return self._search_tree_predict(cur_node.right_child_Node, x_test)else:# 叶子结点,类别,包含有类别分布return cur_node.y_hatdef predict(self, x_test):"""预测测试样本x_test的预测值:param x_test: 测试样本ndarray、numpy数值运算:return:"""x_test = np.asarray(x_test)  # 避免传递DataFrame、list...if self.dbw.XrangeMap is None:raise ValueError("请先进行回归决策树的创建,然后预测...")x_test = self.dbw.transform(x_test)y_test_pred = []  # 用于存储测试样本的预测值for i in range(x_test.shape[0]):y_test_pred.append(self._search_tree_predict(self.root_node, x_test[i]))return np.asarray(y_test_pred)@staticmethoddef cal_mse_r2(y_test, y_pred):"""模型预测的均方误差MSE和判决系数R2:param y_test: 测试样本的真值:param y_pred: 测试样本的预测值:return:"""y_test, y_pred = y_test.reshape(-1), y_pred.reshape(-1)mse = ((y_pred - y_test) ** 2).mean()  # 均方误差r2 = 1 - ((y_pred - y_test) ** 2).sum() / ((y_test - y_test.mean()) ** 2).sum()return mse, r2def _prune_node(self, cur_node: TreeNode_R, alpha):"""递归剪枝,针对决策树中的内部结点,自底向上,逐个考察方法:后序遍历:param cur_node: 当前递归的决策树的内部结点:param alpha: 剪枝阈值:return:"""# 若左子树存在,递归左子树进行剪枝if cur_node.left_child_Node:self._prune_node(cur_node.left_child_Node, alpha)# 若右子树存在,递归右子树进行剪枝if cur_node.right_child_Node:self._prune_node(cur_node.right_child_Node, alpha)# 针对决策树的内部结点剪枝,非叶结点if cur_node.left_child_Node is not None or cur_node.right_child_Node is not None:for child_node in [cur_node.left_child_Node, cur_node.right_child_Node]:if child_node is None:# 可能存在左右子树之一为空的情况,当左右子树划分的样本子集数小于min_samples_leafcontinueif child_node.left_child_Node is not None or child_node.right_child_Node is not None:return# 计算剪枝前的损失值(平方误差),2表示当前结点包含两个叶子结点pre_prune_value = 2 * alphaif cur_node and cur_node.left_child_Node is not None:pre_prune_value += (0.0 if cur_node.left_child_Node.square_error is Noneelse cur_node.left_child_Node.square_error)if cur_node and cur_node.right_child_Node is not None:pre_prune_value += (0.0 if cur_node.right_child_Node.square_error is Noneelse cur_node.right_child_Node.square_error)# 计算剪枝后的损失值,当前结点即是叶子结点after_prune_value = alpha + cur_node.square_errorif after_prune_value <= pre_prune_value:  # 进行剪枝操作cur_node.left_child_Node = Nonecur_node.right_child_Node = Nonecur_node.feature_idx, cur_node.feature_val = None, Nonecur_node.square_error = Nonedef prune(self, alpha=0.01):"""决策树后剪枝算法(李航)C(T) + alpha * |T|:param alpha: 剪枝阈值,权衡模型对训练数据的拟合程度与模型的复杂度:return:"""self._prune_node(self.root_node, alpha)return self.root_node

 四、回归决策树算法的测试

test_decision_tree_R.py

import numpy as np
import matplotlib.pyplot as plt
from decision_tree_R import DecisionTreeRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressorobj_fun = lambda x: np.sin(x)
np.random.seed(0)
n = 100
x = np.linspace(0, 10, n)
target = obj_fun(x) + 0.3 * np.random.randn(n)
data = x[:, np.newaxis]  # 二维数组tree = DecisionTreeRegression(max_bins=50, max_depth=10)
tree.fit(data, target)
x_test = np.linspace(0, 10, 200)
y_test_pred = tree.predict(x_test[:, np.newaxis])
mse, r2 = tree.cal_mse_r2(obj_fun(x_test), y_test_pred)plt.figure(figsize=(14, 5))
plt.subplot(121)
plt.scatter(data, target, s=15, c="k", label="Raw Data")
plt.plot(x_test, y_test_pred, "r-", lw=1.5, label="Fit Model")
plt.xlabel("x", fontdict={"fontsize": 12, "color": "b"})
plt.ylabel("y", fontdict={"fontsize": 12, "color": "b"})
plt.grid(ls=":")
plt.legend(frameon=False)
plt.title("Regression Decision Tree(UnPrune) and MSE = %.5f R2 = %.5f" % (mse, r2))plt.subplot(122)
tree.prune(0.5)
y_test_pred = tree.predict(x_test[:, np.newaxis])
mse, r2 = tree.cal_mse_r2(obj_fun(x_test), y_test_pred)
plt.scatter(data, target, s=15, c="k", label="Raw Data")
plt.plot(x_test, y_test_pred, "r-", lw=1.5, label="Fit Model")
plt.xlabel("x", fontdict={"fontsize": 12, "color": "b"})
plt.ylabel("y", fontdict={"fontsize": 12, "color": "b"})
plt.grid(ls=":")
plt.legend(frameon=False)
plt.title("Regression Decision Tree(Prune) and MSE = %.5f R2 = %.5f" % (mse, r2))plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/677030.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Blender教程(基础)--试图的显示模式-22

一、透视模式&#xff08;AltZ&#xff09; 透视模式下可以实现选中透视的物体信息 发现选中了透视区的所有顶点 二、试图着色模式-显示网格边框 三、试图着色模式-显示实体 三、试图着色模式-材质预览 四、试图着色模式-显示渲染预览

Ps:直接从图层生成文件(图像资源)

通过Ps菜单&#xff1a;文件/导出/将图层导出到文件 Layers to Files命令&#xff0c;我们可以快速地将当前文档中的每个图层导出为同一类型、相同大小和选项的独立文件。 Photoshop 还提供了一个功能&#xff0c;可以基于文档中的图层或图层组的名称&#xff0c;自动生成指定大…

CleanMyMacX4.14.6如何清理mac垃圾内存

一直以来&#xff0c;苹果电脑的运行流畅度都很好&#xff0c;但是垃圾内存多了磁盘空间慢慢变少&#xff0c;还是会造成卡顿的。这篇文章就告诉大家电脑如何清理垃圾内存&#xff0c;电脑如何清理磁盘空间。 一、电脑如何清理垃圾内存 垃圾内存指的是各种缓存文件和系统垃圾…

Java图形化界面编程——事件处理 笔记

2.6 事件处理 前面介绍了如何放置各种组件&#xff0c;从而得到了丰富多彩的图形界面&#xff0c;但这些界面还不能响应用户的任何操作。比如单击前面所有窗口右上角的“X”按钮&#xff0c;但窗口依然不会关闭。因为在 AWT 编程中 &#xff0c;所有用户的操作&#xff0c;都必…

JMeter使用教程

作为一名开发工程师&#xff0c;当我们接到需求的时候&#xff0c;一般就是分析需要&#xff0c;确定思路&#xff0c;编码&#xff0c;自测&#xff0c;然后就可以让测试人员去测试了。在自测这一步&#xff0c;作为开发人员&#xff0c;很多时候就是测一下业务流程是否正确&a…

Python 小白的 Leetcode Daily Challenge 刷题计划 - 20240209(除夕)

368. Largest Divisible Subset 难度&#xff1a;Medium 动态规划 方案还原 Yesterdays Daily Challenge can be reduced to the problem of shortest path in an unweighted graph while todays daily challenge can be reduced to the problem of longest path in an unwe…

用Python来实现2024年春晚刘谦魔术

简介 这是新春的第一篇&#xff0c;今天早上睡到了自然醒&#xff0c;打开手机刷视频就被刘谦的魔术所吸引&#xff0c;忍不住用编程去模拟一下这个过程。 首先&#xff0c;声明的一点&#xff0c;大年初一不学习&#xff0c;所以这其中涉及的数学原理约瑟夫环大家可以找找其…

【新书推荐】7.3 for语句

本节必须掌握的知识点&#xff1a; 示例二十四 代码分析 汇编解析 for循环嵌套语句 示例二十五 7.3.1 示例二十四 ■for语句语法形式&#xff1a; for(表达式1;表达式2;表达式3) { 语句块; } ●语法解析&#xff1a; 第一步&#xff1a;执行表达式1&#xff0c;表达式1…

LabVIEW工业监控系统

LabVIEW工业监控系统 介绍了一个基于LabVIEW软件开发的工业监控系统。系统通过虚拟测控技术和先进的数据处理能力&#xff0c;实现对工业过程的高效监控&#xff0c;提升系统的自动化和智能化水平&#xff0c;从而满足现代工业对高效率、高稳定性和低成本的需求。 随着工业自…

BestEdrOfTheMarket:一个针对AVEDR绕过的训练学习环境

关于BestEdrOfTheMarket BestEdrOfTheMarket是一个针对AV/EDR绕过的训练学习环境&#xff0c;广大研究人员和信息安全爱好者可以使用该项目研究和学习跟AV和EDR绕过相关的技术知识。 支持绕过的防御技术 1、多层API钩子&#xff1b; 2、SSH钩子&#xff1b; 3、IAT钩子&#x…

springboot176基于Spring Boot的装饰工程管理系统

简介 【毕设源码推荐 javaweb 项目】基于springbootvue 的 适用于计算机类毕业设计&#xff0c;课程设计参考与学习用途。仅供学习参考&#xff0c; 不得用于商业或者非法用途&#xff0c;否则&#xff0c;一切后果请用户自负。 看运行截图看 第五章 第四章 获取资料方式 **项…

【Make编译控制 01】程序编译与执行

目录 一、编译原理概述 二、编译过程分析 三、编译动静态库 四、执行过程分析 一、编译原理概述 make&#xff1a; 一个GCC工具程序&#xff0c;它会读 makefile 脚本来确定程序中的哪个部分需要编译和连接&#xff0c;然后发布必要的命令。它读出的脚本&#xff08;叫做 …

react中hook封装一个table组件 与 useColumns组件

目录 1&#xff1a;react中hook封装一个table组件依赖CommonTable / index.tsx使用组件效果 2&#xff1a;useColumns组件useColumns.tsx使用 1&#xff1a;react中hook封装一个table组件 依赖 cnpm i react-resizable --save cnpm i ahooks cnpm i --save-dev types/react-r…

开源微服务平台框架的特点是什么?

借助什么平台的力量&#xff0c;可以让企业实现高效率的流程化办公&#xff1f;低代码技术平台是近些年来较为流行的平台产品&#xff0c;可以帮助很多行业进入流程化办公新时代&#xff0c;做好数据管理工作&#xff0c;从而提升企业市场竞争力。流辰信息专业研发低代码技术平…

软件文档测试

1 文档测试的范围 软件产品由可运行的程序、数据和文档组成。文档是软件的一个重要组成部分。 在软件的整人生命周期中&#xff0c;会用到许多文档&#xff0c;在各个阶段中以文档作为前阶段工作成果的体现和后阶段工作的依据。 软件文档的分类结构图如下图所示&#xff1a; …

图灵之旅--二叉树堆排序

目录 树型结构概念树的表示形式 二叉树概念特殊的二叉树二叉树性质二叉树的存储二叉树的遍历前中后序遍历 优先级队列(堆)概念 优先级队列的模拟实现堆的性质概念堆的存储方式堆的创建 堆常用接口介绍PriorityQueue的特性PriorityQueue常用接口介绍优先级队列的构造插入/删除/获…

力扣刷题之旅:进阶篇(六)—— 图论与最短路径问题

力扣&#xff08;LeetCode&#xff09;是一个在线编程平台&#xff0c;主要用于帮助程序员提升算法和数据结构方面的能力。以下是一些力扣上的入门题目&#xff0c;以及它们的解题代码。 --点击进入刷题地址 引言 在算法的广阔天地中&#xff0c;图论是一个非常重要的领域。…

2万字曝光:华尔街疯狂抢购比特币背后

作者/来源&#xff1a;Mark Goodwin and whitney Webb BitcoinMagazine 编译&#xff1a;秦晋 全文&#xff1a;19000余字 在最近比特币ETF获得批准之后&#xff0c;贝莱德的拉里-芬克透露&#xff0c;很快所有东西都将被「ETF化」与代币化&#xff0c;不仅威胁到现有的资产和商…

【linux系统体验】-archlinux折腾日记

archlinux 一、系统安装二、系统配置及美化2.1 中文输入法2.2 安装virtualbox增强工具2.3 终端美化2.4 桌面面板美化 三、问题总结3.1 一、系统安装 安装步骤人们已经总结了很多很全: Arch Linux图文安装教程 大体步骤&#xff1a; 磁盘分区安装 Linux内核配置系统&#xff…

Nginx 配置 SSL证书

成功配置SSL证书后&#xff0c;您将能够通过HTTPS加密通道安全访问Nginx服务器。 一、准备材料 SSL证书绑定的域名已完成DNS解析&#xff0c;即您的域名与主机IP地址相互映射。您可以通过DNS验证证书工具&#xff0c;检测域名DNS解析是否生效。具体操作&#xff1a; 【1】登录…