深入解析GBDT二分类算法(附代码实现)

目录:

  1. GBDT分类算法简介

  2. GBDT二分类算法

  • 2.1 逻辑回归的对数损失函数

  • 2.2 GBDT二分类原理

  • GBDT二分类算法实例

  • 手撕GBDT二分类算法

  • 4.1 用Python3实现GBDT二分类算法

  • 4.2 用sklearn实现GBDT二分类算法

  • GBDT分类任务常见的损失函数

  • 总结

  • Reference

本文的主要内容概览:

1 GBDT分类算法简介

GBDT无论用于分类还是回归,一直使用的是CART回归树。GBDT不会因为我们所选择的任务是分类任务就选用分类树,这里的核心原因是GBDT每轮的训练是在上一轮训练模型的负梯度值基础之上训练的。这就要求每轮迭代的时候,真实标签减去弱分类器的输出结果是有意义的,即残差是有意义的。如果选用的弱分类器是分类树,类别相减是没有意义的。对于这样的问题,可以采用两种方法来解决:

  • 采用指数损失函数,这样GBDT就退化成了Adaboost,能够解决分类的问题;

  • 使用类似于逻辑回归的对数似然损失函数,如此可以通过结果的概率值与真实概率值的差距当做残差来拟合;

下面我们就通过二分类问题,去看看GBDT究竟是如何做分类的。

2 GBDT二分类算法

2.1 逻辑回归的对数损失函数

逻辑回归的预测函数为:

 

函数 的值有特殊的含义,它表示结果取  的概率,因此对于输入  分类结果为类别  和类别  的概率分别为:

 

 

下面我们根据上式,推导出逻辑回归的对数损失函数  。上式综合起来可以写成:

然后取似然函数为:

因为  和  在同一  处取得极值,因此我们接着取对数似然函数为:

最大似然估计就是求使 取最大值时的  。这里对  取相反数,可以使用梯度下降法求解,求得的  就是要求的最佳参数:

2.2 GBDT二分类原理

逻辑回归单个样本  的损失函数可以表达为:

其中,  是逻辑回归预测的结果。假设第  步迭代之后当前学习器为  ,将  替换为 带入上式之后,可将损失函数写为:

其中,第  棵树对应的响应值为(损失函数的负梯度,即伪残差):

对于生成的决策树,计算各个叶子节点的最佳残差拟合值为:

由于上式没有闭式解(closed form solution),我们一般使用近似值代替:

补充近似值代替过程:
假设仅有一个样本:

令  ,则 

求一阶导:

求二阶导:

对于   的泰勒二阶展开式:

  取极值时,上述二阶表达式中的为:

GBDT二分类算法完整的过程如下:

(1)初始化第一个弱学习器  :

其中,  是训练样本中  的比例,利用先验信息来初始化学习器。

(2)对于建立  棵分类回归树  :

a)对  ,计算第  棵树对应的响应值(损失函数的负梯度,即伪残差):

b)对于 ,利用CART回归树拟合数据  ,得到第  棵回归树,其对应的叶子节点区域为  ,其中  ,且  为第棵回归树叶子节点的个数。

c)对于 个叶子节点区域 ,计算出最佳拟合值:

d)更新强学习器  :

(3)得到最终的强学习器  的表达式:

从以上过程中可知,除了由损失函数引起的负梯度计算和叶子节点的最佳残差拟合值的计算不同,二元GBDT分类和GBDT回归算法过程基本相似。那么二元GBDT是如何做分类呢?

将逻辑回归的公式进行整理,我们可以得到  ,其中,也就是将给定输入   预测为正样本的概率。逻辑回归用一个线性模型去拟合Y=1|x这个事件的对数几率(odds)。二元GBDT分类算法和逻辑回归思想一样,用一系列的梯度提升树去拟合这个对数几率,其分类模型可以表达为:

 

3 GBDT二分类算法实例

3.1 数据集介绍

训练集如下表所示,一组数据的特征有年龄和体重,把身高大于1.5米作为分类边界,身高大于1.5米的令标签为1,身高小于等于1.5米的令标签为0,共有4组数据。

