Python28-11 CatBoost梯度提升算法

图片

CatBoost(Categorical Boosting)是由Yandex(一家俄罗斯互联网企业,旗下的搜索引擎曾在俄国内拥有逾60%的市场占有率,同时也提供其他互联网产品和服务)开发的一种基于梯度提升的机器学习算法。CatBoost特别擅长处理类别特征,并且能够有效地避免过拟合和数据泄露问题。CatBoost的全称是“Categorical Boosting”,它的设计初衷是为了在处理包含大量类别特征的数据时表现得更好。

CatBoost的特点

  1. 处理类别特征:CatBoost可以直接处理类别特征而不需要进行额外的编码(如one-hot编码)。

  2. 避免过拟合:CatBoost采用了一种新的处理类别特征的方法,有效地减少了过拟合。

  3. 高效性:CatBoost在训练速度和预测速度方面都表现优异。

  4. 支持CPU和GPU训练:CatBoost既可以在CPU上运行,也可以利用GPU进行加速训练。

  5. 自动处理缺失值:CatBoost可以自动处理缺失值,无需额外的预处理步骤。

CatBoost的核心原理

CatBoost的核心原理基于梯度提升决策树(GBDT),但在处理类别特征和避免过拟合方面进行了创新。以下是一些关键的技术点:

  1. 类别特征处理

    • CatBoost引入了一个称为“均值编码”的方法,基于类别的均值计算新特征。

    • 使用一种称为“目标编码”的技术,将类别特征转化为数值特征时,通过使用目标值的平均值来减少数据泄露的风险。

    • 在训练过程中,通过使用统计信息对数据进行处理,避免直接使用目标变量进行编码。

  2. 有序提升(Ordered Boosting)

    • 为了防止数据泄露和过拟合,CatBoost在训练时对数据进行了有序处理。

    • 有序提升通过在训练过程中随机打乱数据,并确保模型在某一时刻只看到过去的数据,而不会使用未来的信息进行决策。

  3. 计算优化

    • CatBoost通过预计算和缓存的方式加速了特征的计算过程。

    • 支持CPU和GPU训练,能够在大规模数据集上表现出色。

CatBoost的基本使用

以下是一个使用CatBoost进行分类任务的基本示例,我们使用Auto MPG(Miles Per Gallon)数据集,它是一个经典的回归问题数据集,常用于机器学习和统计分析。该数据集记录了不同型号汽车的燃油效率(即每加仑燃油行驶的英里数)以及其他多个相关特征。

数据集特征:

  • mpg: 每加仑燃油行驶的英里数(目标变量)。

  • cylinders: 气缸数量,表示发动机的气缸数。

  • displacement: 发动机排量(立方英寸)。

  • horsepower: 发动机功率(马力)。

  • weight: 车辆重量(磅)。

  • acceleration: 0到60英里每小时的加速度时间(秒)。

  • model_year: 车辆生产年份。

  • origin: 车辆产地(1=美国,2=欧洲,3=日本)。

数据集前几行:

    mpg  cylinders  displacement  horsepower  weight  acceleration  model_year  origin
0  18.0          8         307.0       130.0  3504.0          12.0          70       1
1  15.0          8         350.0       165.0  3693.0          11.5          70       1
2  18.0          8         318.0       150.0  3436.0          11.0          70       1
3  16.0          8         304.0       150.0  3433.0          12.0          70       1
4  17.0          8         302.0       140.0  3449.0          10.5          70       1

代码示例:

