Python28-10 LightGBM对乳腺癌数据集进行分类

图片

LightGBM(Light Gradient Boosting Machine)是一个梯度提升框架,由微软开发。它用于机器学习中的分类、回归和排序等任务,特别适合处理大规模数据和高维特征。LightGBM的核心是梯度提升决策树(GBDT)算法,但它在此基础上做了多种优化,使其在速度和内存使用方面优于传统的GBDT实现。LightGBM的数学原理基于梯度提升决策树(Gradient Boosting Decision Trees, GBDT),而GBDT本身是一个集成学习算法。

1. 梯度提升决策树(GBDT)

GBDT是通过逐步构建多个决策树,并通过每棵树来纠正前一棵树的错误。具体而言:

  • 初始模型:通常从一个简单的模型开始,比如预测所有样本的平均值。

  • 残差计算:计算当前模型的残差(即预测值与真实值之间的差异)。

  • 决策树训练:训练一棵新的决策树来拟合残差。

  • 模型更新:将新决策树的预测值加到当前模型中。

每次迭代的公式为:

其中:

  • 是第 次迭代的模型。

  • 是学习率,控制每棵树对最终模型的贡献。

  •  是第 棵决策树的预测值。

2. LightGBM的改进

LightGBM在GBDT的基础上做了许多优化,以提高训练速度和效率:

(1) 基于直方图的算法

传统的GBDT每次分裂节点时都要对所有特征值进行排序,时间复杂度较高。LightGBM通过将连续特征值离散化成直方图的形式来加速分裂过程。具体步骤如下:

  • 构建直方图:将特征值划分成多个bin(桶)。

  • 统计每个bin中的样本数量和目标值的总和

  • 选择最佳分裂点:在直方图的基础上计算各个分裂点的增益,选择增益最大的分裂点。

(2) 带深度限制的叶子增长策略

LightGBM采用了一个叫做“Leaf-wise”的策略,而不是传统的“Level-wise”策略:

  • Level-wise:按层次生长,每层增加一层深度。

  • Leaf-wise:每次选择增益最大的叶子进行分裂。

Leaf-wise策略能够更好地减少损失,但容易导致不平衡的树结构。为此,LightGBM引入了最大深度限制,以防止过拟合。

3. 数学公式与损失函数

LightGBM的目标是最小化损失函数,一般情况下使用平方误差或交叉熵损失。以平方误差为例,损失函数为:

其中:

  • 是第 个样本的真实值。

  • 是模型的预测值。

在每次迭代中,通过计算损失函数的负梯度作为残差,来训练下一棵决策树。这一过程可以看作是用梯度下降法来优化模型参数。

LightGBM通过多种优化方法(如基于直方图的分裂、Leaf-wise的增长策略等),在保留GBDT强大预测能力的同时,提高了训练速度和效率。这些优化方法的数学原理相对复杂,但总体思路仍然是基于梯度提升决策树的框架。

LightGBM的Python实例

我们选择乳腺癌数据集,乳腺癌数据集包含了569个样本,每个样本都有30个特征。这些特征主要描述了细胞核的特性(如半径、质地、周长、面积等),目标变量是样本是否为恶性肿瘤(标记为1)或良性肿瘤(标记为0)。

图片