测试数据如下表所示,只有一组数据,年龄为25、体重为65,我们用在训练集上训练好的GBDT模型预测该组数据的身高是否大于1.5米?

3.2 模型训练阶段

参数设置:

  • 学习率:learning_rate = 0.1

  • 迭代次数:n_trees = 5

  • 树的深度:max_depth = 3

1)初始化弱学习器:

2)对于建立棵分类回归树:

由于我们设置了迭代次数:n_trees=5,这就是设置了。

首先计算负梯度,根据上文损失函数为对数损失时,负梯度(即伪残差、近似残差)为:

我们知道梯度提升类算法,其关键是利用损失函数的负梯度的值作为回归问题提升树算法中的残差的近似值,拟合一个回归树。这里,为了称呼方便,我们把负梯度叫做残差。

现将残差的计算结果列表如下:

此时将残差作为样本的标签来训练弱学习器  ,即下表数据:

接着寻找回归树的最佳划分节点,遍历每个特征的每个可能取值。从年龄特征值为开始,到体重特征为结束,分别计算分裂后两组数据的平方损失(Square Error), 为左节点的平方损失,  为右节点的平方损失,找到使平方损失和  最小的那个划分节点,即为最佳划分节点。

例如:以年龄为划分节点,将小于的样本划分为到左节点,大于等于7的样本划分为右节点。左节点包括 ,右节点包括样本 , , , ,所有可能的划分情况如下表所示:

以上划分点的总平方损失最小为,有两个划分点:年龄和体重,所以随机选一个作为划分点,这里我们选年龄。现在我们的第一棵树长这个样子:

我们设置的参数中树的深度max_depth=3,现在树的深度只有,需要再进行一次划分,这次划分要对左右两个节点分别进行划分,但是我们在生成树的时候,设置了三个树继续生长的条件:

  • 深度没有到达最大。树的深度设置为3,意思是需要生长成3层;

  • 点样本数 >= min_samples_split;

  • 此节点上的样本的标签值不一样。如果值一样说明已经划分得很好了,不需要再分;(本程序满足这个条件,因此树只有2层)

最终我们的第一棵回归树长下面这个样子:

此时我们的树满足了设置,还需要做一件事情,给这棵树的每个叶子节点分别赋一个参数,来拟合残差。

根据上述划分结果,为了方便表示,规定从左到右为第个叶子结点,其计算值过程如下:

 

 

此时的第一棵树长下面这个样子:

接着更新强学习器,需要用到学习率:learning_rate=0.1,用lr表示。更新公式为:

为什么要用学习率呢?这是Shrinkage的思想,如果每次都全部加上拟合值  ,即学习率为,很容易一步学到位导致GBDT过拟合。

重复此步骤,直到  结束,最后生成棵树。

下面将展示每棵树最终的结构,这些图都是我GitHub上的代码生成的,感兴趣的同学可以去运行一下代码。

https://github.com/Microstrong0305/WeChat-zhihu-csdnblog-code/tree/master/Ensemble%20Learning/GBDT_GradientBoostingBinaryClassifier

第一棵树:

第二棵树:

第三棵树:

第四棵树:

第五棵树:

3)得到最后的强学习器:

3.3 模型预测阶段

  •  

  • 在  中,测试样本的年龄为,大于划分节点岁,所以被预测为。

  • 在  中,测试样本的年龄为,大于划分节点岁,所以被预测为。

  • 在  中,测试样本的年龄为,大于划分节点岁,所以被预测为。

  • 在  中,测试样本的年龄为,大于划分节点岁,所以被预测为。

  • 在  中,测试样本的年龄为,大于划分节点岁,所以被预测为。

最终预测结果为:

4 手撕GBDT二分类算法

本篇文章所有数据集和代码均在GitHub中,地址:https://github.com/Microstrong0305/WeChat-zhihu-csdnblog-code/tree/master/Ensemble%20Learning

4.1 用Python3实现GBDT二分类算法

需要的Python库:

pandas、PIL、pydotplus、matplotlib
br