import pandas as pd  # 导入Pandas库,用于数据处理
import numpy as np  # 导入Numpy库,用于数值计算
from sklearn.model_selection import train_test_split  # 从sklearn库导入train_test_split,用于划分数据集
from sklearn.metrics import mean_squared_error, mean_absolute_error  # 导入均方误差和平均绝对误差,用于评估模型性能
from catboost import CatBoostRegressor  # 导入CatBoost库中的CatBoostRegressor,用于回归任务
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于绘图
import seaborn as sns  # 导入Seaborn库,用于绘制统计图# 设置随机种子以便结果复现
np.random.seed(42)# 从UCI机器学习库加载Auto MPG数据集
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data"
column_names = ['mpg', 'cylinders', 'displacement', 'horsepower', 'weight', 'acceleration', 'model_year', 'origin']
data = pd.read_csv(url, names=column_names, na_values='?', comment='\t', sep=' ', skipinitialspace=True)# 查看数据集的前几行
print(data.head())# 处理缺失值
data = data.dropna()# 特征和目标变量
X = data.drop('mpg', axis=1)  # 特征变量
y = data['mpg']  # 目标变量# 将类别特征转换为字符串类型(CatBoost可以直接处理类别特征)
X['cylinders'] = X['cylinders'].astype(str)
X['model_year'] = X['model_year'].astype(str)
X['origin'] = X['origin'].astype(str)# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 定义CatBoost回归器
model = CatBoostRegressor(iterations=1000,  # 迭代次数learning_rate=0.1,  # 学习率depth=6,  # 决策树深度loss_function='RMSE',  # 损失函数verbose=100  # 输出训练过程信息
)# 训练模型
model.fit(X_train, y_train, eval_set=(X_test, y_test), early_stopping_rounds=50)# 进行预测
y_pred = model.predict(X_test)# 评估模型性能
mse = mean_squared_error(y_test, y_pred)  # 计算均方误差
mae = mean_absolute_error(y_test, y_pred)  # 计算平均绝对误差# 打印模型的评估结果
print(f'Mean Squared Error (MSE): {mse:.4f}')
print(f'Mean Absolute Error (MAE): {mae:.4f}')# 绘制真实值与预测值的对比图
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.5)  # 绘制散点图
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], '--k')  # 绘制对角线
plt.xlabel('True Values')  # X轴标签
plt.ylabel('Predictions')  # Y轴标签
plt.title('True Values vs Predictions')  # 图标题
plt.show()# 特征重要性可视化
feature_importances = model.get_feature_importance()  # 获取特征重要性
feature_names = X.columns  # 获取特征名称plt.figure(figsize=(10, 6))
sns.barplot(x=feature_importances, y=feature_names)  # 绘制特征重要性条形图
plt.title('Feature Importances')  # 图标题
plt.show()# 输出
'''
mpg  cylinders  displacement  horsepower  weight  acceleration  \
0  18.0          8         307.0       130.0  3504.0          12.0   
1  15.0          8         350.0       165.0  3693.0          11.5   
2  18.0          8         318.0       150.0  3436.0          11.0   
3  16.0          8         304.0       150.0  3433.0          12.0   
4  17.0          8         302.0       140.0  3449.0          10.5   model_year  origin  
0          70       1  
1          70       1  
2          70       1  
3          70       1  
4          70       1  
0: learn: 7.3598113 test: 6.6405869 best: 6.6405869 (0) total: 1.7ms remaining: 1.69s
100: learn: 1.5990203 test: 2.3207830 best: 2.3207666 (94) total: 132ms remaining: 1.17s
200: learn: 1.0613606 test: 2.2319632 best: 2.2284239 (183) total: 272ms remaining: 1.08s
Stopped by overfitting detector  (50 iterations wait)bestTest = 2.21453232
bestIteration = 238Shrink model to first 239 iterations.
Mean Squared Error (MSE): 4.9042
Mean Absolute Error (MAE): 1.6381<Figure size 1000x600 with 1 Axes>
<Figure size 1000x600 with 1 Axes>
'''

Mean Squared Error (MSE): 均方误差,表示预测值与实际值之间的平均平方差异。值越小,模型性能越好,在这里MSE的值是4.9042。

Mean Absolute Error (MAE): 平均绝对误差,表示预测值与实际值之间的平均绝对差异。值越小,模型性能越好,在这里MAE的值是1.6381。

图片

  1. 散点图:图中的每个点表示一个测试样本。横坐标表示该样本的真实值(MPG),纵坐标表示模型的预测值(MPG)。

  2. 对角线:图中的黑色虚线是45度对角线,表示理想情况下的预测结果,即预测值等于真实值。

  3. 点的分布:

    • 靠近对角线:表示模型的预测值与真实值非常接近,预测准确。

    • 远离对角线:表示预测值与真实值有较大差距,预测不准确。

