机器学习 | 计算分类算法的ROC和AUC曲线以随机森林为例

受试者工作特征(ROC)曲线和曲线下面积(AUC)是常用的分类算法评价指标,本文将讨论如何计算随机森林分类器的ROC 和 AUC。

ROC 和 AUC是量化二分类区分阳性和阴性类别能力的度量。ROC曲线是针对不同分类阈值的真阳性率(TPR)对假阳性率(FPR)的图。TPR是真阳性与阳性示例总数的比率,而FPR是假阳性与阴性示例总数的比率。AUC是ROC曲线下面积,范围为0.0至1.0,值越高表示分类器性能越好。

具体步骤

1.导入所需模块

from sklearn.ensemble import RandomForestClassifier 
from sklearn.metrics import roc_curve, roc_auc_score 
from sklearn.datasets import load_breast_cancer 
from sklearn.model_selection import train_test_split 
import matplotlib.pyplot as plt

这里我们导入所需的模块,包括分别来自sklearn.ensemble和sklearn.metrics模块的RandomForestClassifier和roc_curve函数。我们还从sklearn.datasets模块导入load_breast_cancer函数来加载乳腺癌数据集,并从sklearn.model_selection模块导入train_test_split函数来将数据集拆分为训练集和测试集。最后,我们从matplotlib库中导入pyplot模块来绘制ROC曲线。

2.加载并拆分数据集

加载数据集并分离特征和目标值,然后拆分训练和测试数据集。

df = load_breast_cancer(as_frame=True) 
df = df.frame x = df.drop('target',axis=1) 
y = df[['target']] 
# Split the dataset into training and test sets 
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3)

3.训练随机森林分类器

# Train a Random Forest classifier 
rf = RandomForestClassifier(n_estimators=5, max_depth=2) 
rf.fit(X_train, y_train)

在这里,我们使用RandomForestClassifier函数训练一个随机森林分类器,其中包含5个估计量和最大深度2。我们使用拟合方法将分类器拟合到训练数据。

4.获取测试集的预测类概率

# Get predicted class probabilities for the test set 
y_pred_prob = rf.predict_proba(X_test)[:, 1] 

在这里,我们使用随机森林分类器的predict_proba方法来获得测试集的预测类概率。该方法返回一个形状数组(n_samples,n_classes),其中n_samples是测试集中的样本数,n_classes是问题中的类数。因为我们使用的是二元分类器,所以n_classes等于2,我们感兴趣的是正类的概率。 这是数组的第二列。因此,我们使用 [:,1] 索引来获得正类概率的一维数组。

5.计算不同分类阈值的假阳性率(FPR)和真阳性率(TPR)

# Compute the false positive rate (FPR) 
# and true positive rate (TPR) for different classification thresholds 
fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob, pos_label=1)

在这里,我们使用sklearn.metrics模块中的roc_curve函数来计算不同分类阈值的假阳性率(FPR)和真阳性率(TPR)。该函数将测试集的真标签(y_test)和阳性类的预测类概率(y_pred_prob)作为输入。它返回三个数组:fpr,其包含不同阈值的FPR值; tpr,其包含不同阈值的TPR值;以及thresholds,其包含阈值。

6.计算ROC AUC评分

# Compute the ROC AUC score 
roc_auc = roc_auc_score(y_test, y_pred_prob) 
roc_auc

输出

0.9787264420331239

这里我们使用sklearn.metrics模块中的roc_auc_score函数来计算ROC AUC分数。该函数将测试集的真标签(y_test)和阳性类的预测类概率(y_pred_prob)作为输入。它返回表示ROC曲线下面积的标量值。

7.绘制ROC曲线

# Plot the ROC curve 
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc) 
# roc curve for tpr = fpr 
plt.plot([0, 1], [0, 1], 'k--', label='Random classifier') 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate') 
plt.title('ROC Curve') 
plt.legend(loc="lower right") 
plt.show()

