Python28-11 CatBoost梯度提升算法

图片

CatBoost(Categorical Boosting)是由Yandex(一家俄罗斯互联网企业,旗下的搜索引擎曾在俄国内拥有逾60%的市场占有率,同时也提供其他互联网产品和服务)开发的一种基于梯度提升的机器学习算法。CatBoost特别擅长处理类别特征,并且能够有效地避免过拟合和数据泄露问题。CatBoost的全称是“Categorical Boosting”,它的设计初衷是为了在处理包含大量类别特征的数据时表现得更好。

CatBoost的特点

  1. 处理类别特征:CatBoost可以直接处理类别特征而不需要进行额外的编码(如one-hot编码)。

  2. 避免过拟合:CatBoost采用了一种新的处理类别特征的方法,有效地减少了过拟合。

  3. 高效性:CatBoost在训练速度和预测速度方面都表现优异。

  4. 支持CPU和GPU训练:CatBoost既可以在CPU上运行,也可以利用GPU进行加速训练。

  5. 自动处理缺失值:CatBoost可以自动处理缺失值,无需额外的预处理步骤。

CatBoost的核心原理

CatBoost的核心原理基于梯度提升决策树(GBDT),但在处理类别特征和避免过拟合方面进行了创新。以下是一些关键的技术点:

  1. 类别特征处理

    • CatBoost引入了一个称为“均值编码”的方法,基于类别的均值计算新特征。

    • 使用一种称为“目标编码”的技术,将类别特征转化为数值特征时,通过使用目标值的平均值来减少数据泄露的风险。

    • 在训练过程中,通过使用统计信息对数据进行处理,避免直接使用目标变量进行编码。

  2. 有序提升(Ordered Boosting)

    • 为了防止数据泄露和过拟合,CatBoost在训练时对数据进行了有序处理。

    • 有序提升通过在训练过程中随机打乱数据,并确保模型在某一时刻只看到过去的数据,而不会使用未来的信息进行决策。

  3. 计算优化

    • CatBoost通过预计算和缓存的方式加速了特征的计算过程。

    • 支持CPU和GPU训练,能够在大规模数据集上表现出色。

CatBoost的基本使用

以下是一个使用CatBoost进行分类任务的基本示例,我们使用Auto MPG(Miles Per Gallon)数据集,它是一个经典的回归问题数据集,常用于机器学习和统计分析。该数据集记录了不同型号汽车的燃油效率(即每加仑燃油行驶的英里数)以及其他多个相关特征。

数据集特征:

  • mpg: 每加仑燃油行驶的英里数(目标变量)。

  • cylinders: 气缸数量,表示发动机的气缸数。

  • displacement: 发动机排量(立方英寸)。

  • horsepower: 发动机功率(马力)。

  • weight: 车辆重量(磅)。

  • acceleration: 0到60英里每小时的加速度时间(秒)。

  • model_year: 车辆生产年份。

  • origin: 车辆产地(1=美国,2=欧洲,3=日本)。

数据集前几行:

    mpg  cylinders  displacement  horsepower  weight  acceleration  model_year  origin
0  18.0          8         307.0       130.0  3504.0          12.0          70       1
1  15.0          8         350.0       165.0  3693.0          11.5          70       1
2  18.0          8         318.0       150.0  3436.0          11.0          70       1
3  16.0          8         304.0       150.0  3433.0          12.0          70       1
4  17.0          8         302.0       140.0  3449.0          10.5          70       1

代码示例:

import pandas as pd  # 导入Pandas库,用于数据处理
import numpy as np  # 导入Numpy库,用于数值计算
from sklearn.model_selection import train_test_split  # 从sklearn库导入train_test_split,用于划分数据集
from sklearn.metrics import mean_squared_error, mean_absolute_error  # 导入均方误差和平均绝对误差,用于评估模型性能
from catboost import CatBoostRegressor  # 导入CatBoost库中的CatBoostRegressor,用于回归任务
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于绘图
import seaborn as sns  # 导入Seaborn库,用于绘制统计图# 设置随机种子以便结果复现
np.random.seed(42)# 从UCI机器学习库加载Auto MPG数据集
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data"
column_names = ['mpg', 'cylinders', 'displacement', 'horsepower', 'weight', 'acceleration', 'model_year', 'origin']
data = pd.read_csv(url, names=column_names, na_values='?', comment='\t', sep=' ', skipinitialspace=True)# 查看数据集的前几行
print(data.head())# 处理缺失值
data = data.dropna()# 特征和目标变量
X = data.drop('mpg', axis=1)  # 特征变量
y = data['mpg']  # 目标变量# 将类别特征转换为字符串类型(CatBoost可以直接处理类别特征)
X['cylinders'] = X['cylinders'].astype(str)
X['model_year'] = X['model_year'].astype(str)
X['origin'] = X['origin'].astype(str)# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 定义CatBoost回归器
model = CatBoostRegressor(iterations=1000,  # 迭代次数learning_rate=0.1,  # 学习率depth=6,  # 决策树深度loss_function='RMSE',  # 损失函数verbose=100  # 输出训练过程信息
)# 训练模型
model.fit(X_train, y_train, eval_set=(X_test, y_test), early_stopping_rounds=50)# 进行预测
y_pred = model.predict(X_test)# 评估模型性能
mse = mean_squared_error(y_test, y_pred)  # 计算均方误差
mae = mean_absolute_error(y_test, y_pred)  # 计算平均绝对误差# 打印模型的评估结果
print(f'Mean Squared Error (MSE): {mse:.4f}')
print(f'Mean Absolute Error (MAE): {mae:.4f}')# 绘制真实值与预测值的对比图
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.5)  # 绘制散点图
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], '--k')  # 绘制对角线
plt.xlabel('True Values')  # X轴标签
plt.ylabel('Predictions')  # Y轴标签
plt.title('True Values vs Predictions')  # 图标题
plt.show()# 特征重要性可视化
feature_importances = model.get_feature_importance()  # 获取特征重要性
feature_names = X.columns  # 获取特征名称plt.figure(figsize=(10, 6))
sns.barplot(x=feature_importances, y=feature_names)  # 绘制特征重要性条形图
plt.title('Feature Importances')  # 图标题
plt.show()# 输出
'''
mpg  cylinders  displacement  horsepower  weight  acceleration  \
0  18.0          8         307.0       130.0  3504.0          12.0   
1  15.0          8         350.0       165.0  3693.0          11.5   
2  18.0          8         318.0       150.0  3436.0          11.0   
3  16.0          8         304.0       150.0  3433.0          12.0   
4  17.0          8         302.0       140.0  3449.0          10.5   model_year  origin  
0          70       1  
1          70       1  
2          70       1  
3          70       1  
4          70       1  
0: learn: 7.3598113 test: 6.6405869 best: 6.6405869 (0) total: 1.7ms remaining: 1.69s
100: learn: 1.5990203 test: 2.3207830 best: 2.3207666 (94) total: 132ms remaining: 1.17s
200: learn: 1.0613606 test: 2.2319632 best: 2.2284239 (183) total: 272ms remaining: 1.08s
Stopped by overfitting detector  (50 iterations wait)bestTest = 2.21453232
bestIteration = 238Shrink model to first 239 iterations.
Mean Squared Error (MSE): 4.9042
Mean Absolute Error (MAE): 1.6381<Figure size 1000x600 with 1 Axes>
<Figure size 1000x600 with 1 Axes>
'''

Mean Squared Error (MSE): 均方误差,表示预测值与实际值之间的平均平方差异。值越小,模型性能越好,在这里MSE的值是4.9042。

Mean Absolute Error (MAE): 平均绝对误差,表示预测值与实际值之间的平均绝对差异。值越小,模型性能越好,在这里MAE的值是1.6381。

图片

  1. 散点图:图中的每个点表示一个测试样本。横坐标表示该样本的真实值(MPG),纵坐标表示模型的预测值(MPG)。

  2. 对角线:图中的黑色虚线是45度对角线,表示理想情况下的预测结果,即预测值等于真实值。

  3. 点的分布:

    • 靠近对角线:表示模型的预测值与真实值非常接近,预测准确。

    • 远离对角线:表示预测值与真实值有较大差距,预测不准确。

