【Python机器学习】决策树集成——梯度提升回归树

理论知识:        

        梯度提升回归树通过合并多个决策树来构建一个更为强大的模型。虽然名字里有“回归”,但这个模型既能用于回归,也能用于分类。与随机森林方法不同,梯度提升采用连续的方式构造树,每棵树都试图纠正前一棵树的错误。默认情况下,梯度提升回归树中没有随机化,而是用到了强预剪枝。梯度提升树通常使用深度很小(1-5之间),这样的模型占用内存小,预测速度也更快。

        梯度提升背后的主要思想是合并许多简单的模型(弱学习器),比如深度较小的树。每棵树只能对部分数据做出比较好的预测,因此添加的树越来越多,可以不断迭代来提高性能。

        梯度提升树通常对参数设置非常敏感,但如果参数设置正确的话,模型精度会更高。

        除了预剪枝和集成树的数量外,梯度提升的另一个重要参数是learning_rate(学习率),用于控制每棵树纠正前一棵树的错误的强度。较高的学习率意味着每棵树都可以做出较强的修正,这样的模型更为复杂。通过增大n_estimators来向集成中添加更多树,也可以增加模型的复杂度,因为模型有更多机会来纠正训练集上的错误。

        默认参数上:树的数量为100、最大深度为3,学习率为0.1

示例:

以乳腺癌数据集为例,用分类模型:

from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import numpy as npplt.rcParams['font.sans-serif'] = ['SimHei']
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0)
gbrt=GradientBoostingClassifier(random_state=0)
gbrt.fit(X_train,y_train)
print('训练集精度:{:.3f}'.format(gbrt.score(X_train,y_train)))
print('测试集精度:{:.3f}'.format(gbrt.score(X_test,y_test)))

 

由于训练集精度达到100%,所以很可能存在过拟合,为了降低过拟合,可以限制最大深度来加强预剪枝,也可以降低学习率:


gbrt_md1=GradientBoostingClassifier(random_state=0,max_depth=1)
gbrt_md1.fit(X_train,y_train)
print('max_depth=1训练集精度:{:.3f}'.format(gbrt_md1.score(X_train,y_train)))
print('max_depth=1测试集精度:{:.3f}'.format(gbrt_md1.score(X_test,y_test)))gbrt_lr001=GradientBoostingClassifier(random_state=0,learning_rate=0.01)
gbrt_lr001.fit(X_train,y_train)
print('learning_rate=0.01训练集精度:{:.3f}'.format(gbrt_lr001.score(X_train,y_train)))
print('learning_rate=0.01测试集精度:{:.3f}'.format(gbrt_lr001.score(X_test,y_test)))

 

可以看到,两种方法都降低了训练集精度,而减小树的最大深度显著提升了模型性能。

特征重要性可视化: 

 

import mglearn.plots
from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import numpy as npdef plot_importances(model):n_feature=cancer.data.shape[1]plt.barh(range(n_feature),model.feature_importances_,align='center')plt.yticks(np.arange(n_feature),cancer.feature_names)plt.xlabel('特征重要性')plt.ylabel('特征')plt.rcParams['font.sans-serif'] = ['SimHei']
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0)
gbrt=GradientBoostingClassifier(random_state=0)
gbrt.fit(X_train,y_train)plot_importances(gbrt)
plt.show()

可以看到,梯度提升树的特征重要性与随机森林有些类似,但梯度提升树完全忽略了某些特征。

常用的方法是先尝试随机森林,因为它的鲁棒性很好,如果随机森林的效果好但预测时间太长,或者学习模型精度在小数点后两位的提高也很重要,那么切换成梯度提升树通常比较有用。

优缺点:

梯度提升树是监督学习中最强大也最常用的模型之一,它的主要缺点是需要仔细调参,而且训练时间会比较长;优点是不需要对数据进行缩放就可以表现的很好,而且也适用于二元特征和连续特征同时存在的数据集。与其他基于树的模型相同,梯度提升树通常也不适用于高纬稀疏数据。