其中pydotplus库会自动调用Graphviz,所以需要去Graphviz官网下载graphviz-2.38.msi安装,再将安装目录下的bin添加到系统环境变量,最后重启计算机。

由于用Python3实现GBDT二分类算法代码量比较多,我这里就不列出详细代码了,感兴趣的同学可以去GitHub中看一下,地址:https://github.com/Microstrong0305/WeChat-zhihu-csdnblog-code/tree/master/Ensemble%20Learning/GBDT_GradientBoostingBinaryClassifier

4.2 用sklearn实现GBDT二分类算法

import numpy as np
from sklearn.ensemble import GradientBoostingClassifier'''
调参:
loss:损失函数。有deviance和exponential两种。deviance是采用对数似然,exponential是指数损失,后者相当于AdaBoost。
n_estimators:最大弱学习器个数,默认是100,调参时要注意过拟合或欠拟合,一般和learning_rate一起考虑。
learning_rate:步长,即每个弱学习器的权重缩减系数,默认为0.1,取值范围0-1,当取值为1时,相当于权重不缩减。较小的learning_rate相当于更多的迭代次数。
subsample:子采样,默认为1,取值范围(0,1],当取值为1时,相当于没有采样。小于1时,即进行采样,按比例采样得到的样本去构建弱学习器。这样做可以防止过拟合,但是值不能太低,会造成高方差。
init:初始化弱学习器。不使用的话就是第一轮迭代构建的弱学习器.如果没有先验的话就可以不用管由于GBDT使用CART回归决策树。以下参数用于调优弱学习器,主要都是为了防止过拟合
max_feature:树分裂时考虑的最大特征数,默认为None,也就是考虑所有特征。可以取值有:log2,auto,sqrt
max_depth:CART最大深度,默认为None
min_sample_split:划分节点时需要保留的样本数。当某节点的样本数小于某个值时,就当做叶子节点,不允许再分裂。默认是2
min_sample_leaf:叶子节点最少样本数。如果某个叶子节点数量少于某个值,会同它的兄弟节点一起被剪枝。默认是1
min_weight_fraction_leaf:叶子节点最小的样本权重和。如果小于某个值,会同它的兄弟节点一起被剪枝。一般用于权重变化的样本。默认是0
min_leaf_nodes:最大叶子节点数
'''gbdt = GradientBoostingClassifier(loss='deviance', learning_rate=0.1, n_estimators=5, subsample=1, min_samples_split=2, min_samples_leaf=1, max_depth=3, init=None, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False)train_feat = np.array([[1, 5, 20],[2, 7, 30],[3, 21, 70],[4, 30, 60],])
train_label = np.array([[0], [0], [1], [1]]).ravel()test_feat = np.array([[5, 25, 65]])
test_label = np.array([[1]])
print(train_feat.shape, train_label.shape, test_feat.shape, test_label.shape)gbdt.fit(train_feat, train_label)
pred = gbdt.predict(test_feat)total_err = 0
for i in range(pred.shape[0]):print(pred[i], test_label[i])err = (pred[i] - test_label[i]) / test_label[i]total_err += err * err
print(total_err / pred.shape[0])

用sklearn中的GBDT库实现GBDT二分类算法的难点在于如何更好的调节下列参数:

用sklearn实现GBDT二分类算法的GitHub地址:https://github.com/Microstrong0305/WeChat-zhihu-csdnblog-code/tree/master/Ensemble%20Learning/GBDT_Classification_sklearn

5 GBDT分类任务常见损失函数

对于GBDT分类算法,其损失函数一般有对数损失函数和指数损失函数两种:

(1)如果是指数损失函数,则损失函数表达式为:

其负梯度计算和叶子节点的最佳负梯度拟合可以参看Adaboost算法过程。

(2)如果是对数损失函数,分为二元分类和多元分类两种,本文主要介绍了GBDT二元分类的损失函数。

6 总结