通过图中的点可以看到大部分点都集中在对角线附近,这表明模型的预测性能良好,但也有一些点离对角线较远,表示这些样本的预测值与真实值存在一定的差距。

图片

  1. 条形图:每个条形表示一个特征在模型中的重要性。条形越长,表示该特征对模型预测的贡献越大。

  2. 特征名称:在Y轴上列出了所有特征的名称。

  3. 特征重要性值:在X轴上显示了每个特征的相对重要性值。

从图中可以看到:

  1. model_year:在所有特征中最重要,表示汽车的生产年份对预测燃油效率有很大的影响。

  2. weight:汽车的重量是第二重要的特征,对燃油效率也有显著影响。

  3. displacement 和 horsepower:发动机的排量和功率对燃油效率也有较大的贡献。

在实例中,我们使用CatBoost处理Auto MPG数据集,其主要目的是构建一个回归模型,以预测汽车的燃油效率(即每加仑燃油行驶的英里数,MPG)。

以上内容总结自网络,如有帮助欢迎转发,我们下次再见!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43423.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是ThingsKit物联网平台?

在信息化时代的浪潮中&#xff0c;物联网&#xff08;IoT&#xff09;作为新一代信息技术的核心&#xff0c;已经逐渐渗透到我们生活的方方面面。而在这个大背景下&#xff0c;Thingskit物联网平台以其独特的技术优势和应用场景&#xff0c;成为了物联网领域的一颗璀璨明星。本…

AI和人工智能是啥关系?

AI&#xff08;人工智能&#xff09;与通用人工智能&#xff08;AGI&#xff09;是人工智能领域中的两个重要概念&#xff0c;它们在定义、技术基础以及应用领域等方面有所区别。人工智能&#xff08;AI&#xff09;&#xff0c;是指使计算机和其他机器模拟人类智能的技术&…

3.flink架构

目录 概述 概述 Flink是一个分布式的带有状态管理的计算框架&#xff0c;为了执行流应用程序&#xff0c;可以和 Hadoop YARN 、K8s 进行整合&#xff0c;当然也可以是一个 standalone 。 官方地址&#xff1a;速递 k8s 是未来的一种趋势&#xff0c;对资源管控能力强。

Windows 控制中心在哪里打开,七种方法教会你

在 Windows 操作系统中&#xff0c;控制中心的概念可能稍有些混淆&#xff0c;因为 Windows 通常使用“控制面板”这一术语来指代用于配置系统设置和更改硬件及软件设置的中心区域。 不过&#xff0c;随着 Windows 的更新&#xff0c;微软也在逐步将一些设置迁移到“设置”应用…

关于Linux的操作作业!24道题

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

js如何要让一个对象继承另一个对象的原型属性和方法

js如何要让一个对象继承另一个对象的原型属性和方法 1、使用 Object.create() const parent {greet: function() {console.log("Hello from parent!");} };const child Object.create(parent); child.greet(); // 输出: Hello from parent!2、使用 proto 属性 …

【算法】贪婪算法介绍及实现方法

贪婪算法简介 贪婪算法&#xff08;Greedy Algorithm&#xff09;是一种在每一步选择中都采取当前状态下最好或最优&#xff08;即最有利&#xff09;的选择&#xff0c;从而希望导致结果是全局最好或最优的算法。贪婪算法通常用于解决优化问题&#xff0c;如最小化成本、最大…

Tomcat打破双亲委派模型的方式

文章目录 1、前言2、标准的双亲委派模型3、Tomcat的类加载器架构4、Tomcat打破双亲委派模型的方式5、总结 1、前言 双亲委派模型是一种类加载机制&#xff0c;它确保了类加载器层次结构中的父加载器先于子加载器尝试加载类。这种机制有助于防止类的重复加载和类之间的不兼容。…

MySQL数据库基本操作-DDL和DML

