如果你还不了解GBDT,不妨看看这篇文章

作者:Freemanzxp

简介:中科大研二在读,目前在微软亚洲研究院实习,主要研究方向是机器学习。

原文:https://blog.csdn.net/zpalyq110/article/details/79527653

Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial

本文已授权,未经原作者允许,不得二次转载


写在前面: 去年学习 GBDT 之初,为了加强对算法的理解,整理了一篇笔记形式的文章,发出去之后发现阅读量越来越多,渐渐也有了评论,评论中大多指出来了笔者理解或者编辑的错误,故重新编辑一版文章,内容更加翔实,并且在 GitHub 上实现了和本文一致的 GBDT 简易版(包括回归、二分类、多分类以及可视化),供大家交流探讨。感谢各位的点赞和评论,希望继续指出错误~Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial


简介:

GBDT 的全称是 Gradient Boosting Decision Tree,梯度提升树,在传统机器学习算法中,GBDT 算得上 TOP3 的算法。想要理解 GBDT 的真正意义,那就必须理解 GBDT 中的 Gradient Boosting 和 Decision Tree 分别是什么?

1. Decision Tree:CART回归树

首先,GBDT 使用的决策树是 CART 回归树,无论是处理回归问题还是二分类以及多分类,GBDT 使用的决策树通通都是都是 CART 回归树。

为什么不用 CART 分类树呢?因为 GBDT 每次迭代要拟合的是梯度值,是连续值所以要用回归树。

对于回归树算法来说最重要的是寻找最佳的划分点,那么回归树中的可划分点包含了所有特征的所有可取的值。在分类树中最佳划分点的判别标准是熵或者基尼系数,都是用纯度来衡量的,但是在回归树中的样本标签是连续数值,所以再使用熵之类的指标不再合适,取而代之的是平方误差,它能很好的评判拟合程度。


回归树生成算法:

输入:训练数据集 D:

输出:回归树 f(x).

在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:

(1)选择最优切分变量 j 与切分点 s,求解

640?wx_fmt=png

遍历变量 j,对固定的切分变量 j 扫描切分点 s,选择使得上式达到最小值的对 (j,s).

(2)用选定的对 (j,s) 划分区域并决定相应的输出值:

640?wx_fmt=png

(3)继续对两个子区域调用步骤(1)和(2),直至满足停止条件。

(4)将输入空间划分为 M 个区域 

640?wx_fmt=png

,生成决策树:

640?wx_fmt=png

2. Gradient Boosting:拟合负梯度

梯度提升树(Grandient Boosting)是提升树(Boosting Tree)的一种改进算法,所以在讲梯度提升树之前先来说一下提升树。


先来个通俗理解:

假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时我们用6岁去拟合剩下的损失,发现差距还有4岁,第三轮我们用3岁拟合剩下的差距,差距就只有一岁了。

如果我们的迭代轮数还没有完,可以继续迭代下面,每一轮迭代,拟合的岁数误差都会减小。

最后将每次拟合的岁数加起来便是模型输出的结果。


提升树算法:640?wx_fmt=png640?wx_fmt=png

640?wx_fmt=png

 (b)拟合残差640?wx_fmt=png学习一个回归树,得到640?wx_fmt=png

640?wx_fmt=png

640?wx_fmt=png


上面伪代码中的残差是什么?

640?wx_fmt=png

损失函数是

640?wx_fmt=png

我们本轮迭代的目标是找到一个弱学习器

640?wx_fmt=png

最小化让本轮的损失

640?wx_fmt=png

当采用平方损失函数时

640?wx_fmt=png

这里,

640?wx_fmt=png

是当前模型拟合数据的残差(residual)

所以,对于提升树来说只需要简单地拟合当前模型的残差。

回到我们上面讲的那个通俗易懂的例子中,第一次迭代的残差是10岁,第二 次残差4岁……


当损失函数是平方损失和指数损失函数时,梯度提升树每一步优化是很简单的,但是对于一般损失函数而言,往往每一步优化起来不那么容易,针对这一问题,Freidman 提出了梯度提升树算法,这是利用最速下降的近似方法,其关键是利用损失函数的负梯度作为提升树算法中的残差的近似值。

那么负梯度长什么样呢?

第 t 轮的第 i 个样本的损失函数的负梯度为:

640?wx_fmt=png

此时不同的损失函数将会得到不同的负梯度,如果选择平方损失

640?wx_fmt=png

负梯度为

640?wx_fmt=png

此时我们发现 GBDT 的负梯度就是残差,所以说对于回归问题,我们要拟合的就是残差

log(loss)本文以回归问题为例进行讲解

3. GBDT算法原理

上面两节分别将 Decision Tree 和 Gradient Boosting 介绍完了,下面将这两部分组合在一起就是我们的 GBDT 了。


GBDT算法:

640?wx_fmt=png

(2)对640?wx_fmt=png有:640?wx_fmt=png,计算负梯度,即残差

640?wx_fmt=png

640?wx_fmt=png

作为下棵树的训练数据,得到一颗新的回归树

640?wx_fmt=png