import lightgbm as lgb  # 导入LightGBM库
import pandas as pd  # 导入Pandas库,用于数据处理
import numpy as np  # 导入Numpy库,用于数值计算
from sklearn.datasets import load_breast_cancer  # 从sklearn库导入乳腺癌数据集
from sklearn.model_selection import train_test_split  # 导入train_test_split,用于划分数据集
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix  # 导入评估指标
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于绘图
import seaborn as sns  # 导入Seaborn库,用于绘制热力图# 加载乳腺癌数据集
data = load_breast_cancer()
# 将数据集转换为Pandas DataFrame格式
X = pd.DataFrame(data.data, columns=data.feature_names)
# 将目标变量转换为Pandas Series格式
y = pd.Series(data.target)# 划分训练集和测试集,80%用于训练,20%用于测试
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建LightGBM数据集,用于训练
train_data = lgb.Dataset(X_train, label=y_train)
# 创建LightGBM数据集,用于测试
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)# 设置LightGBM模型参数
params = {'objective': 'binary',  # 二分类任务'metric': 'auc',  # 评估指标使用AUC'boosting_type': 'gbdt',  # 使用梯度提升决策树'num_leaves': 31,  # 每棵树的最大叶子数'learning_rate': 0.05,  # 学习率'feature_fraction': 0.9  # 每次迭代使用90%的特征
}# 设置早停回调,若验证集上AUC在10轮内没有提升,则停止训练
callbacks = [lgb.early_stopping(stopping_rounds=10)]# 训练LightGBM模型
bst = lgb.train(params, train_data, num_boost_round=100, valid_sets=[test_data], callbacks=callbacks)# 使用训练好的模型进行预测
y_pred = bst.predict(X_test, num_iteration=bst.best_iteration)
# 将预测概率转换为二进制分类结果(0或1)
y_pred_binary = (y_pred > 0.5).astype(int)# 评估模型的准确率
accuracy = accuracy_score(y_test, y_pred_binary)
# 评估模型的ROC AUC值
roc_auc = roc_auc_score(y_test, y_pred)# 打印模型的准确率和ROC AUC值
print(f'Accuracy: {accuracy:.4f}')
print(f'ROC AUC: {roc_auc:.4f}')# 计算混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred_binary)
# 使用Seaborn绘制混淆矩阵的热力图
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=data.target_names, yticklabels=data.target_names)
# 设置混淆矩阵的标签
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
# 显示混淆矩阵图
plt.show()# 绘制特征重要性图,按分裂次数排序
lgb.plot_importance(bst, max_num_features=10, importance_type='split')
plt.title('Feature Importance (split)')
# 显示特征重要性图
plt.show()# 绘制特征重要性图,按信息增益排序
lgb.plot_importance(bst, max_num_features=10, importance_type='gain')
plt.title('Feature Importance (gain)')
# 显示特征重要性图
plt.show()# 输出:
'''
Accuracy: 0.6228
ROC AUC: 0.9689
'''

可视化结果:

图片

  • 混淆矩阵的左上角(第一个象限)显示模型预测为“malignant”且实际也是“malignant”的样本数。在这个矩阵中,该值为0,表示没有正确预测出恶性肿瘤。

  • 右上角(第二个象限)显示模型预测为“benign”但实际是“malignant”的样本数。该值为43,表示有43个恶性肿瘤被错误分类为良性肿瘤。

  • 左下角(第三个象限)显示模型预测为“malignant”但实际是“benign”的样本数。该值为0,表示没有良性肿瘤被错误分类为恶性肿瘤。

  • 右下角(第四个象限)显示模型预测为“benign”且实际也是“benign”的样本数。该值为71,表示有71个良性肿瘤被正确分类为良性肿瘤。

从这个混淆矩阵来看,模型的性能非常不理想。模型无法正确分类恶性肿瘤样本,而能正确分类的只是良性肿瘤。可能需要调整模型参数、增加特征工程或者尝试不同的模型来提高分类性能。

图片

从上图中可知,模型在进行分类时对不同特征的依赖程度如下:

特征重要性最高的前三个特征(按分裂次数)

  1. mean_radius(平均半径)

    • 分裂次数最多,重要性最高。

    • 在决策树中,它是被最频繁用来分裂数据的特征。

  2. worst_texture(最差质地)

    • 分裂次数为3,说明这个特征也是模型中重要的特征之一。

  3. mean_concave_points(平均凹点数)

    • 分裂次数为3,表明这个特征对模型的分类也有显著影响。

图片

根据特征重要性图(按信息增益排序)来看,模型在进行分类时对不同特征的依赖程度如下:

特征重要性最高的前三个特征(按信息增益排序)

  1. mean_concave_points(平均凹点数)

    • 信息增益最高,为588.950。

    • 这是模型中最重要的特征,对模型决策贡献最大。

  2. worst_radius(最差半径)

    • 信息增益为48.979,显示其在模型中的重要性。

    • 这是第二重要的特征。

  3. worst_concave_points(最差凹点数)

    • 信息增益为29.814,表明该特征对模型分类有显著影响。

两种特征重要性的对比

分裂次数(split)重要性

