【Python机器学习】决策树集成——梯度提升回归树

理论知识:        

        梯度提升回归树通过合并多个决策树来构建一个更为强大的模型。虽然名字里有“回归”,但这个模型既能用于回归,也能用于分类。与随机森林方法不同,梯度提升采用连续的方式构造树,每棵树都试图纠正前一棵树的错误。默认情况下,梯度提升回归树中没有随机化,而是用到了强预剪枝。梯度提升树通常使用深度很小(1-5之间),这样的模型占用内存小,预测速度也更快。

        梯度提升背后的主要思想是合并许多简单的模型(弱学习器),比如深度较小的树。每棵树只能对部分数据做出比较好的预测,因此添加的树越来越多,可以不断迭代来提高性能。

        梯度提升树通常对参数设置非常敏感,但如果参数设置正确的话,模型精度会更高。

        除了预剪枝和集成树的数量外,梯度提升的另一个重要参数是learning_rate(学习率),用于控制每棵树纠正前一棵树的错误的强度。较高的学习率意味着每棵树都可以做出较强的修正,这样的模型更为复杂。通过增大n_estimators来向集成中添加更多树,也可以增加模型的复杂度,因为模型有更多机会来纠正训练集上的错误。

        默认参数上:树的数量为100、最大深度为3,学习率为0.1

示例:

以乳腺癌数据集为例,用分类模型:

from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import numpy as npplt.rcParams['font.sans-serif'] = ['SimHei']
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0)
gbrt=GradientBoostingClassifier(random_state=0)
gbrt.fit(X_train,y_train)
print('训练集精度:{:.3f}'.format(gbrt.score(X_train,y_train)))
print('测试集精度:{:.3f}'.format(gbrt.score(X_test,y_test)))

 

由于训练集精度达到100%,所以很可能存在过拟合,为了降低过拟合,可以限制最大深度来加强预剪枝,也可以降低学习率:


gbrt_md1=GradientBoostingClassifier(random_state=0,max_depth=1)
gbrt_md1.fit(X_train,y_train)
print('max_depth=1训练集精度:{:.3f}'.format(gbrt_md1.score(X_train,y_train)))
print('max_depth=1测试集精度:{:.3f}'.format(gbrt_md1.score(X_test,y_test)))gbrt_lr001=GradientBoostingClassifier(random_state=0,learning_rate=0.01)
gbrt_lr001.fit(X_train,y_train)
print('learning_rate=0.01训练集精度:{:.3f}'.format(gbrt_lr001.score(X_train,y_train)))
print('learning_rate=0.01测试集精度:{:.3f}'.format(gbrt_lr001.score(X_test,y_test)))

 

可以看到,两种方法都降低了训练集精度,而减小树的最大深度显著提升了模型性能。

特征重要性可视化: 

 

import mglearn.plots
from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import numpy as npdef plot_importances(model):n_feature=cancer.data.shape[1]plt.barh(range(n_feature),model.feature_importances_,align='center')plt.yticks(np.arange(n_feature),cancer.feature_names)plt.xlabel('特征重要性')plt.ylabel('特征')plt.rcParams['font.sans-serif'] = ['SimHei']
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0)
gbrt=GradientBoostingClassifier(random_state=0)
gbrt.fit(X_train,y_train)plot_importances(gbrt)
plt.show()

可以看到,梯度提升树的特征重要性与随机森林有些类似,但梯度提升树完全忽略了某些特征。

常用的方法是先尝试随机森林,因为它的鲁棒性很好,如果随机森林的效果好但预测时间太长,或者学习模型精度在小数点后两位的提高也很重要,那么切换成梯度提升树通常比较有用。

优缺点:

梯度提升树是监督学习中最强大也最常用的模型之一,它的主要缺点是需要仔细调参,而且训练时间会比较长;优点是不需要对数据进行缩放就可以表现的很好,而且也适用于二元特征和连续特征同时存在的数据集。与其他基于树的模型相同,梯度提升树通常也不适用于高纬稀疏数据。

