【机器学习】CatBoost 模型实践:回归与分类的全流程解析

一. 引言

本篇博客首发于掘金 https://juejin.cn/post/7441027173430018067。
PS:转载自己的文章也算原创吧。

在机器学习领域,CatBoost 是一款强大的梯度提升框架,特别适合处理带有类别特征的数据。本篇博客以脱敏后的保险数据集为例,展示如何利用 CatBoost 完成分类和回归任务,并以可视化的方式解析特征重要性与结果。

我们将完成以下任务:

  1. 回归任务:预测保险索赔金额。
  2. 分类任务:判断保险案件是否需要调查。
  3. 可视化分析:利用散点图与分割线展示结果。

二. CatBoost 模型简介

CatBoost 是由俄罗斯搜索巨头 Yandex 于 2017 年开源的机器学习库,其名称来源于 “Category” 和 “Boosting” 的组合,旨在高效处理类别特征的梯度提升算法。与其他模型(如 XGBoost 和 LightGBM)相比,CatBoost 具有以下优势:

  • 支持类别特征:无需对类别特征进行独热编码,直接处理类别数据,避免数据膨胀。
  • 对缺失值的鲁棒性:无需特殊预处理即可直接处理缺失值。
  • 防止过拟合:内置多种正则化手段,减少梯度偏差和预测偏移,提高模型的准确性和泛化能力。
  • 对称树结构:采用对称决策树(Oblivious Trees),在每个层级使用相同的特征和分割点,提升训练和预测效率。

三. 实战项目环境与数据准备

本项目使用了脱敏后的保险数据集,包含以下特征:

  • 类别特征:险种代码、出险原因、医疗责任类别等。
  • 数值特征:基本保额、索赔金额等。
  • 标签:是否需要调查(分类任务)。

所有数据均已脱敏,支持迁移至其他表格数据集。

因为不好分享,所以后续第七节补充了一个基于sklearn "California Housing"数据集的流程代码与说明。


四. 回归任务:预测保险索赔金额

数据预处理

在回归任务中,我们根据特征预测索赔金额。以下是数据清洗与预处理的关键步骤:

  1. 过滤无效数据:移除缺失或非法值的记录。
  2. 特征转换:将类别特征转为字符串类型。
  3. 分割数据集:按 80% 和 20% 的比例划分训练集与测试集。

4.1 模型训练与评估

我们使用 CatBoost 进行回归建模,模型参数包括:

  • 学习率:0.02
  • 深度:8
  • 迭代次数:10,000(支持提前停止)

以下是模型的关键代码:

from catboost import CatBoostRegressor# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(iterations=10000,learning_rate=0.02,depth=8,eval_metric='RMSE',early_stopping_rounds=1500,random_seed=42
)# 训练模型
cat_regressor.fit(X_train, y_train,cat_features=categorical_features_indices,eval_set=(X_test, y_test),verbose=100
)

4.2 特征重要性分析

特征重要性是衡量特征对模型预测贡献程度的指标,可以帮助我们更好地理解模型。

# 获取特征重要性
feature_importances = cat_regressor.get_feature_importance()
feature_names = X_train.columns# 可视化特征重要性
import matplotlib.pyplot as plt
importance_df = pd.DataFrame({'Feature': feature_names,'Importance': feature_importances
}).sort_values(by='Importance', ascending=True)plt.figure(figsize=(10, 6))
plt.barh(importance_df['Feature'], importance_df['Importance'], color='salmon')
plt.xlabel('特征重要性')
plt.ylabel('特征名称')
plt.title('CatBoost 特征重要性分析')
plt.show()

结果展示
在这里插入图片描述


4.3 模型评估

我们可以均方误差 (MSE) 以及 平均绝对误差 (MAE) 来评估模型在测试集上的回归性能,同时展示模型的学习曲线:

# 获取训练和测试集的 RMSE
evals_result = cat_regressor.get_evals_result()
train_rmse = evals_result['learn']['RMSE']
test_rmse = evals_result['validation']['RMSE']# 绘制 RMSE 曲线
plt.figure(figsize=(10, 6))
plt.plot(train_rmse, label='训练集 RMSE')
plt.plot(test_rmse, label='测试集 RMSE')
plt.title('训练与测试集的 RMSE 学习曲线')
plt.xlabel('迭代次数')
plt.ylabel('RMSE')
plt.legend()
plt.show()

五. 分类任务:判别是否调查

5.1 数据标注与模型选择

分类任务以 是否调查 作为标签(1 表示需要调查,0 表示无需调查),特征包括所有数值和类别字段。