通过图中的点可以看到大部分点都集中在对角线附近,这表明模型的预测性能良好,但也有一些点离对角线较远,表示这些样本的预测值与真实值存在一定的差距。

图片

  1. 条形图:每个条形表示一个特征在模型中的重要性。条形越长,表示该特征对模型预测的贡献越大。

  2. 特征名称:在Y轴上列出了所有特征的名称。

  3. 特征重要性值:在X轴上显示了每个特征的相对重要性值。

从图中可以看到:

  1. model_year:在所有特征中最重要,表示汽车的生产年份对预测燃油效率有很大的影响。

  2. weight:汽车的重量是第二重要的特征,对燃油效率也有显著影响。

  3. displacement 和 horsepower:发动机的排量和功率对燃油效率也有较大的贡献。

在实例中,我们使用CatBoost处理Auto MPG数据集,其主要目的是构建一个回归模型,以预测汽车的燃油效率(即每加仑燃油行驶的英里数,MPG)。

以上内容总结自网络,如有帮助欢迎转发,我们下次再见!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43423.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是ThingsKit物联网平台?

在信息化时代的浪潮中&#xff0c;物联网&#xff08;IoT&#xff09;作为新一代信息技术的核心&#xff0c;已经逐渐渗透到我们生活的方方面面。而在这个大背景下&#xff0c;Thingskit物联网平台以其独特的技术优势和应用场景&#xff0c;成为了物联网领域的一颗璀璨明星。本…

3.flink架构

目录 概述 概述 Flink是一个分布式的带有状态管理的计算框架&#xff0c;为了执行流应用程序&#xff0c;可以和 Hadoop YARN 、K8s 进行整合&#xff0c;当然也可以是一个 standalone 。 官方地址&#xff1a;速递 k8s 是未来的一种趋势&#xff0c;对资源管控能力强。

Windows 控制中心在哪里打开,七种方法教会你

在 Windows 操作系统中&#xff0c;控制中心的概念可能稍有些混淆&#xff0c;因为 Windows 通常使用“控制面板”这一术语来指代用于配置系统设置和更改硬件及软件设置的中心区域。 不过&#xff0c;随着 Windows 的更新&#xff0c;微软也在逐步将一些设置迁移到“设置”应用…

关于Linux的操作作业!24道题

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

MySQL数据库基本操作-DDL和DML

1. DDL解释 DDL(Data Definition Language)&#xff0c;数据定义语言&#xff0c;该语言部分包括以下内容&#xff1a; 对数据库的常用操作对表结构的常用操作修改表结构 2. 对数据库的常用操作 功能SQL查看所有的数据库show databases&#xff1b;查看有印象的数据库show d…

16 - Python语言进阶

Python语言进阶 数据结构和算法 算法&#xff1a;解决问题的方法和步骤 评价算法的好坏&#xff1a;渐近时间复杂度和渐近空间复杂度。 渐近时间复杂度的大O标记&#xff1a; - 常量时间复杂度 - 布隆过滤器 / 哈希存储 - 对数时间复杂度 - 折半查找&#xff08;二分查找&am…

使用耳机壳UV树脂制作私模定制耳塞的价格如何呢?

使用耳机壳UV树脂制作私模定制耳塞的价格如何呢&#xff1f; 耳机壳UV树脂制作私模定制耳塞的价格因多个因素而异&#xff0c;如材料、工艺、设计、定制复杂度等。 根据我目前所了解到的信息&#xff0c;使用UV树脂制作私模定制耳塞的价格可能在数百元至数千元不等。具体价格…

LVS+Nginx高可用集群---Nginx进阶与实战

1.Nginx中解决跨域问题 两个站点的域名不一样&#xff0c;就会有一个跨域问题。 跨域问题&#xff1a;了解同源策略&#xff1a;协议&#xff0c;域名&#xff0c;端口号都相同&#xff0c;只要有一个不相同那么就是非同源。 CORS全称Cross-Origin Resource Sharing&#xff…