梯度提升树的主要参数是树的数量n_estimators和学习率learning_rate。这两个参数高度相关,因为learning_rate越低,就需要更多树来构建具有相似复杂度的模型,随机森林的n_estimators值总是越大越好,但梯度提升树不同,增大n_estimators会导致模型更加复杂,进而可能导致过拟合,通常的做法是根据时间和内存的预算选择合适的n_estimators,然后会不同的learning_rate进行遍历。

另一个重要参数是max_depth,用来降低每棵树的复杂度,一般不超过5。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/610631.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

功能分享【电商API接口】:商品采集正确使用方法!

相信很多做过电商的人,曾经有在淘宝、京东、天猫、拼多多、1688等平台上卖过自己的产品,但是每换一个平台,商品要重新上传,这浪费了很多没有必要的时间。 为此我们开发了商品采集API功能,一键完成上架商品&#xff0c…

CHS_01.1.5+操作系统引导

CHS_01.1.5操作系统引导 操作系统的引导一个新的磁盘安装操作系统后操作系统引导(开机过程) 操作系统的引导 我们会学习操作系统的引导 那你可能看见这个词的时候会觉得莫名其妙不明 绝地 什么是操作系统的引导呢 简单来说就是当你在开机的时候 如何让…

频率的高低与辐射强度有关系吗?

摘要: 频率的高低和辐射强度之间存在一定的关系。 一般而言,频率越高,辐射强度越大,即电磁辐射的能量越大。这是因为电磁波的能量与其频率成正比。在电磁波谱中,如X光和伽玛射线具有高频率和强辐射强度,可以破坏构成 .…

番外篇 中国古代的操 作系统

番外篇中国古代的操作系统 在古代中国,仿佛已经存在一套古老而神秘的操作系统机制。 这个东方国度中,有一位名叫小李子的忙碌人物,他的工作就如同是执行各种指令的“人肉CPU”。 这个国家还有一个特殊的人物,即皇帝,他…

1.4.1机器学习——梯度下降+α学习率大小判定