在这里插入图片描述
这里我们使用pyplot模块的plot函数来绘制ROC曲线。我们在x轴上传递FPR值,在y轴上传递TPR值。我们还将ROC AUC评分作为面积添加到图中。我们绘制虚线来表示随机分类器,其具有从(0,0)到(1,1)的直线的ROC曲线。我们为图添加轴标签和标题,以及显示ROC AUC得分和随机分类器线的图例。

说明:
ROC曲线是对于不同分类阈值,y轴上的真阳性率(TPR)对x轴上的假阳性率(FPR)的图。ROC曲线显示了分类器在不同阈值下区分阳性和阴性类别的能力。一个完美的分类器的TPR为1,FPR为0,对应于图的左上角。另一方面,随机分类器将具有从(0,0)到(1,1)的直线的ROC曲线,这是图中的虚线。ROC曲线越接近左上角,分类器的性能越好。

ROC曲线可用于选择分类器的最佳阈值,这取决于TPR和FPR之间的权衡。接近1的阈值将具有较低的FPR但较高的TPR,而接近0的阈值将具有较高的FPR但较低的TPR。

8.绘制预测类概率

# Plot the predicted class probabilities 
plt.hist(y_pred_prob, bins=10) 
plt.xlim(0, 1) 
plt.title('Histogram of predicted probabilities') 
plt.xlabel('Predicted probability of Setosa') 
plt.ylabel('Frequency') 
plt.show() 

在这里插入图片描述

多分类的ROC曲线示例

这里使用sklearn.datasets的iris数据集,它有3个类。ROC曲线可用于二分类,因此,这里我们将使用来自sklearn.multiclass的OneVsRestClassifier和Random forest作为分类器,绘制ROC曲线。

from sklearn.ensemble import RandomForestClassifier 
from sklearn.metrics import roc_curve, roc_auc_score 
from sklearn.datasets import load_iris 
from sklearn.multiclass import OneVsRestClassifier 
from sklearn.model_selection import train_test_split 
import matplotlib.pyplot as plt # Load the iris dataset 
iris = load_iris() # Split the dataset into training and test sets 
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.5, random_state=23) # Train a Random Forest classifier 
clf = OneVsRestClassifier(RandomForestClassifier()) # fit model 
clf.fit(X_train, y_train) # Get predicted class probabilities for the test set 
y_pred_prob = clf.predict_proba(X_test) # Compute the ROC AUC score 
roc_auc = roc_auc_score(y_test, y_pred_prob, multi_class='ovr') 
print('ROC AUC Score :',roc_auc) # roc curve for Multi classes 
colors = ['orange','red','green'] 
for i in range(len(iris.target_names)):	 fpr, tpr, thresh = roc_curve(y_test, y_pred_prob[:,i], pos_label=i) plt.plot(fpr, tpr, linestyle='--',color=colors[i], label=iris.target_names[i]+' vs Rest') 
# roc curve for tpr = fpr 
plt.plot([0, 1], [0, 1], 'k--', label='Random classifier') 
plt.title('Multiclass (Iris) ROC curve') 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive rate') 
plt.legend() 
plt.show()

输出

ROC AUC Score : 0.9795855072463767

在这里插入图片描述

总结

总之,计算随机森林分类器的ROC AUC分数在Python中是一个简单的过程。sklearn.metrics模块提供了计算ROC曲线、ROC AUC评分和PR曲线的函数。ROC曲线和PR曲线是评估二值分类器性能的有用工具,它们可以帮助基于不同评估指标之间的权衡来选择分类器的最佳阈值。

PR(precision-recall)曲线是二元分类问题的另一个评估指标。PR曲线是针对不同分类阈值的精确度(y轴)对召回率(x轴)的图。精确度被定义为真阳性的数量除以真阳性加假阳性的数量,而召回率被定义为真阳性的数量除以真阳性加假阴性的数量。PR曲线显示了分类器在最小化误报的同时预测阳性类别的能力。