在本文中,我们首先简单介绍了如何把GBDT回归算法变成分类算法的思路;然后从逻辑回归的对数损失函数推导出GBDT的二分类算法原理;其次不仅用Python3实现GBDT二分类算法,还用sklearn实现GBDT二分类算法;最后介绍了GBDT分类任务中常见的损失函数。GBDT可以完美的解决二分类任务,那么它对多分类任务是否有效呢?如果有效,GBDT是如何做多分类呢?这些问题都需要我们不停的探索和挖掘GBDT的深层原理。让我们期待一下GBDT在多分类任务中的表现吧!

 

文章来源:https://mp.weixin.qq.com/s/XLxJ1m7tJs5mGq3WgQYvGw#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480751.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文浅尝 | 动态词嵌入

Citation: Bamler R, Mandt S. Dynamic word embeddings.InInternational Conference on Machine Learning 2017 Jul 17 (pp. 380-389).URL:http://proceedings.mlr.press/v70/bamler17a/bamler17a.pdf动机语言随着时间在不断演化,词语的意思也由于文化的…

滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型

国际数据挖掘领域的顶级会议 KDD 2018 在伦敦举行,今年 KDD 吸引了全球范围内共 1480 篇论文投递,共收录 293 篇,录取率不足 20%。其中滴滴共有四篇论文入选 KDD 2018,涵盖 ETA 预测 (预估到达时间) 、智能派单、大规模车流管理等…

Keyword-BERT——问答系统中语义匹配的杀手锏

引子 问&答 是人和人之间非常重要的沟通方式,其关键在于:我们要理解对方的问题,并给出他想要的答案。设想这样一个场景,当你的女朋友or老婆大人在七夕前一晚,含情脉脉地跟你说 亲爱的,七夕快到了&…

阿里P8架构师谈:Docker简介、组成架构、使用步骤、以及生态产品

Docker简介 Docker是DotCloud开源的、可以将任何应用包装在Linux container中运行的工具。 Docker基于Go语言开发,代码托管在Github上,目前超过10000次commit。 基于Docker的沙箱环境可以实现轻型隔离,多个容器间不会相互影响;D…

研讨会 | 知识图谱前沿技术课程暨学术研讨会(武汉大学站)

知识图谱作为大数据时代重要的知识表示方式之一,已经成为人工智能领域的一个重要支撑。4月28日,“武汉大学信息集成与应用实验室”与“复旦大学知识工场实验室”联合举办“知识图谱前沿技术课程暨学术研讨会”,将结合知识图谱学界研究与业界应…

LayerNorm是Transformer的最优解吗?

本文转载自公众号“夕小瑶的卖萌屋”,专业带逛互联网算法圈的神操作 -----》我是传送门 关注后,回复以下口令: 回复【789】 :领取深度学习全栈手册(含NLP、CV海量综述、必刷论文解读) 回复【入群】&#xf…

观点 | 滴滴 AI Labs 负责人叶杰平教授:深度强化学习在滴滴的探索与实践+关于滴滴智能调度的分析和思考+滴滴派单和Uber派单对比

AI 科技评论按:7 月 29 日,YOCSEF TDS《深度强化学习的理论、算法与应用》专题探索报告会于中科院自动化所成功举办,本文为报告会第一场演讲,讲者为滴滴副总裁、AI Labs 负责人叶杰平教授,演讲题为「深度强化学习在滴滴…

消息中间件系列(二):Kafka的原理、基础架构、以及使用场景

一:Kafka简介 Apache Kafka是分布式发布-订阅消息系统,在 kafka官网上对 kafka 的定义:一个分布式发布-订阅消息传递系统。 它最初由LinkedIn公司开发,Linkedin于2010年贡献给了Apache基金会并成为顶级开源项目。Kafka是一种快速、…

丁力 | cnSchema:中⽂知识图谱的普通话

本文转载自公众号:大数据创新学习中心。3月10日下午,复旦大学知识工场联手北京理工大学大数据创新学习中心举办的“知识图谱前沿技术课程暨学术研讨会”上,OpenKG联合发起⼈、海知智能CTO丁力博士分享了以“cnSchema:中⽂知识图谱…

详解ERNIE-Baidu进化史及应用场景