其对应的叶子节点区域为640?wx_fmt=png。其

中 J 为回归树 t 的叶子节点的个数。640?wx_fmt=png计算最佳拟合值

640?wx_fmt=png

 (d)更新强学习器

640?wx_fmt=png

(3)得到最终学习器

640?wx_fmt=png

4. 实例详解

本人用 python 以及 pandas 库实现 GBDT 的简易版本,在下面的例子中用到的数据都在 github 可以找到,大家可以结合代码和下面的例子进行理解,欢迎 star~  

Github:https://github.com/Freemanzxp/GBDT_Simple_Tutorial


数据介绍:

如下表所示:一组数据,特征为年龄、体重,身高为标签值。共有5条数据,前四条为训练样本,最后一条为要预测的样本。

640?wx_fmt=png

训练阶段:


参数设置:

  • 学习率:learning_rate=0.1

  • 迭代次数:n_trees=5

  • 树的深度:max_depth=3


1.初始化弱学习器:

640?wx_fmt=png

损失函数为平方损失,因为平方损失函数是一个凸函数,直接求导,倒数等于零,得到 c。

640?wx_fmt=png

令导数等于0

640?wx_fmt=png

所以初始化时,c取值为所有训练样本标签值的均值。

c=(1.1+1.3+1.7+1.8)/4=1.475,此时得到初始学习器640?wx_fmt=png640?wx_fmt=png


2.对迭代轮数m=1,2,…,M:

由于我们设置了迭代次数:n_trees=5,这里的 M=5640?wx_fmt=png的差值640?wx_fmt=png

640?wx_fmt=png

此时将残差作为样本的真实值来训练弱学习器640?wx_fmt=png,即下表数据

640?wx_fmt=png

接着,寻找回归树的最佳划分节点,遍历每个特征的每个可能取值。从年龄特征的5开始,到体重特征的 70 结束,分别计算分裂后两组数据的平方损失(Square Error),640?wx_fmt=png 左节点平方损失,640?wx_fmt=png 右节点平方损失,找到使平方损失和640?wx_fmt=png 最小的那个划分节点,即为最佳划分节点。

例如:以年龄 7 为划分节点,将小于 7 的样本划分为到左节点,大于等于 7 的样本划分为右节点。左节点包括 x0,右节点包括样本640?wx_fmt=png640?wx_fmt=png,所有可能划分情况如下表所示:

640?wx_fmt=png

以上划分点是的总平方损失最小为0.025有两个划分点:年龄21和体重60,所以随机选一个作为划分点,这里我们选 年龄21

现在我们的第一棵树长这个样子:

640?wx_fmt=png

我们设置的参数中树的深度 max_depth=3,现在树的深度只有 2,需要再进行一次划分,这次划分要对左右两个节点分别进行划分:


对于左节点,只含有 0,1 两个样本,根据下表我们选择 年龄7 划分

640?wx_fmt=png