与ROC曲线相比,PR曲线更适合不平衡数据集,其中阳性类别中的样本数量远小于阴性类别中的样本数量。当假阳性和假阴性的成本不同时,PR曲线也很有用,因为它可以帮助基于精确度-召回率权衡为分类器选择最佳阈值。

重要的是要注意,ROC AUC不应该是用于评估分类器性能的唯一度量。其他指标,如精确度、召回率和F1分数,也可能有用,具体取决于应用程序的具体要求。此外,重要的是要考虑数据中正面和负面示例的总体分布以及不平衡类对评估指标的潜在影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/50102.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LabVIEW座舱照明测控系统

用LabVIEW开发飞机座舱照明测控系统。系统通过集成可靠的硬件与软件技术,提高了测试的效率和自动化水平,确保了飞行安全性和舒适性。体现了系统的设计思路、主要组成部分、工作原理及实际应用效果。 项目背景 飞机座舱照明系统是航空电子系统中至关重要…

【Spring Boot教程:从入门到精通】掌握Spring Boot开发技巧与窍门(三)-配置git环境和项目创建

主要介绍了如何创建一个Springboot项目以及运行Springboot项目访问内部的html页面!!! 文章目录 前言 配置git环境 创建项目 ​编辑 在SpringBoot中解决跨域问题 配置Vue 安装Nodejs 安装vue/cli 启动vue自带的图形化项目管理界面 总结 前言 …

谷粒商城实战笔记-63-商品服务-API-品牌管理-OSS获取服务端签名

文章目录 一,创建第三方服务模块thrid-party1,创建一个名为gulimall-third-party的模块2,nacos上创建third-party命名空间,用来管理这个服务的所有配置3,配置pom文件4,配置文件5,单元测试6&…

oracle登录报“ORA-27101: shared memory realm does not exist”

oracle登录报“ORA-27101: shared memory realm does not exist” 问题: 1、使用ip:1521/服务名方式连库报错" ORA-27101: shared memory realm does not exist Linux-x86_64 Error: 2: No such file or directory" 2、sqlplus XX/密码 可以登录数据库 …

【Apache Doris】数据副本问题排查指南

【Apache Doris】数据副本问题排查指南 一、问题现象二、问题定位三、问题处理 本文主要分享Doris中数据副本异常的问题现象、问题定位以及如何处理此类问题。 一、问题现象 问题日志 查询报错 Failed to initialize storage reader, tablet{tablet_id}.xxx.xxx问题说明 查…

c++ 内存管理(newdeletedelete[])

因为在c里面新增了类,所以我们在有时候会用malloc来创建类,但是这种创建只是单纯的开辟空间,没有什么默认构造的。同时free也是free的表面,如果类里面带有指针指向堆区的成员变量就会free不干净。 所以我们c增加了new delete和de…

HTML常用的转义字符——怎么在网页中写“<div></div>”?

一、问题描述 如果需要在网页中写“<div></div>”怎么办呢&#xff1f; 使用转义字符 如果直接写“<div></div>”&#xff0c;编译器会把它翻译为块&#xff0c;类似的&#xff0c;其他的标签也是如此&#xff0c;所以如果要在网页中写类似于“<div…

LeetCode_122(买卖股票的最佳时机)