挑战杯 opencv python 深度学习垃圾图像分类系统

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; opencv python 深度学习垃圾分类系统 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;4分 这是一个较为新颖的竞…

昇思25天学习打卡营第13天|应用实践之ResNet50迁移学习

基本介绍 今日的应用实践的模型是计算机实践领域中十分出名的模型----ResNet模型。ResNet是一种残差网络结构&#xff0c;它通过引入“残差学习”的概念来解决随着网络深度增加时训练困难的问题&#xff0c;从而能够训练更深的网络结构。现很多网络极深的模型或多或少都受此影响…

数据链路层(超详细)

引言 数据链路层是计算机网络协议栈中的第二层&#xff0c;位于物理层之上&#xff0c;负责在相邻节点之间的可靠数据传输。数据链路层使用的信道主要有两种类型&#xff1a;点对点信道和广播信道。点对点信道是指一对一的通信方式&#xff0c;而广播信道则是一对多的通信方式…

风险评估:Tomcat的安全配置,Tomcat安全基线检查加固

「作者简介」&#xff1a;冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础著作 《网络安全自学教程》&#xff0c;适合基础薄弱的同学系统化的学习网络安全&#xff0c;用最短的时间掌握最核心的技术。 这一章节我们需…

grafana数据展示

目录 一、安装步骤 二、如何添加喜欢的界面 三、自动添加注册客户端主机 一、安装步骤 启动成功后 可以查看端口3000是否启动 如果启动了就在浏览器输入IP地址&#xff1a;3000 账号密码默认是admin 然后点击 log in 第一次会让你修改密码 根据自定义密码然后就能登录到界面…

高职物联网实训室

一、高职物联网实训室建设背景 随着《中华人民共和国国民经济和社会发展第十四个五年规划和2035年远景目标纲要》的发布&#xff0c;中国正式步入加速数字化转型的新时代。在数字化浪潮中&#xff0c;物联网技术作为连接物理世界与数字世界的桥梁&#xff0c;其重要性日益凸显…

Golang | Leetcode Golang题解之第224题基本计算器

题目&#xff1a; 题解&#xff1a; func calculate(s string) (ans int) {ops : []int{1}sign : 1n : len(s)for i : 0; i < n; {switch s[i] {case :icase :sign ops[len(ops)-1]icase -:sign -ops[len(ops)-1]icase (:ops append(ops, sign)icase ):ops ops[:len(o…

2024年有多少程序员转行了?

疫情后大环境下行&#xff0c;各行各业的就业情况都是一言难尽。互联网行业更是极不稳定&#xff0c;频频爆出裁员的消息。大家都说2024年程序员的就业很难&#xff0c;都很焦虑。 在许多人眼里&#xff0c;程序员可能是一群背着电脑、进入高大上写字楼的职业&#xff0c;他们…

Datawhale AI 夏令营 机器学习挑战赛

一、赛事背景 在当今科技日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;技术正以前所未有的深度和广度渗透到科研领域&#xff0c;特别是在化学及药物研发中展现出了巨大潜力。精准预测分子性质有助于高效筛选出具有优异性能的候选药物。以PROTACs为例&#x…

Hi3861 OpenHarmony嵌入式应用入门--MQTT

MQTT 是机器对机器(M2M)/物联网(IoT)连接协议。它被设计为一个极其轻量级的发布/订阅消息传输 协议。对于需要较小代码占用空间和/或网络带宽非常宝贵的远程连接非常有用&#xff0c;是专为受限设备和低带宽、 高延迟或不可靠的网络而设计。这些原则也使该协议成为新兴的“机器…

AutoMQ 生态集成 Kafdrop-ui

Kafdrop [1] 是一个为 Kafka 设计的简洁、直观且功能强大的Web UI 工具。它允许开发者和管理员轻松地查看和管理 Kafka 集群的关键元数据&#xff0c;包括主题、分区、消费者组以及他们的偏移量等。通过提供一个用户友好的界面&#xff0c;Kafdrop 大大简化了 Kafka 集群的监控…