GBDT 算法【python,机器学习,算法】

GBDT 即 Gradient Boosting Decision Tree 梯度提升树, 是一种迭代的决策树算法,又叫 MART(Multiple Additive Regression Tree),
它通过构造一组弱的学习器(树),然后把多棵决策树的结果累加起来作为最终的预测输出。该算法将决策树与集成思想进行了有效的结合。具体实现步骤如下:

  1. 初始化基分类器。
  2. 以当前学习器的预测值为准,计算未正确预测的样本(即残差)。
  3. 使用残差构建下一棵决策树(主要思想:试图纠正前一个模型的错误,使其不断提升预测正确率)。
  4. 重复 2-3 步骤,直到满足终止条件为止(误差很小或者达到一定的迭代次数),结束迭代。
  5. 将迭代中的每个分类器产生的预测值相加,得到最终的预测结果。

下面是一个简单的示例,使用梯度提升算法和决策树分类器对手写数字数据进行对比分析:

# 导入sklearn内置数据集
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits# 导入手写数字数据
digits = load_digits()plt.figure(1, figsize=(3.5, 3.5), facecolor='white')
for i in range(10):for j in range(10):ax = plt.subplot(10, 10, 10 * i + j + 1)# 设置子图的位置ax.set_xticks([])# 隐藏横坐标# 隐藏纵坐标ax.set_yticks([])plt.imshow(digits.images[9 * i + j], cmap=plt.cm.gray_r,interpolation="nearest")
plt.show()# 导入sklearn中的模型验证类
from sklearn.model_selection import train_test_split# 使用train test_split函数自动分割训练数据集和测试数据集
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target,test_size=0.3)
# 导入sklearn模块中的决策树分类器类
from sklearn.tree import DecisionTreeClassifier# 定义一个决策树分类器对象
dtc = DecisionTreeClassifier()
dtc.fit(x_train, y_train)
# 导入sklearn模块中的梯度提升分类器类
from sklearn.ensemble import GradientBoostingClassifier# 定义一个梯度提升决策树分类器对象
gbc = GradientBoostingClassifier(n_estimators=30, learning_rate=0.8)
gbc.fit(x_train, y_train)
print("单棵决策树在训练集上的性能:%.3f" % dtc.score(x_train, y_train))
print("单棵决策树在测试集上的性能:%.3f" % dtc.score(x_test, y_test))
print("GBDT(T-30)在训练集上的性能:%.3f" % gbc.score(x_train, y_train))
print("GBDT(T-30)在测试集上的性能:%.3f" % gbc.score(x_test, y_test))
# 观察弱分类器数量对分类准确度的影响
# 弱分类器的最大值
T_max = 39
gbc_train_scores = []
gbc_test_scores = []
for i in range(1, T_max + 1):gbc = GradientBoostingClassifier(n_estimators=i, learning_rate=0.1)gbc.fit(x_train, y_train)gbc_train_scores.append(gbc.score(x_train, y_train))gbc_test_scores.append(gbc.score(x_test, y_test))# 绘制测试结果
import matplotlib.pyplot as plt# 解决图形中的中文显示乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.matplotlib.rcParams['axes.unicode_minus'] = False
plt.figure()
# 解决图形中的坐标轴负号显示问题
plt.plot(range(1, T_max + 1), gbc_train_scores, color='r', label='训练集')
plt.plot(range(1, T_max + 1), gbc_test_scores, color='g', label='测试集')
plt.title("基学习器数量对GBDT性能的影响")
plt.xlabel("基分类器数量")
plt.ylabel("准确率")
plt.xlim(1, T_max)
plt.legend()
plt.show()

上面的代码演示了基学习器的数量对 GBDT 性能的影响。主要步骤如下:

  1. 导入训练数据。
  2. 将数据切分为两个集合:训练集和测试集。
  3. 使用不同数量的学期器对数据集进行拟合训练和预测。
  4. 绘制基学习器数量对 GBDT 性能的影响图像。

你可以根据实际需要对代码中的数据进行调整以适应不同的测试需要。
你可以根据实际需要对代码中的数据进行调整以适应不同的测试需要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/842737.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

史上最全的Linux常用命令、使用技巧汇总(超全面!简单明了!)

目录 常用Linux命令 --help ls pwd cd touch mkdir rm clear vim cat 基本命令 find 快捷键 小技巧 系统命令 reboot 常用Linux命令 --help 作 用:显示 命令的帮助信息 ls 作 用&…

STM32H750外设之ADC通道选择

目录 概述 1 通道选择功能介绍 2 通道选择( SQRx、 JSQRx) 2.1 通道复用 2.1.1 通道介绍 2.1.2 通道框图 2.2 转换分组 2.3 内部专用通道 3 通道预选寄存器 (ADCx_PCSEL) 3.1 功能介绍 3.2 预选通道寄存器 概述 本位主要介绍STM32H750外设之…

AI学习指南数学工具篇-凸优化在支持逻辑回归中的应用

AI学习指南数学工具篇-凸优化在支持逻辑回归中的应用 一、引言 在人工智能领域,逻辑回归是一种常见的分类算法,它通过学习样本数据的特征和标签之间的关系,来进行分类预测。而在逻辑回归算法中,凸优化是一种重要的数学工具&…

如何开展人工智能项目呢?

1.分析问题,确定输入和输出 比如:中英翻译,输入: 苹果 输出: apple 确定了输入和输出后,要想办法将输入和输出抽象成一些数字,因为计算机只能为你处理数字。比如说,你输入一朵花&am…

栈 队列