分裂次数(split)重要性显示了每个特征在决策树中被用于分裂的频率。它的用途和意义包括:

  1. 模型解释性:通过查看哪些特征被频繁使用,可以了解模型在决策过程中更依赖哪些特征。

  2. 特征选择:高分裂次数的特征通常对模型性能有重要影响,可以作为重要特征保留。

  3. 特征工程:可以基于这些特征进行进一步的数据处理和特征工程,提高模型的性能。

信息增益(gain)重要性

信息增益(gain)重要性显示了每个特征在决策树中分裂时带来的信息增益。它的用途和意义包括:

  1. 模型优化:高信息增益的特征对模型的预测性能贡献更大,优化这些特征可以显著提升模型性能。

  2. 特征选择:信息增益高的特征可以作为模型的核心特征进行保留,而信息增益低的特征可能对模型贡献较小,可以考虑去除。

  3. 模型解释性:帮助理解哪些特征对模型的决策影响最大,尤其是在关键节点上的决策。

改进后的代码:

import lightgbm as lgb  # 导入LightGBM库
import pandas as pd  # 导入Pandas库,用于数据处理
import numpy as np  # 导入Numpy库,用于数值计算
from sklearn.datasets import load_breast_cancer  # 从sklearn库导入乳腺癌数据集
from sklearn.model_selection import train_test_split, GridSearchCV  # 导入train_test_split和GridSearchCV,用于划分数据集和交叉验证
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix  # 导入评估指标
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于绘图
import seaborn as sns  # 导入Seaborn库,用于绘制热力图# 加载乳腺癌数据集
data = load_breast_cancer()
# 将数据集转换为Pandas DataFrame格式
X = pd.DataFrame(data.data, columns=data.feature_names)
# 将目标变量转换为Pandas Series格式
y = pd.Series(data.target)# 划分训练集和测试集,80%用于训练,20%用于测试
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建LightGBM数据集,用于训练
train_data = lgb.Dataset(X_train, label=y_train)
# 创建LightGBM数据集,用于测试
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)# 设置初始参数
params = {'objective': 'binary',  # 二分类任务'metric': 'auc',  # 评估指标使用AUC'boosting_type': 'gbdt',  # 使用梯度提升决策树'num_leaves': 31,  # 每棵树的最大叶子数'learning_rate': 0.01,  # 降低学习率'feature_fraction': 0.8,  # 每次分裂使用80%的特征'bagging_fraction': 0.8,  # 每次训练使用80%的数据'bagging_freq': 5,  # 每5次迭代进行一次重采样'lambda_l1': 0.1,  # L1正则化'lambda_l2': 0.1  # L2正则化
}# 设置早停回调,若验证集上AUC在50轮内没有提升,则停止训练
callbacks = [lgb.early_stopping(stopping_rounds=50)]# 使用交叉验证来选择最佳参数
grid_params = {'num_leaves': [31, 41, 51],  # 叶子节点数的候选值'learning_rate': [0.01, 0.05, 0.1],  # 学习率的候选值'n_estimators': [100, 200, 500]  # 迭代次数的候选值
}# 创建LGBMClassifier对象
gbm = lgb.LGBMClassifier(**params)# 使用GridSearchCV进行交叉验证
grid = GridSearchCV(gbm, grid_params, scoring='roc_auc', cv=5)
# 训练模型并寻找最佳参数
grid.fit(X_train, y_train)# 输出最佳参数
print(f'Best parameters found by grid search are: {grid.best_params_}')# 使用最佳参数训练模型
best_params = grid.best_params_
# 将最佳参数与初始参数合并,用于训练模型
bst = lgb.train({**params, **best_params}, train_data, num_boost_round=500, valid_sets=[test_data], callbacks=callbacks)# 预测
y_pred = bst.predict(X_test, num_iteration=bst.best_iteration)
# 将预测概率转换为二进制分类结果(0或1)
y_pred_binary = (y_pred > 0.5).astype(int)# 评估模型的准确率
accuracy = accuracy_score(y_test, y_pred_binary)
# 评估模型的ROC AUC值
roc_auc = roc_auc_score(y_test, y_pred)# 打印模型的准确率和ROC AUC值
print(f'Accuracy: {accuracy:.4f}')
print(f'ROC AUC: {roc_auc:.4f}')# 计算混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred_binary)
# 使用Seaborn绘制混淆矩阵的热力图
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=data.target_names, yticklabels=data.target_names)
# 设置混淆矩阵的标签
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
# 显示混淆矩阵图
plt.show()# 绘制特征重要性图,按分裂次数排序
lgb.plot_importance(bst, max_num_features=10, importance_type='split')
plt.title('Feature Importance (split)')
# 显示特征重要性图
plt.show()# 绘制特征重要性图,按信息增益排序
lgb.plot_importance(bst, max_num_features=10, importance_type='gain')
plt.title('Feature Importance (gain)')
# 显示特征重要性图
plt.show()# 输出
'''
Accuracy: 0.9737
ROC AUC: 0.9951
'''