为了完成分类任务,我们选用 CatBoostClassifier。模型参数类似于回归模型,分类评估指标包括准确率、混淆矩阵和分类报告。


5.2 训练结果与模型评估

训练结果显示,分类准确率达 94.0%。以下是模型的分类报告:

分类报告 (训练集):precision    recall  f1-score   support0       0.96      0.98      0.97     130871       0.74      0.57      0.64      1354accuracy                           0.94     14441macro avg       0.85      0.77      0.80     14441
weighted avg       0.94      0.94      0.94     14441
5.3 代码示例
from catboost import CatBoostClassifier# 初始化分类器
cat_classifier = CatBoostClassifier(iterations=1000,learning_rate=0.02,depth=8,eval_metric='Accuracy',early_stopping_rounds=150,random_seed=42
)# 模型训练
cat_classifier.fit(X_train, y_train,cat_features=categorical_features_indices,eval_set=(X_test, y_test),verbose=100
)

六. 可视化分析

为更直观地理解模型,我们利用散点图和分割线对预测结果进行展示:

  • 散点图:展示实际金额与预测金额的分布。
  • 分割线:通过 KMeans 聚类划分四个金额档次。

以下代码生成散点图与分割线:

# 使用 KMeans 聚类生成分割线
from sklearn.cluster import KMeanskmeans = KMeans(n_clusters=4, random_state=42)
df['cluster'] = kmeans.fit_predict(df[['预测金额']])# 绘制散点图
plt.figure(figsize=(12, 8))
plt.scatter(df['预测金额'], df['是否调查'], c=df['cluster'], cmap='tab10')
plt.title("预测金额与是否调查的散点图")
plt.xlabel("预测金额")
plt.ylabel("是否调查")
plt.colorbar(label='Cluster')
plt.show()

散点图展示

在这里插入图片描述


七. 补充学习

7.1 基础数据集

California Housing 数据集包含加利福尼亚州 20,640 个街区的人口、住房和收入信息。目标是预测每个街区的房价中位数 MedHouseVal

数据特征

  1. MedInc:街区的收入中位数。
  2. HouseAge:街区住房的平均年龄。
  3. AveRooms:每个街区的平均房间数。
  4. AveBedrms:每个街区的平均卧室数。
  5. Population:街区的总人口。
  6. AveOccup:每户的平均人数。
  7. Latitude:街区的纬度。
  8. Longitude:街区的经度。

7.2 实践步骤

7.2.1 导入数据与预处理

我们使用 Scikit-learn 加载数据并进行预处理。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd# 加载 California Housing 数据集
data = fetch_california_housing(as_frame=True)
df = data.frame# 特征和目标变量
X = df.drop(columns="MedHouseVal")
y = df["MedHouseVal"]# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")

训练集大小: (16512, 8), 测试集大小: (4128, 8)


7.2.2 训练 CatBoost 回归模型

使用 CatBoost 对房价进行预测。

from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(iterations=1000,learning_rate=0.1,depth=6,eval_metric="RMSE",random_seed=42,verbose=100
)# 模型训练
cat_regressor.fit(X_train, y_train, eval_set=(X_test, y_test), verbose=100, early_stopping_rounds=50)# 模型预测
y_pred_train = cat_regressor.predict(X_train)
y_pred_test = cat_regressor.predict(X_test)# 模型评估
mse_train = mean_squared_error(y_train, y_pred_train)
mse_test = mean_squared_error(y_test, y_pred_test)
mae_test = mean_absolute_error(y_test, y_pred_test)print(f"训练集均方误差 (MSE): {mse_train}")
print(f"测试集均方误差 (MSE): {mse_test}")
print(f"测试集平均绝对误差 (MAE): {mae_test}")

输出如下:

0:	learn: 1.0934740	test: 1.0841841	best: 1.0841841 (0)	total: 1.24s	remaining: 20m 38s
100:	learn: 0.4867395	test: 0.5154868	best: 0.5154868 (100)	total: 1.54s	remaining: 13.7s
200:	learn: 0.4320149	test: 0.4798269	best: 0.4798269 (200)	total: 1.8s	remaining: 7.18s
300:	learn: 0.4020581	test: 0.4657293	best: 0.4657293 (300)	total: 2.07s	remaining: 4.8s
400:	learn: 0.3803801	test: 0.4582868	best: 0.4582868 (400)	total: 2.35s	remaining: 3.5s
500:	learn: 0.3633580	test: 0.4534430	best: 0.4534430 (500)	total: 2.61s	remaining: 2.6s
600:	learn: 0.3488402	test: 0.4491723	best: 0.4491723 (600)	total: 2.89s	remaining: 1.92s
700:	learn: 0.3358611	test: 0.4461323	best: 0.4461323 (700)	total: 3.17s	remaining: 1.35s
800:	learn: 0.3234759	test: 0.4431320	best: 0.4431320 (800)	total: 3.44s	remaining: 854ms
900:	learn: 0.3126821	test: 0.4403978	best: 0.4403978 (900)	total: 3.71s	remaining: 407ms
999:	learn: 0.3025414	test: 0.4386906	best: 0.4386902 (998)	total: 3.97s	remaining: 0usbestTest = 0.438690174
bestIteration = 998Shrink model to first 999 iterations.
训练集均方误差 (MSE): 0.09158491090576551
测试集均方误差 (MSE): 0.19244906768098075
测试集平均绝对误差 (MAE): 0.28701415230111493