目录 1.1栈的基本概念 1.1.1栈的定义 1.1.2栈的基本操作 1.2栈的顺序存储结构 1.2.1构造原理 1.2.2基本算法 1.3栈的链式存储结构 1.3.1构造原理 1.3.2基本算法 2.1队列的基本概念 2.1.1队列的定义 2.1.2队列的基本运算 2.2队列的顺序存储结构 2.2.1构造原理 2.2.1基…

SAP项目中的国际化团队

参与过国际项目或者管理过国际化的团队的同仁是不是有同样的感受,管理SAP项目中的国际化团队是一项复杂的任务,需要考虑多种因素,包括跨文化沟通、时区管理、项目协调等。今天我就根据我多个国际项目的经验来给大家分享一下如何很好的管理国际…

深刻解析 volatile 关键字和线程本地存储ThreadLocal

1.volatile关键字在Java多线程编程中的重要性 在多线程编程中,volatile关键字扮演着至关重要的角色,它确保了变量在多个线程间的可见性,并且能防止指令重排序,从而达到线程安全的目的。 1.1 保证多线程环境下变量的可见性 在Ja…

CRLF注入漏洞

1.CRLF注入漏洞原理 Nginx会将 $uri进行解码,导致传入%0a%0d即可引入换行符,造成CRLF注入漏洞。 执行xss语句 2.漏洞扩展 CRLF 指的是回车符(CR,ASCII 13,\r,%0d) 和换行符(LF,ASCII 10,\n&am…

Java中的事件驱动编程:增强应用的互动性和响应性

事件驱动编程是一种编程范式,其中程序的流程由外部事件决定,如用户操作、系统消息或其他程序的输入。在Java中,事件驱动编程广泛应用于图形用户界面(GUI)开发、网络编程和组件交互。本文将探讨Java中的事件驱动编程基础…

FTP协议——LightFTP安装(Linux)

1、简介 LightFTP是一个轻量级的FTP(File Transfer Protocol,文件传输协议)客户端软件。FTP是一种用于在网络上传输文件的标准协议,允许用户通过TCP/IP网络(如互联网)在计算机之间进行文件传输。 2、步骤…

在ARM开发板上,栈大小设置为2MB(常用设置)里面存放的数据

系列文章目录 在ARM开发板上,栈大小设置为2MB(常用设置)里面存放的数据 在ARM开发板上,栈大小设置为2MB(常用设置)里面存放的数据 系列文章目录 在ARM开发板上,栈(Stack)…

Flutter 中的 LimitedBox 小部件:全面指南

Flutter 中的 LimitedBox 小部件:全面指南 Flutter 是一个功能强大的 UI 框架,它提供了大量的小部件来帮助开发者构建美观且响应式的用户界面。在 Flutter 的布局小部件中,LimitedBox 是一个不太常见但非常有用的组件,它可以用来…

Keras深度学习框架第二十五讲:使用KerasNLP预训练Transformer模型

1、KerasNPL预训练Transformer模型概念 使用KerasNLP来预训练一个Transformer模型涉及多个步骤。由于Keras本身并不直接提供NLP的预训练模型或工具集,我们通常需要结合像TensorFlow Hub、Hugging Face的Transformers库或自定义的Keras层来实现。 以下是一个简化的…

Thingsboard规则链:Message Type Filter节点详解

一、Message Type Filter节点概述 二、具体作用 三、使用教程 四、源码浅析 五、应用场景与案例 智能家居自动化 工业设备监控 智慧城市基础设施管理 六、结语 在物联网(IoT)领域,数据处理与自动化流程的实现是构建智能系统的关键。作…

创新实训2024.05.28日志:记忆化机制、基于MTPE与CoT技术的混合LLM对话机制

1. 带有记忆的会话 1.1. 查询会话历史记录 在利用大模型自身能力进行对话与解答时,最好对用户当前会话的历史记录进行还原,大模型能够更好地联系上下文进行解答。 在langchain chat chat的chat函数中,通过实现langchain框架提供的ChatMemo…

【设计模式】创建型-工厂方法模式

前言 工厂方法模式是一种经典的创建型设计模式,它提供了一种灵活的方式来创建对象实例。通过本文,我们将深入探讨工厂方法模式的概念、结构和应用。 一、什么是工厂方法模式 工厂方法模式是一种创建型设计模式,旨在解决对象的创建过程和客…

Spring MVC的请求流程

Spring MVC(Model-View-Controller)是一种基于Java的实现了MVC设计模式的轻量级Web框架。它通过一套注解,可以快速地搭建一个可扩展、易维护的Web应用程序。下面是Spring MVC处理请求的基本流程: 用户发起请求:用户通过…

Parquet使用指南:一个超越CSV、提升数据处理效率的存储格式

前言 在大数据时代,数据存储和处理的效率越来越重要。同时,我们在工作中处理的数据也越来越多,从excel格式到csv格式,从文件文档传输到直接从数据库提取,数据单位也从K到M再到G。 当数据量达到了G以上,几…

ROS | 自动导航

保存&加载地图: image:地图文件 resolution:地图分辨率(珊格地图) origin:地图左下标 第三个参数是偏转角度 加载创建好的yaml文件: 年轻人第一次导航: 全局规划器: 代价地图设置参数&#…

K-means聚类模型入门介绍

K-means聚类是一种无监督学习方法,广泛应用于数据挖掘、机器学习和模式识别等领域,用于将数据集划分为K个簇(cluster),其中每个簇的数据具有相似的特征。其基本思想是通过迭代寻找使簇内点间距离平方和最小的簇划分方式…