对于右节点,只含有 2,3 两个样本,根据下表我们选择 年龄30 划分(也可以选体重70

640?wx_fmt=png

现在我们的第一棵树长这个样子:

640?wx_fmt=png

此时我们的树深度满足了设置,还需要做一件事情,给这每个叶子节点分别赋一个参数 γ,来拟合残差。

640?wx_fmt=png

这里其实和上面初始化学习器是一个道理,平方损失,求导,令导数等于零,化简之后得到每个叶子节点的参数 γ,其实就是标签值的均值。这个地方的标签值不是原始的 y,而是本轮要拟合的标残差 640?wx_fmt=png.

根据上述划分结果,为了方便表示,规定从左到右为第640?wx_fmt=png个叶子结点640?wx_fmt=png640?wx_fmt=png640?wx_fmt=png640?wx_fmt=png

此时的树长这个样子:

640?wx_fmt=png

此时可更新强学习器,需要用到参数学习率:learning_rate=0.1,用 lr 表示。

640?wx_fmt=png

为什么要用学习率呢?这是Shrinkage的思想,如果每次都全部加上(学习率为1)很容易一步学到位导致过拟合。


重复此步骤,直到 640?wx_fmt=png 结束,最后生成5棵树。

下面将展示每棵树最终的结构,这些图都是GitHub上的代码生成的,感兴趣的同学可以去一探究竟

https://github.com/Freemanzxp/GBDT_Simple_Tutorial

第一棵树:

640?wx_fmt=png

第二棵树:
640?wx_fmt=png

第三棵树:
640?wx_fmt=png

第四棵树:
640?wx_fmt=png

第五棵树:
640?wx_fmt=png

4.得到最后的强学习器:

640?wx_fmt=png


5.预测样本5:

640?wx_fmt=png640?wx_fmt=png中,样本4的年龄为25,大于划分节点21岁,又小于30岁,所以被预测为0.2250

640?wx_fmt=png中,样本4的…此处省略…所以被预测为0.2025

为什么是 0.2025

这是根据第二颗树得到的,可以 GitHub 简单运行一下代码

640?wx_fmt=png中,样本4的…此处省略…所以被预测为0.1823

640?wx_fmt=png中,样本4的…此处省略…所以被预测为0.1640

640?wx_fmt=png中,样本4的…此处省略…所以被预测为0.1476

最终预测结果:640?wx_fmt=png


5. 总结

本文章从GBDT算法的原理到实例详解进行了详细描述,但是目前只写了回归问题,GitHub 上的代码也是实现了回归、二分类、多分类以及树的可视化,希望大家继续批评指正,感谢各位的关注。

Github:

https://github.com/Freemanzxp/GBDT_Simple_Tutorial


参考资料

  1. 李航 《统计学习方法》

  2. Friedman J H . Greedy Function Approximation: A Gradient Boosting Machine[J]. The Annals of Statistics, 2001, 29(5):1189-1232.


欢迎关注我的微信公众号--机器学习与计算机视觉,或者扫描下方的二维码,大家一起交流,学习和进步!

640?wx_fmt=jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/408658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javascript 面向对象编程小记

虽然平常用jquery用的很熟,但是基本都是面向过程的写法。一个事件一个function,很少有面向对象的写法。今天得写一个日期控件,不得不用上面向对象编程。 刚开始我的想法是: var datepicker function(){return new datepicker.fn.init(); }da…

来了解下计算机视觉的八大应用

2019 第 40 篇,总第 64 篇文章本文大约7000字,建议收藏阅读之前通过三篇文章简单介绍了机器学习常用的几种经典算法,当然也包括了目前很火的 CNNs 算法了:常用机器学习算法汇总比较(上)常用机器学习算法汇总比较(中&am…

CRM系统助家具企业华丽转身

近年来,随着住宅建设规模的扩大,作为住宅主要配套商品的家具将迎来广阔的发展机遇和市场增长空间。 我国家具行业以中小企业为多,并且小型私营家具企业在当中占大比例,管理粗放,实行的大多是家族式、经验式管理&#x…

itchat 保存好友信息以及生成好友头像图片墙

2019 第 41 篇,总第 65 篇文章本文大约 4000 字,阅读大约需要 12 分钟最近简单运用 itchat 这个库来实现一些简单的应用,主要包括以下几个应用:统计保存好友的数量和信息统计和保存关注的公众号数量和信息简单生成好友头像的图片墙…

启动outlook时报错:mapi无法加载信息服务msncon.dll

今天这个Office2010 outlook搞的让人蛋疼,老是说启动outlook时报错:mapi无法加载信息服务msncon.dll。 百度了一下,如下解决方案: 安装路径为D:\NEW Windows7 File\office2010\Office14 在命令行中定位到outlook安装文件夹&#x…

快速入门Pytorch(1)--安装、张量以及梯度

2019 第 42 篇,总第 66 篇文章本文大约 9000 字,建议收藏阅读!这是翻译自官方的入门教程,教程地址如下:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html虽然教程名字是 60 分钟入门&#xff0…

快速入门PyTorch(2)--如何构建一个神经网络

2019 第 43 篇,总第 67 篇文章本文大约 4600 字,阅读大约需要 10 分钟快速入门 PyTorch 教程第二篇,这篇介绍如何构建一个神经网络。上一篇文章:快速入门Pytorch(1)--安装、张量以及梯度本文的目录:3. 神经网络在 PyTo…

程序员的职业素养文摘

第1章 专业主义 第2章 说“不” 第3章 说“是” 第4章 编码 第5章 测试驱动开发 第6章 练习 第7章 验收测试 第8章 测试策略 第9章 时间管理 第10章 预估 第11章 压力 第12章 协作 第13章 团队与项目 第14章 辅导,学徒期与技艺转载于:https://www.cnblogs.com/smile…

快速入门PyTorch(3)--训练一个图片分类器和多 GPUs 训练

2019 第 44 篇,总第 68 篇文章本文大约14000字,建议收藏阅读快速入门 PyTorch 教程前两篇文章:快速入门Pytorch(1)--安装、张量以及梯度快速入门PyTorch(2)--如何构建一个神经网络这是快速入门 PyTorch 的第三篇教程也是最后一篇教程&#xf…

60分钟快速入门 PyTorch

PyTorch 是由 Facebook 开发,基于 Torch 开发,从并不常用的 Lua 语言转为 Python 语言开发的深度学习框架,Torch 是 TensorFlow 开源前非常出名的一个深度学习框架,而 PyTorch 在开源后由于其使用简单,动态计算图的特性…

【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.cpp

模型实现代码&#xff0c;关键是train函数和predict函数&#xff0c;都很容易。 #include <iostream> #include <string> #include <math.h> #include "LogisticRegression.h" using namespace std;LogisticRegression::LogisticRegression(int si…

5月份 Github 上最热的十个 Python 项目,从Debug工具到AI水军、量化交易系统。

2019 年第 46 篇&#xff0c;总第 70 篇文章原文地址&#xff1a;https://medium.mybridge.co/python-open-source-for-the-past-month-v-may-2019-473e9f60c73f5 月份刚刚过去&#xff0c;之前看到了一篇介绍 5 月份的最热机器学习项目&#xff0c;刚好看到 Mybridge AI 博客又…