可视化输出结果:

图片

此时模型已经能较为准确地对测试集中的114个细胞样本进行分类。模型正确地将41个恶性肿瘤分类为恶性肿瘤。右上角2表示模型将2个恶性肿瘤错误地分类为良性肿瘤。左下角1表示模型将1个良性肿瘤错误地分类为恶性肿瘤。右下角70表示模型正确地将70个良性肿瘤分类为良性肿瘤。

图片

图片

对比与分析

  • 模型表现

    • 新模型的混淆矩阵显示分类性能显著提升。准确率、精确率、召回率和F1分数都比之前有明显改善。

    • 新模型的准确率为97.37%,比之前的模型(62.3%)高得多,说明调整参数后的模型性能更好。

  • 特征重要性

    • 按分裂次数:新的特征重要性图显示了一些特征,如worst_texturemean_texture被频繁用于分裂,表明这些特征对模型决策有重要影响。

    • 按信息增益:信息增益高的特征如worst_concave_pointsmean_concave_points对模型决策贡献巨大,表明这些特征在关键节点上提供了大量信息。

以上内容总结自网络,如有帮助欢迎转发,我们下次再见!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/43355.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

虚拟现实3d场景漫游体验实现了“所见即所得”

如今,从实体店铺到工厂企业,再到政府单位,各行各业都已纷纷加入VR数字化升级的行列,相比传统的2D商品展示,三维交互展示成为商企客户交流的主流方式。产品展示、服务介绍、考察洽谈等都可以通过在3D虚拟场景网站中真实…

7月学术会议:7月可投的EI国际会议

随着科技的迅猛发展,学术交流与研讨成为了推动科研进步的重要途径。进入7月,众多高质量的EI国际会议纷纷拉开帷幕,为全球的科研工作者提供了一个展示研究成果、交流学术思想的平台。以下,我们将详细介绍一些在7月可投的EI国际会议…

Chromium编译指南2024 Linux篇-安装官方工具depot_tools(二)

1.引言 在上一节中,我们已经完成了 Git 的安装,并了解了其在 Chromium 编译过程中的重要性。接下来,我们将继续进行环境的配置,首先是安装和配置 Chromium 编译所需的重要工具——depot_tools。 depot_tools 是一组用于获取、管…

你最近想通了什么事情?这10条职场经验帮助你活得更通透

1别总当老好人 记得刚步入职场那会儿,我简直是“老好人”的代名词。 无论是同事的额外任务,还是朋友的小忙,我总是二话不说就接下来,结果自己累得半死,换来的却是别人的理所当然和偶尔的忽视。 直到有一次&#xff…

云计算【第一阶段(27)】DHCP原理与配置以及FTP的介绍