梯度提升树的主要参数是树的数量n_estimators和学习率learning_rate。这两个参数高度相关,因为learning_rate越低,就需要更多树来构建具有相似复杂度的模型,随机森林的n_estimators值总是越大越好,但梯度提升树不同,增大n_estimators会导致模型更加复杂,进而可能导致过拟合,通常的做法是根据时间和内存的预算选择合适的n_estimators,然后会不同的learning_rate进行遍历。

另一个重要参数是max_depth,用来降低每棵树的复杂度,一般不超过5。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/610631.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Verilog】期末复习——设计11011序列检测器电路

系列文章 数值(整数,实数,字符串)与数据类型(wire、reg、mem、parameter) 运算符 数据流建模 行为级建模 结构化建模 组合电路的设计和时序电路的设计 有限状态机的定义和分类 期末复习——数字逻辑电路分…

nginx反向代理websocket 60秒自动断开处理

参考: https://blog.csdn.net/zqh123zqh/article/details/112795608 添加: proxy_read_timeout 600s; 例如: server {listen 80;server_name carrefourzone.senguo.cc;#error_page 502 /static/502.html;location /static/ {root /home/ch…

功能分享【电商API接口】:商品采集正确使用方法!

相信很多做过电商的人,曾经有在淘宝、京东、天猫、拼多多、1688等平台上卖过自己的产品,但是每换一个平台,商品要重新上传,这浪费了很多没有必要的时间。 为此我们开发了商品采集API功能,一键完成上架商品&#xff0c…

2024年江苏省职业院校技能大赛高职学生组软件测试任务四 单元测试题目

2024年江苏省职业院校技能大赛高职学生组软件测试任务四 单元测试 任务要求 题目1:任意输入2个正整数值分别存入x、y中,据此完成下述分析:若x≤0或y≤0,则提示:“输入不符合要求。”;若2值相同&#xff0…

CHS_01.1.5+操作系统引导

CHS_01.1.5操作系统引导 操作系统的引导一个新的磁盘安装操作系统后操作系统引导(开机过程) 操作系统的引导 我们会学习操作系统的引导 那你可能看见这个词的时候会觉得莫名其妙不明 绝地 什么是操作系统的引导呢 简单来说就是当你在开机的时候 如何让…

消除类游戏(UPC练习)

题目描述 消除类游戏是深受大众欢迎的一种游戏,游戏在一个包含有n行m列的游戏棋盘上进行,棋盘的每一行每一列的方格上放着一个有颜色的棋子,当一行或一列上有连续三个或更多的相同颜色的棋子时,这些棋子都被消除。当有多处可以被…

频率的高低与辐射强度有关系吗?

摘要: 频率的高低和辐射强度之间存在一定的关系。 一般而言,频率越高,辐射强度越大,即电磁辐射的能量越大。这是因为电磁波的能量与其频率成正比。在电磁波谱中,如X光和伽玛射线具有高频率和强辐射强度,可以破坏构成 .…

番外篇 中国古代的操 作系统

番外篇中国古代的操作系统 在古代中国,仿佛已经存在一套古老而神秘的操作系统机制。 这个东方国度中,有一位名叫小李子的忙碌人物,他的工作就如同是执行各种指令的“人肉CPU”。 这个国家还有一个特殊的人物,即皇帝,他…

1.4.1机器学习——梯度下降+α学习率大小判定