7.2.3 可视化预测结果

展示预测值与实际值的对比,以及模型的特征重要性。

实际值与预测值对比
import matplotlib.pyplot as plt# 对比测试集的预测值和实际值
plt.figure(figsize=(10, 6))
plt.scatter(range(len(y_test)), y_test, color="blue", label="真实值", alpha=0.6)
plt.scatter(range(len(y_pred_test)), y_pred_test, color="red", label="预测值", alpha=0.6)
plt.title("真实房价与预测房价对比")
plt.xlabel("样本索引")
plt.ylabel("房价中位数")
plt.legend()
plt.show()

特征重要性分析
# 特征重要性可视化
feature_importances = cat_regressor.get_feature_importance()
feature_names = data.feature_namesplt.figure(figsize=(10, 6))
plt.barh(feature_names, feature_importances, color="skyblue")
plt.title("CatBoost 特征重要性")
plt.xlabel("重要性得分")
plt.ylabel("特征名称")
plt.show()

在这里插入图片描述


7.3 数据结果

  • 模型评估结果:
    • 训练集均方误差 (MSE): 0.09158491090576551
    • 测试集均方误差 (MSE): 0.19244906768098075
    • 测试集平均绝对误差 (MAE): 0.28701415230111493
  • 特征重要性解读:
    根据特征重要性分析,MedInc(收入中位数)对预测房价的影响最大,而经纬度特征(Latitude 和 Longitude)也提供了显著的信息。

八. 总结

通过本项目,我们完成了基于 CatBoost 的回归与分类建模,并展示了预测结果的可视化。CatBoost 的强大功能和易用性使其在处理类别特征和缺失值的数据中表现优异。

希望本篇博客能为大家带来启发,助力实际项目的落地实现。如果对您有所帮助,也欢迎点赞与分享😊。

源码已上传到:https://github.com/YYForReal/ML-DL-RL-Learning/blob/main/ML-Learning/Catboost/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62492.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

lua download

https://www.lua.org/ https://www.lua.org/versions.html#5.4

Java 中的 ResponseBodyEmitter:详解与实战

Java 中的 ResponseBodyEmitter:详解与实战 前言 在开发高并发应用或处理长时间任务时,服务端需要向客户端实时推送数据,而不是一次性将所有结果返回。Spring 提供了一种优雅的解决方案:ResponseBodyEmitter。它适用于需要逐步发…

基于大数据python 房屋价格数据分析预测可视化系统(源码+LW+部署讲解+数据库+ppt)