public int maxProfit(int[] prices) {int ans 0;//int prices[] {7,1,5,3,6,4};for(int i1;i<prices.length;i){ansMath.max(0,prices[i]-prices[i-1]);}return ans;}

Unity DOTS中的world

Unity DOTS中的world 注册销毁逻辑自定义创建逻辑创建world创建system group插入player loopReference DOTS中&#xff0c;world是一组entity的集合。entity的ID在其自身的世界中是唯一的。每个world都拥有一个EntityManager&#xff0c;可以用它来创建、销毁和修改world中的en…

[Spring] MyBatis操作数据库(基础)

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏: &#x1f9ca; Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 &#x1f355; Collection与…

Python酷库之旅-第三方库Pandas(045)

目录 一、用法精讲 156、pandas.Series.count方法 156-1、语法 156-2、参数 156-3、功能 156-4、返回值 156-5、说明 156-6、用法 156-6-1、数据准备 156-6-2、代码示例 156-6-3、结果输出 157、pandas.Series.cov方法 157-1、语法 157-2、参数 157-3、功能 15…

分布式系统常见软件架构模式

常见的分布式软件架构 Peer-to-Peer (P2P) PatternAPI Gateway PatternPub-Sub (Publish-Subscribe)Request-Response PatternEvent Sourcing PatternETL (Extract, Transform, Load) PatternBatching PatternStreaming Processing PatternOrchestration Pattern总结 先上个图&…

.h264 .h265 压缩率的直观感受

1.资源文件 https://download.csdn.net/download/twicave/89579327 上面是.264 .265和原始的YUV420文件&#xff0c;各自的大小。 2.转换工具&#xff1a; 2.1 .h264 .h265互转 可以使用ffmpeg工具&#xff1a;Builds - CODEX FFMPEG gyan.dev 命令行参数&#xff1a; …

liteos定时器回调时间过长造成死机问题解决思路

项目需求 原代码是稳定的&#xff0c;现我实现EMQ平台断开连接的时候&#xff0c;把HSL的模拟点位数据采集到网关&#xff0c;然后存入Flash&#xff0c;当EMQ平台连接的时候&#xff0c;把Flash里面的点位数据放在消息队列里面&#xff0c;不影响实时采集。 核心1&#xff1a…

godot新建项目及设置外部编辑器为vscode

一、新建项目 初次打开界面如下所示&#xff0c;点击取消按钮先关闭掉默认弹出的框 点击①新建弹出中间的弹窗②中填入项目的名称 ③中设置项目的存储路径&#xff0c;点击箭头所指浏览按钮&#xff0c;会弹出如下所示窗口 根据图中所示可以选择或新建自己的游戏存储路径&…

鸿蒙(HarmonyOS)自定义Dialog实现时间选择控件

一、操作环境 操作系统: Windows 11 专业版、IDE:DevEco Studio 3.1.1 Release、SDK:HarmonyOS 3.1.0&#xff08;API 9&#xff09; 二、效果图 三、代码 SelectedDateDialog.ets文件/*** 时间选择*/ CustomDialog export struct SelectedDateDialog {State selectedDate:…

Linux系统上安装Redis

百度网盘&#xff1a; 通过网盘分享的文件&#xff1a;redis_linux 链接: https://pan.baidu.com/s/1ZcECygWA15pQWCuiVdjCtg?pwd8888 提取码: 8888 1.把安装包拖拽到/ruanjian/redis/文件夹中&#xff08;自己选择&#xff09; 2.进入压缩包所在文件夹&#xff0c;解压压缩…

ROM修改进阶教程------修改rom 开机自动安装指定apk 自启脚本完整步骤解析

rom修改的初期认识 在解包修改系统分区过程中。很多客户需求刷完rom后自动安装指定apk。这种与内置apk有区别。而且一些极个别apk无法内置。今天对这种修改rom刷入机型后第一次启动后自动安装指定apk的需求做个步骤解析。 在前期博文中我有做过说明。官方系统固件解…

按图搜索新体验:阿里巴巴拍立淘API返回值详解

阿里巴巴拍立淘API是一项基于图片搜索的商品搜索服务&#xff0c;它允许用户通过上传商品图片&#xff0c;系统自动识别图片中的商品信息&#xff0c;并返回与之相关的搜索结果。以下是对阿里巴巴拍立淘API返回值的详细解析&#xff1a; 一、主要返回值内容 商品信息 商品列表…

Java面试题(每日更新)

每日五道&#xff01;学会就去面试&#xff01; 本文的宗旨是为读者朋友们整理一份详实而又权威的面试清单&#xff0c;下面一起进入主题吧。 目录 1.概述 2.Java 基础 2.1 JDK 和 JRE 有什么区别&#xff1f; 2.2 和 equals 的区别是什么&#xff1f; 2.3 两个对象的…