1.4.1梯度下降 4.1、梯度下降的概念 ※【总结一句话】:系统通过自动的调节参数w和b的值,得到最小的损失函数值J。 如下:是梯度下降的概念图。 我们有一个损失函数 J(w,b),包含两个参数w和b(你可以想象成J(w,b) w*x…

服务器配置SSL证书到nginx基于Fdfs存储服务器或者直接阿里云绑定SSL

1.如果用FDFS存储服务器内置nginx设置SSL证书 1.验证当前nginx是否存在 http_ssl_modulehttp_ssl_module模块 如果存在直接配置就行 server {listen 80 default backlog2048;listen 443 ssl; server_name 域名; ssl_certificate /usr/local/nginx_fdfs/ssl/xxxx.top.crt; ssl…

leetcode滑动窗口问题总结 Python

目录 一、理论 二、例题 1. 最长无重复字符串 2. 长度最小的子数组 3. 字符串的排列 4. 最小覆盖子串 5. 滑动窗口最大值 一、理论 滑动窗口是一类比较重要的解题思路,一般来说我们面对的都是非定长窗口,所以一般需要定义两个指针 left 和 right&…

Unity报错:[SteamVR] Not Initialized (109)的解决方法

问题描述 使用HTC vive 头像进行SteamVR插件的示例场景进行测试,发现头显场景无法跳转到运行场景(Unity 项目可以运行,仅出现警告)。 具体如下: [SteamVR] Not Initialized (109) [SteamVR] Initialization failed…

面试宝典之消息中间件面试题

RabbitMq: 1、RabbitMQ有啥用处? (1)服务间异步通信 (2)顺序消费 (3)定时任务 (4)请求削峰 2、RabbitMQ有哪些常用的工作模式? 工作模式(Work…

常用接口抓包以及接口测试工具总结

接口 统称为API,程序与程序之间的对接、交接。 接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点,主要是为了检验不同组件(模块)之间数据的传递是否正确,同时接口测试还要测试当前系统与第三方…

OpenHarmony自定义Launcher

前言 OpenHarmony源码版本:4.0release 开发板:DAYU / rk3568 DevEco Studio版本:4.0.0.600 自定义效果: 一、Launcher源码下载 Launcher源码地址:https://gitee.com/openharmony/applications_launcher 切换分支为OpenHarmony-4.0-Release,并下载源码 二、Launcher源…

2024.1.9 基于 Jedis 通过 Java 客户端连接 Redis 服务器

目录 引言 RESP 协议 Redis 通信过程 实现步骤 步骤一 步骤二 步骤三 步骤四 引言 在 Redis 命令行客户端中手敲命令并不是我们日常开发中的主要形式而更多的时候是使用 Redis 的 API 来实现定制化的 Redis 客户端程序,进而操作 Redis 服务器即使用程序来操…

mysql生成到当前时间的时间序列,报表按时间补0

生成本月每日的时间序列 SELECT DATE_FORMAT(date_add( CONCAT(YEAR(Date(curdate())),‘-0’,MONTH(Date(curdate())),‘-’,‘01’), INTERVAL ( cast( help_topic_id AS signed) ) DAY ) ,‘%Y-%m-%d’ ) FROM mysql.help_topic WHERE help_topic_id < DAY ( curdate( ) …

利用Type类来获得字段名称(Unity C#中的反射)

使用Type类以前需要引用反射的命名空间&#xff1a; using System.Reflection; 以下是完整代码&#xff1a; public class ReflectionDemo : MonoBehaviour {void Start(){A a new A();B b new B();A[] abArraynew A[] { a, b };foreach(A v in abArray){Type t v.GetTyp…

SaaS先驱Salesforce发展史

Salesforce是云计算和SaaS领域的先驱&#xff0c;大致经过5个不同发展阶段 第一个阶段&#xff1a;SaaS CRM发展初期 Salesforce成立时间是1999年&#xff0c;其SaaS业务的Idea的灵感起源于IaaS巨头亚马逊。初期标榜的竞品Siebel早期投入高、很难上手、功能过于复杂、实用性不强…

使用Excel批量给数据添加单引号和逗号

表格制作过程如下&#xff1a; A2表格暂时为空&#xff0c;模板建立完成以后&#xff0c;用来放置原始数据&#xff1b; 在B2表格内输入公式&#xff1a; ""&A2&""&"," 敲击回车&#xff1b; 解释&#xff1a; B2表格的公式&q…