!!!!!!!!! 很对人不知道选题怎么选 不清楚自己适合做哪块内容 都可以免费来问我 避免后期給自己答辩找麻烦 增加难度(部分学校只有一次答辩机会 没弄好就延迟…

pikachu文件上传漏洞通关详解

声明:文章只是起演示作用,所有涉及的网站和内容,仅供大家学习交流,如有任何违法行为,均和本人无关,切勿触碰法律底线 目录 概念:什么是文件上传漏洞一、客户端check二、MIME type三、getimagesi…

lua-cjson 例子

apt install -y lua-cjson 安装 编辑 tmp.lua cjson require "cjson" p 666 d "23.42" payload{"d":[{"pres":..(p)..,"temp":"..(d).."}]} print("payload " .. payload) j cjson.decode(payloa…

Monitor 显示器软件开发设计入门二

基础篇--显示驱动方案输出接口介绍 写在前面:首先申明,这篇文章是写给那些初入显示器软件行业的入门者,或是对显示器没有基本知识的小白人员。如您是行业大咖大神,可以绕行,可看后期进阶文章。 上篇介绍了输入接口及相…

像素流送api ue多人访问需要什么显卡服务器

关于像素流送UE推流,在之前的文章里其实小芹和大家聊过很多,不过今天偶然搜索发现还是有很多小伙伴,在搜索像素流送相关的问题,搜索引擎给的提示有这些。当然这些都是比较短的词汇,可能每个人真正遇到的问题和想获取的…

【21-30期】Java技术深度剖析:从分库分表到微服务的核心问题解析

🚀 作者 :“码上有前” 🚀 文章简介 :Java 🚀 欢迎小伙伴们 点赞👍、收藏⭐、留言💬 文章题目:Java技术深度剖析:从分库分表到微服务的核心问题解析 摘要: 本…

C#里怎么样List类进行拷贝?

演示的代码: using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; namespace ConsoleApp1 { internal class Program { static void Main(string[] args) …

Flutter 权限申请

这篇文章是基于permission_handler 10.2.0版本写的 前言 在App开发过程中我们经常要用到各种权限,我是用的是permission_handler包来实现权限控制的。 pub地址:https://pub.dev/packages/permission_handler permission_handler 权限列表 变量 Androi…

C#:时间与时间戳的转换

1、将 DateTime 转换为 Unix 时间戳(秒) public static long DateTimeToUnixTimestamp(DateTime dateTime) {// 定义UTC纪元时间DateTime epochStart new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);// 计算从UTC纪元时间到指定时间的总秒数Tim…

【Spring】Spring IOCDI:架构旋律中的“依赖交响”与“控制华章”

前言 🌟🌟本期讲解关于Spring IOC&DI的详细介绍~~~ 🌈感兴趣的小伙伴看一看小编主页:GGBondlctrl-CSDN博客 🔥 你的点赞就是小编不断更新的最大动力 🎆那么…

在CentOS7上更换为阿里云源

在CentOS 7上更换为阿里云YUM源可以通过以下步骤进行: 备份当前的YUM源配置文件 sudo mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup 下载阿里云的YUM源配置文件 sudo curl -o /etc/yum.repos.d/CentOS-Base.repo http://mirr…

【小白学机器学习34】基础统计2种方法:用numpy的方法np().mean()等进行统计,pd.DataFrame.groupby() 分组统计

目录 1 用 numpy 快速求数组的各种统计量:mean, var, std 1.1 数据准备 1.2 直接用np的公式求解 1.3 注意问题 1.4 用print() 输出内容,显示效果 2 为了验证公式的背后的理解,下面是详细的展开公式的求法 2.1 均值mean的详细 2.2 方差…

Oracle 19C Data Guard 单实例1+1部署(Duplicate方式)

环境描述 项目主库备库操作系统CentOS 7.9CentOS 7.9数据库版本Oracle 19.3.0.0Oracle 19.3.0.0ORACLE_UNQNAMEhishisdgIP地址10.172.1.10110.172.1.102Hostnamehisdb01hisdb02SIDhishisdb_namehishisdb_unique_namehishisdg 说明 主库和备库建议采用相同服务器配置。主库和…

ubuntu20配置mysql注意事项

目录 一、mysql安装 二、初始化配置密码 三、配置文件的位置 四、常用的mysql命令 五、踩坑以及解决方法 一、mysql安装 1.更新apt源 sudo apt update 2.安装mysql服务 sudo apt-get install mysql-server 3.初始化配置 sudo mysql_secure_installation 4.配置项 VALI…

开展网络安全成熟度评估:业务分析师的工具和技术

想象一下,您坐在飞机驾驶舱内。起飞前,您需要确保所有系统(从发动机到导航工具)均正常运行。现在,将您的业务视为飞机,将网络安全视为飞行前必须检查的系统。就像飞行员依赖检查表一样,业务分析师使用网络安全成熟度评估来评估组织对网络威胁的准备程度。这些评估可帮助…

MySql(面试题理解B+树原理 实操加大白话)

数据的定位 通过磁道和扇区定位到数据的位置 扇区为512字节 黄色地方数据位置为2磁道3扇区 黑色地方数据位置为1磁道1扇区 通过磁道和扇区还有偏移量定位到数据的位置 比如这里有一张表 由id、name、no、address组成id为主键 列占有大小(字节) id int …

目标检测,图像分割,超分辨率重建

目标检测和图像分割 目标检测和图像分割是计算机视觉中的两个不同任务,它们的输出形式也有所不同。下面我将分别介绍这两个任务的输出。图像分割又可以分为:语义分割、实例分割、全景分割。 语义分割(Semantic Segmentation)&…

K8s内存溢出问题剖析:排查与解决方案

文章目录 一、背景二、排查方案:1. 可能是数据量超出了限制的大小,检查数据目录大小2. 查看是否是内存溢出2.1 排查数据量(查看数据目录大小是否超过limit限制)2.2 查看pod详情发现问题 三、解决过程 一、背景 做redis压测过程中…