基于Python的机器学习系列(18):梯度提升分类(Gradient Boosting Classification)

简介

        梯度提升(Gradient Boosting)是一种集成学习方法,通过逐步添加新的预测器来改进模型。在回归问题中,我们使用梯度来最小化残差。在分类问题中,我们可以利用梯度提升来进行二分类或多分类任务。与回归不同,分类问题需要使用如softmax这样的概率模型来处理类别标签。

梯度提升分类的工作原理

        梯度提升分类的基本步骤与回归类似,但在分类任务中,我们使用概率模型来处理预测结果:

  1. 初始化模型:选择一个初始预测器,这里使用DummyClassifier来作为第一个模型。
  2. 计算梯度:计算每个样本的梯度,梯度是当前预测值与真实标签之间的差异。
  3. 训练新预测器:用计算得到的梯度作为目标,训练一个新的分类器。
  4. 更新模型:将新预测器的结果加到现有模型中。
  5. 重复步骤:重复上述步骤,逐步添加更多的预测器以改进模型的分类能力。

二分类示例

        在二分类任务中,梯度提升分类器的工作流程如下:

  1. 预测概率:通过softmax将预测值转换为概率。
  2. 更新模型:利用当前的梯度来训练下一个分类器。

代码示例

        下面的代码示例展示了如何实现一个梯度提升分类器,包括支持二分类和多分类任务:

from sklearn.tree import DecisionTreeRegressor
from sklearn.dummy import DummyRegressor, DummyClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits, load_breast_cancer
import numpy as npclass GradientBoosting:def __init__(self, S=5, learning_rate=1, max_depth=1, min_samples_split=2, regression=True, tol=1e-4):self.S = Sself.learning_rate = learning_rateself.max_depth = max_depthself.min_samples_split = min_samples_splitself.regression = regression# 初始化回归树tree_params = {'max_depth': self.max_depth, 'min_samples_split': self.min_samples_split}self.models = [DecisionTreeRegressor(**tree_params) for _ in range(S)]if regression:# 回归模型的初始模型self.models.insert(0, DummyRegressor(strategy='mean'))else:# 分类模型的初始模型self.models.insert(0, DummyClassifier(strategy='most_frequent'))def grad(self, y, h):return y - hdef fit(self, X, y):# 训练第一个模型self.models[0].fit(X, y)for i in range(self.S):# 预测yhat = self.predict(X, self.models[:i+1], with_argmax=False)# 计算梯度gradient = self.grad(y, yhat)# 训练下一个模型self.models[i+1].fit(X, gradient)def predict(self, X, models=None, with_argmax=True):if models is None:models = self.modelsh0 = models[0].predict(X)boosting = sum(self.learning_rate * model.predict(X) for model in models[1:])yhat = h0 + boostingif not self.regression:# 使用softmax转换为概率yhat = np.exp(yhat) / np.sum(np.exp(yhat), axis=1, keepdims=True)if with_argmax:yhat = np.argmax(yhat, axis=1)return yhat# 示例:使用乳腺癌数据集进行二分类
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 创建和训练梯度提升分类器
gb = GradientBoosting(S=50, learning_rate=0.1, regression=False)
gb.fit(X_train, y_train)# 预测并计算准确率
y_pred = gb.predict(X_test)
from sklearn.metrics import accuracy_score
print(f'Accuracy: {accuracy_score(y_test, y_pred)}')

总结

        梯度提升分类器通过逐步减少分类错误来提高模型的性能。这种方法在处理分类任务时,能够有效提高预测准确率。与回归任务类似,分类任务中的梯度提升也能通过逐步添加预测器来优化模型。通过调整学习率和模型参数,我们可以进一步提高模型的表现。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/52975.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

react中修改组件样式的几种方法

使用自定义类名className&#xff0c;引入样式文件进行样式覆盖 import React from react; import { Button } from antd;const MyComponent () > {return (<Button className"custom-button">点击我</Button>); };export default MyComponent;.cus…

访问win10共享文件夹:用户或密码不正确 以及 未授予用户在此计算机上的请求登录类型

因为安装的是神州网信政府版&#xff0c;该版本通常包含更严格的安全策略和访问控制&#xff0c;设置了共享文件夹后&#xff0c;访问共享文件夹时出现错误。 1、首先报错&#xff1a;用户或密码不正确 将》网络访问&#xff1a;本地账户的共享和安全模型&#xff0c;修改为&a…

开源通用验证码识别OCR —— DdddOcr 源码赏析(二)

文章目录 前言DdddOcr分类识别调用识别功能classification 函数源码classification 函数源码解读1. 分类功能不支持目标检测2. 转换为Image对象3. 根据模型配置调整图片尺寸和色彩模式4. 图像数据转换为浮点数据并归一化5. 图像数据预处理6. 运行模型&#xff0c;返回预测结果 …

Python测试之测试覆盖率统计

本篇承接上一篇 Python测试框架之—— pytest介绍与示例&#xff0c;在此基础上介绍如何基于pytest进行测试的覆盖率统计。 要在使用 pytest 进行测试时检测代码覆盖率&#xff0c;可以使用 pytest-cov 插件。这个插件是基于 coverage.py&#xff0c;它能帮助你了解哪些代码部…