1. DDL解释 DDL(Data Definition Language)&#xff0c;数据定义语言&#xff0c;该语言部分包括以下内容&#xff1a; 对数据库的常用操作对表结构的常用操作修改表结构 2. 对数据库的常用操作 功能SQL查看所有的数据库show databases&#xff1b;查看有印象的数据库show d…

16 - Python语言进阶

Python语言进阶 数据结构和算法 算法&#xff1a;解决问题的方法和步骤 评价算法的好坏&#xff1a;渐近时间复杂度和渐近空间复杂度。 渐近时间复杂度的大O标记&#xff1a; - 常量时间复杂度 - 布隆过滤器 / 哈希存储 - 对数时间复杂度 - 折半查找&#xff08;二分查找&am…

关于TCP的三次握手流程

三次握手流程 第一次握手&#xff1a;客户端向服务端发起建立连接请求&#xff0c;客户端会随机生成一个起始序列号x&#xff0c;客户端向服务端发送的字段包含标志位SYN1&#xff0c;序列号segx。第一次握手后客户端的状态为SYN-SENT。此时服务端的状态为LISTEN 第二次握手&…

使用耳机壳UV树脂制作私模定制耳塞的价格如何呢?

使用耳机壳UV树脂制作私模定制耳塞的价格如何呢&#xff1f; 耳机壳UV树脂制作私模定制耳塞的价格因多个因素而异&#xff0c;如材料、工艺、设计、定制复杂度等。 根据我目前所了解到的信息&#xff0c;使用UV树脂制作私模定制耳塞的价格可能在数百元至数千元不等。具体价格…

LVS+Nginx高可用集群---Nginx进阶与实战

1.Nginx中解决跨域问题 两个站点的域名不一样&#xff0c;就会有一个跨域问题。 跨域问题&#xff1a;了解同源策略&#xff1a;协议&#xff0c;域名&#xff0c;端口号都相同&#xff0c;只要有一个不相同那么就是非同源。 CORS全称Cross-Origin Resource Sharing&#xff…

大模型知识大全1-基础知识【大模型】

文章目录 大模型简介以后的介绍流程基础知识训练流程介绍pre-train对齐和指令微调规模拓展涌现能力 系统学习大模型的记录https://github.com/LLMBook-zh/LLMBook-zh.github.io 大模型简介 历史我就不写了&#xff0c;简单说说大模型的应用和特点。人类使用大模型其实分为两个…

linux高级编程(OSI/UDP(用户数据报))

OSI七层模型&#xff1a; OSI 模型 --> 开放系统互联模型 --> 分为7层&#xff1a; 理想模型 --> 尚未实现 1.应用层 QQ 应用程序的接口 2.表示层 加密解密 gzip 将接收的数据进行解释&#xff…

【shell】—双引号引用变量

文章目录 一、举例—单、双引号引用变量的结果差异二、使用双引号引用变量的场景1、使用双引号—可以防止字符串被分割2、使用双引号—特殊字符变为普通字符3、使用双引号—保存原始命令的输出格式4、使用双引号—具有强约束的单引号变为普通单引号字符5、注意 一、举例—单、双…

挑战杯 opencv python 深度学习垃圾图像分类系统

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; opencv python 深度学习垃圾分类系统 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;4分 这是一个较为新颖的竞…

昇思25天学习打卡营第13天|应用实践之ResNet50迁移学习

基本介绍 今日的应用实践的模型是计算机实践领域中十分出名的模型----ResNet模型。ResNet是一种残差网络结构&#xff0c;它通过引入“残差学习”的概念来解决随着网络深度增加时训练困难的问题&#xff0c;从而能够训练更深的网络结构。现很多网络极深的模型或多或少都受此影响…

数据链路层(超详细)

引言 数据链路层是计算机网络协议栈中的第二层&#xff0c;位于物理层之上&#xff0c;负责在相邻节点之间的可靠数据传输。数据链路层使用的信道主要有两种类型&#xff1a;点对点信道和广播信道。点对点信道是指一对一的通信方式&#xff0c;而广播信道则是一对多的通信方式…