一只小狐狸带你解锁 炼丹术&NLP 秘籍Ernie 1.0ERNIE: Enhanced Representation through Knowledge Integration 是百度在2019年4月的时候,基于BERT模型,做的进一步的优化,在中文的NLP任务上得到了state-of-the-art的结果。它主要的改进是…

解读 | 滴滴主题研究计划:机器学习专题+

解读 | 滴滴主题研究计划:机器学习专题(上篇) 解读 | 滴滴主题研究计划:机器学习专题(上篇) 2018年7月31日 管理员 微信分享 复制页面地址复制成功滴滴主题研究计划 滴滴希望通过开放业务场景,与…

笔记:seafile 7.x 安装和部署摘要

文章目录1. 安装1.1. 注意事项1.2. 企业微信集成并支持自建第三方应用配置1.3. 内置 Office 文件预览配置1.3.1. 安装 Libreoffice 和 UNO 库2. 主要功能2.1. 服务器个性化配置2.2. 管理员面板2.3. seafile 命令行使用教程2.3.1. ubuntu安装2.3.2. init 初始化seafile配置文件夹…

文章合集

Hi 大家好,我是陈睿|mikechen,这是优知学院的所有文章集合,专门整理这个页面,希望会对大家在浏览感兴趣文章的时候,能有更好的帮助! 这些文章的呈现,并不是按照时间轴来排序,无论是新旧文章&…

领域应用 | 阿里发布藏经阁计划,打造 AI 落地最强知识引擎

如果没有知识引擎,人工智能将会怎样?知识引擎可以把数据加工成信息,信息和现有的知识通过推理能够获得新的知识,从而形成庞大的知识网络,像大脑一样支持各种决策。你与智能音箱进行对话,背后就是基于知识引…

ACL2020 | FastBERT:放飞BERT的推理速度

FastBERT 自从BERT问世以来,大多数NLP任务的效果都有了一次质的飞跃。BERT Large在GLUE test上甚至提升了7个点之多。但BERT同时也开启了模型的“做大做深”之路,普通玩家根本训不起,高端玩家虽然训得起但也不一定用得起。 所以BERT之后的发展…

2017年双十一最全面的大数据分析报告在此!+2018年双十一已经开始,厚昌竞价托管教你如何应对流量流失?+2019年双十一大战一触即发:阿里、京东都有哪些套路和玩法

首先说一个众所周知的数据:2017年双十一天猫成交额1682亿。 所以今天,从三个角度带你一起去探索1682亿背后的秘密: 1、全网热度分析:双十一活动在全网的热度变化趋势、关注来源、媒体来源以及关联词分析。 2、各平台对比分析&…

阿里P8架构师谈:大数据架构设计(文章合集)

架构师进阶有一块很重要的内容,就是需要掌握大数据的架构设计,主要涵括: MySQL等关系式数据库,需要掌握数据库的索引、慢SQL、以及长事务的优化等。 需要掌握非关系式数据库(NoSQL)的选型,以及…

论文浅尝 | 利用 RNN 和 CNN 构建基于 FreeBase 的问答系统

Qu Y,Liu J, Kang L, et al. Question Answering over Freebase via Attentive RNN withSimilarity Matrix based CNN[J]. arXiv preprint arXiv:1804.03317, 2018.概述随着近年来知识库的快速发展,基于知识库的问答系统(KBQA )吸引了业界的广…

positional encoding位置编码详解:绝对位置与相对位置编码对比

本文转载自公众号“夕小瑶的卖萌屋”,专业带逛互联网算法圈的神操作 -----》我是传送门 关注后,回复以下口令: 回复【789】 :领取深度学习全栈手册(含NLP、CV海量综述、必刷论文解读) 回复【入群】&#xf…

## 作为多目标优化的多任务学习:寻找帕累托最优解+组合在线学习:实时反馈玩转组合优化-微软研究院+用于组合优化的强化学习:学习策略解决复杂的优化问题

NIPS 2018:作为多目标优化的多任务学习:寻找帕累托最优解多任务学习本质上是一个多目标问题,因为不同任务之间可能产生冲突,需要对其进行取舍。本文明确将多任务学习视为多目标优化问题,以寻求帕累托最优解。而经过实验…