人工智能和机器学习5 (复旦大学计算机科学与技术实践工作站)语言模型相关的技术和应用、通过OpenAI库,调用千问大模型,并进行反复询问等功能加强

前言 在这个日新月异的AI时代&#xff0c;自然语言处理&#xff08;NLP&#xff09;技术正以前所未有的速度改变着我们的生活方式和工作模式。作为这一领域的佼佼者&#xff0c;OpenAI不仅以其强大的GPT系列模型引领风骚&#xff0c;还通过其开放的API接口&#xff0c;让全球开…

Gamma软件处理D-InSAR获取形变步骤

1. 数据准备 获取数据 目标&#xff1a;通常你需要至少两张SAR图像&#xff1a;一个作为基准图像&#xff08;reference image&#xff09;&#xff0c;另一个作为目标图像&#xff08;secondary image&#xff09;。这些图像应在不同时间拍摄&#xff0c;且成像条件要尽可能…

哈工大-操作系统L30

文件使用磁盘的实现 fd文件描述符 buf内存缓冲区 count读写字符的个数 file->inode获得inode file_write写文件 inode映射表 读写的内存缓冲区buf,file字符流的位置200-212,根据inode提供的索引号找到块号,根据buf形成请求队列&#xff0c;再放入电梯队列 fseek调整读…

Jenkins安装使用详解,jenkins实现企业级CICD流程

文章目录 一、资料1、官方文档 二、环境准备1、安装jdk172、安装maven3、安装git4、安装gitlab5、准备我们的springboot项目6、安装jenkins7、安装docker8、安装k8s&#xff08;可选&#xff0c;部署节点&#xff09;9、安装Harbor10、准备带有jdk环境的基础镜像 三、jenkins实…

二叉树的最大深度(LeetCode)

题目 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 解题 # 定义二叉树节点的类 class TreeNode:def __init__(self, val0, leftNone, rightNone):self.val valself.left leftself.right right# …

力扣1235.规划兼职工作

力扣1235.规划兼职工作 动态规划 二分 将所有工作按照结束时间排序f[i]表示前i个工作可获取的最大收益状态转移&#xff1a;取第i个工作&#xff0c;f[i] profit[i] f[j]&#xff0c;其中j为结束时间小于i的开始时间的最大数不取第i个工作&#xff0c;f[i] f[i-1]可以通过二…

低代码开发平台:重塑未来软件开发格局的关键力量

低代码开发平台正以前所未有的速度改变着软件开发的面貌&#xff0c;通过最小化手动编码&#xff0c;让用户能够迅速构建应用程序。随着企业对敏捷性和创新能力的追求日益增强&#xff0c;这类平台的需求激增。展望未来&#xff0c;技术进步与市场动态将引领低代码开发进入新的…

【C++】探索inline关键字:用法、限制与示例代码

文章目录 前言相关性质用法优点限制和注意事项inline 函数的定义位置inline 和类成员函数inline 和 constexpr 前言 我们知道&#xff1a;对于C、C&#xff0c;在编译时遇到函数调用时&#xff0c;编译器会生成一个函数调用的代码&#xff0c;这包括跳转到函数的地址和处理返回…

大阪OSAKA分子泵TG710MTG730TG1130TD7111TG2810TD3211TG3413手侧接线图

大阪OSAKA分子泵TG710MTG730TG1130TD7111TG2810TD3211TG3413手侧接线图

window下kafka3启动多个

准备工作 我们先安装好kafka&#xff0c;并保证启动成功&#xff0c;可参考文章Windows下安装Kafka3-CSDN博客 复制kafka安装文件 kafka3已经内置了zookeeper&#xff0c;所以直接复制就行了 修改zookeeper配置文件 这里我们修改zookeeper配置文件&#xff0c;主要是快照地址…

【MyBatis】MyBatis的一级缓存和二级缓存简介

目录 1、一级缓存 1.1 我们在一个 sqlSession 中&#xff0c;对 User 表根据id进行两次查询&#xff0c;查看他们发出sql语句的情况。 1.2 同样是对user表进行两次查询&#xff0c;只不过两次查询之间进行了一次update操作。 1.3 一级缓存查询过程 1.4 Mybatis与Spring整…

switch语句和while循环

switch语句和while循环 switch语句break的用法default的用法switch语句中的case和default的顺序问题 while语句while语句的执行流程while语句的具体例子 switch语句 switch 语句是⼀种特殊形式的 if…else 结构&#xff0c;用于判断条件有多个结果的情况。它把多重 的 else if…

Pandas 8-数据筛选过滤

1. 基于条件筛选 1.1 单条件筛选 可以使用布尔索引来筛选满足特定条件的数据。 import pandas as pd # 创建一个DataFrame data { Name: [Alice, Bob, Charlie, David], Age: [24, 27, 22, 32], City: [New York, Los Angeles, Chicago, Houston], Score: [85…

滚动视图ScrollView

activity_scroll_view.xml <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_pare…

面试官让简述一下elasticsearch

当面试官要求你简述 Elasticsearch 时,你可以从以下几个方面来介绍: 1. 简介 Elasticsearch 是一个分布式的、RESTful 风格的搜索和分析引擎,基于 Lucene 构建。它能够处理海量数据,提供近乎实时的全文搜索功能,并且可以轻松扩展到数百台服务器及 PB 级结构化或非结构化…

【Python系列】 Python 中的枚举使用

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…