一、DHCP工作原理 1.1、DHCP概念 动态主机配置协议 DHCP(Dynamic Host Configuration Protocol,动态主机配置协议,该协议允许服务器向客户端动态分配 IP 地址和配置信息。 DHCP协议支持C/S(客户端/服务器)结构&…

break 和 continue 的区别与用法

break 和 continue 的区别与用法 1、break 语句2、continue 语句3、总结 💖The Begin💖点点关注,收藏不迷路💖 在JAVA中,break 和 continue 是两种常用的控制流语句,它们主要用于在循环结构中改变程序的执行…

Nacos 进阶篇---集群:选举心跳健康检查劳动者(九)

一、引言 本章将是我们第二阶段,开始学习集群模式下,Nacos 是怎么去操作的 ? 本章重点: 在Nacos服务端当中,会去开启健康心跳检查定时任务。如果是在Nacos集群下,大家思考一下,有没有必要所有的…

无人直播系统源码开发:功能~优势~开发方法

自动直播通常是指通过自动化技术来实现实时内容分发的过程,它结合了流媒体技术和人工智能(如机器学习)。以下是自动直播实现的基本步骤: 内容采集:通过摄像头、手机等设备捕捉实时画面,并通过编码将其转换成…

rocketmq主从切换测试

服务器 192.168.1.23 nameserver、broker nameserver、brokerA,brokerB 192.168.1.35 nameserver、broker nameserver、brokerA,brokerB 192.168.1.88 nameserver nameserver 主从切换 关闭master:等待几秒钟23成为新的master slave消费测…

超市收银系统源码

今天给大家分享一套线上线下打通的收银系统,安卓/win双端线下收银台,可DIY、多模板的三端线上小程序商城,除此之外ERP进销存管理、商品管理、会员营销都很完善。 重点是系统支持OEM贴牌独立部署和全开源源码,非常适合一些正在寻找…

南航秋招指南,线上测评和线下考试

南航秋招简介 南航作为国内一流的航空公司,对人才的需求量非常旺盛,每年也有很多专业对口的工作提供给应届毕业生,对于应届毕业生而言,一定要抓住任何一个应聘机会,并且在规定的范围内进行简历的提交,以便…

CSS content 计数器

CSS content 计数器 CSS 计数器通过一个变量来设置,根据规则递增变量。 使用计数器自动编号 CSS 计数器根据规则来递增变量。 CSS 计数器使用到以下几个属性: counter-reset - 创建或者重置计数器,给计算器命名。注意声明计算器不能在自身…

孕产妇(产科)管理信息系统源码 三甲医院产科电子病历系统成品源代码

孕产妇(产科)管理信息系统源码 三甲医院产科电子病历系统成品源代码 医院智慧孕产是一种通过信息化手段,实现孕产期宣教、健康服务的院外延伸,对孕产妇健康管理具有重要意义,是医院智慧服务水平和能力的体现。实行涵盖婚前检查、孕期保健、产后康复的一…

如何把harmonos项目修改为openharmony项目

一开始分不清harmonyos和openharmony,在harmonyos直接下载的开发软件,后面发现不对劲,打脑阔 首先你要安装对应版本的开发软件,鸿蒙开发是由harmonyos和openharmony官网两个的,找到对应的地方下载对应版本的开发软件&…

C#-反射

一、概念 反射(Reflection)在C#中是一种非常重要的特性,它为开发者提供了在运行时获取和操作关于类型、成员、属性、方法等的详细信息的能力。通过反射,开发者可以在程序运行期间动态地创建对象、调用方法、设置属性值以及进行其…

【Java开发实训】day01

目录 1.Java开发步骤 2.目录的三个表达方法 3.Java的三种注释方法 4.文档注释的作用 🌈嗨!我是Filotimo__🌈。很高兴与大家相识,希望我的博客能对你有所帮助。 💡本文由Filotimo__✍️原创,首发于CSDN&…

运维锅总详解数据一致性

本文首先对数据一致性进行简要说明,然后画图分析展示9种数据一致性协议的工作流程,最后给出实现这9种协议的例子。希望对您理解数据一致性有所帮助! 一、数据一致性简介 数据一致性是数据库和分布式系统中的一个关键概念,它确保…

【Mac】Folder Icons for mac(文件夹个性化图标修改软件)软件介绍

软件介绍 Folder Icons for Mac 是一款专为 macOS 设计的应用程序,主要用于个性化和定制你的文件夹图标。以下是它的主要特点和使用方法: 主要特点: 个性化文件夹图标 Folder Icons for Mac 允许用户为 macOS 上的任何文件夹定制图标。你…

怎样优化 PostgreSQL 中对布尔类型数据的查询?

文章目录 一、索引的合理使用1. 常规 B-tree 索引2. 部分索引 二、查询编写技巧1. 避免不必要的类型转换2. 逻辑表达式的优化 三、表结构设计1. 避免过度细分的布尔列2. 规范化与反规范化 四、数据分布与分区1. 数据分布的考虑2. 表分区 五、数据库参数调整1. 相关配置参数2. 定…