1.4.1梯度下降 4.1、梯度下降的概念 ※【总结一句话】:系统通过自动的调节参数w和b的值,得到最小的损失函数值J。 如下:是梯度下降的概念图。 我们有一个损失函数 J(w,b),包含两个参数w和b(你可以想象成J(w,b) w*x…

服务器配置SSL证书到nginx基于Fdfs存储服务器或者直接阿里云绑定SSL

1.如果用FDFS存储服务器内置nginx设置SSL证书 1.验证当前nginx是否存在 http_ssl_modulehttp_ssl_module模块 如果存在直接配置就行 server {listen 80 default backlog2048;listen 443 ssl; server_name 域名; ssl_certificate /usr/local/nginx_fdfs/ssl/xxxx.top.crt; ssl…

Unity报错:[SteamVR] Not Initialized (109)的解决方法

问题描述 使用HTC vive 头像进行SteamVR插件的示例场景进行测试,发现头显场景无法跳转到运行场景(Unity 项目可以运行,仅出现警告)。 具体如下: [SteamVR] Not Initialized (109) [SteamVR] Initialization failed…

OpenHarmony自定义Launcher

前言 OpenHarmony源码版本:4.0release 开发板:DAYU / rk3568 DevEco Studio版本:4.0.0.600 自定义效果: 一、Launcher源码下载 Launcher源码地址:https://gitee.com/openharmony/applications_launcher 切换分支为OpenHarmony-4.0-Release,并下载源码 二、Launcher源…

2024.1.9 基于 Jedis 通过 Java 客户端连接 Redis 服务器

目录 引言 RESP 协议 Redis 通信过程 实现步骤 步骤一 步骤二 步骤三 步骤四 引言 在 Redis 命令行客户端中手敲命令并不是我们日常开发中的主要形式而更多的时候是使用 Redis 的 API 来实现定制化的 Redis 客户端程序,进而操作 Redis 服务器即使用程序来操…

mysql生成到当前时间的时间序列,报表按时间补0

生成本月每日的时间序列 SELECT DATE_FORMAT(date_add( CONCAT(YEAR(Date(curdate())),‘-0’,MONTH(Date(curdate())),‘-’,‘01’), INTERVAL ( cast( help_topic_id AS signed) ) DAY ) ,‘%Y-%m-%d’ ) FROM mysql.help_topic WHERE help_topic_id < DAY ( curdate( ) …

利用Type类来获得字段名称(Unity C#中的反射)

使用Type类以前需要引用反射的命名空间&#xff1a; using System.Reflection; 以下是完整代码&#xff1a; public class ReflectionDemo : MonoBehaviour {void Start(){A a new A();B b new B();A[] abArraynew A[] { a, b };foreach(A v in abArray){Type t v.GetTyp…

SaaS先驱Salesforce发展史

Salesforce是云计算和SaaS领域的先驱&#xff0c;大致经过5个不同发展阶段 第一个阶段&#xff1a;SaaS CRM发展初期 Salesforce成立时间是1999年&#xff0c;其SaaS业务的Idea的灵感起源于IaaS巨头亚马逊。初期标榜的竞品Siebel早期投入高、很难上手、功能过于复杂、实用性不强…

使用Excel批量给数据添加单引号和逗号

表格制作过程如下&#xff1a; A2表格暂时为空&#xff0c;模板建立完成以后&#xff0c;用来放置原始数据&#xff1b; 在B2表格内输入公式&#xff1a; ""&A2&""&"," 敲击回车&#xff1b; 解释&#xff1a; B2表格的公式&q…

WebRTC实现1对1音视频通信原理

什么是 WebRTC &#xff1f; WebRTC&#xff08;Web Real-Time Communication&#xff09;是 Google于2010以6829万美元从 Global IP Solutions 公司购买&#xff0c;并于2024年01月10日将其开源&#xff0c;旨在建立一个互联网浏览器间的实时通信的平台&#xff0c;让 WebRTC…

计算机毕业设计---ssm实验室设备管理系统

项目介绍 ssm实验室设备管理系统。前台jsplayuieasyui等框架渲染数据、后台java语言搭配ssm(spring、springmvc、mybatis、maven) 数据库mysql8.0。该系统主要分三种角色&#xff1a;管理员、教师、学生。主要功能学校实验设备的借、还、修以及实验课程的发布等等&#xff1b;…

再不收藏就晚了,Axure RP Pro 各版本大集合

Axure RP Pro下载链接 https://pan.baidu.com/s/1hRJRY6t0ZONKhdwvykAc3g?pwd0531 1.鼠标右击【Axure RP Pro9.0】压缩包&#xff08;win11及以上系统需先点击“显示更多选项”&#xff09;选择【解压到 Axure RP Pro9.0】。 2.打开解压后的文件夹&#xff0c;鼠标右击【Axu…

2024啦,致敬最可爱的技术人!!

大家可以关注我的公众号和视频号“架构随笔录”。 ​作为一个开源爱好者&#xff0c;我花费了大概1整天的时间去整理了国内外主流的互联网公司在Java后端领域的开源输出成果&#xff0c;顿时感悟太多&#xff0c;总是觉得这些贡献开源的技术人及对应技术公司确实太不容易了&am…

golang并发安全-select

前面说了golang的channel&#xff0c; 今天我们看看golang select 是怎么实现的。 数据结构 type scase struct {c *hchan // chanelem unsafe.Pointer // 数据 } select 非默认的case 中都是处理channel 的 接受和发送&#xff0c;所有scase 结构体中c是用来存储…

前端monorepo大仓权限设计的思考与实现

一、背景 前端 monorepo 在试行大仓研发流程过程中&#xff0c;已经包含了多个业务域的应用、共享组件库、工具函数等多种静态资源&#xff0c;在实现包括代码共享、依赖管理的便捷性以及更好的团队协作的时候&#xff0c;也面临大仓代码文件权限的问题。如何让不同业务域的研…

12. SSM整合

1.新建一个maven项目,添加web支持 创建项目 设定项目名 右键添加框架支持: 添加web应用支持: 完成后目录结构: 2.添加jar包依赖 <?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0…