分类模型评估:混淆矩阵与ROC曲线

  • 1.混淆矩阵
  • 2.ROC曲线 & AUC指标

理解混淆矩阵和ROC曲线之前,先区分几个概念。对于分类问题,不论是多分类还是二分类,对于某个关注类来说,都可以看成是二分类问题,当前的这个关注类为正类,所有其他非关注类为负类。因为样本的真实值有正负两类,而模型的预测值也有正负两类,因此样本的真实值和模型的预测值之间产生了下面4种组合:

  • 真正例(True Positives/TP):在所有真实值为正类的样本中,模型预测值也为正类的样本数。
  • 假正例(False Positives/FP):在所有真实值为负类的样本中,模型预测值为正类的样本数。
  • 真负例(True Negatives/TN):所有真实值为负类的样本中,模型预测值也为负类的样本数。
  • 假负例(False Negatives/FN):所有真实值为正类的样本中,模型预测值为负类的样本数。

从上面几个定义可以知道:
1)样本总数 = TP+FP+TN+FN
2)所有真实值为正类的样本总数 = TP+FN
3)所有真实值为负类的样本总数 = TN+FP

1.混淆矩阵

使用sklearn自带的鸢尾花数据集,数据集里鸢尾花包含3个分类。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix# 获取特征值与目标值
data = load_iris()
X, y = data['data'], data['target']# 自带的数据集分类准确率为1,为了后面更好的基于混淆矩阵验证相关指标的计算,为训练集添加均值0,标准差2的高斯噪声
np.random.seed(42)
noise = np.random.normal(0, 2, (len(X), len(X[0])))
X += noise# 特征值归一化到区间[-1,1]
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 创建逻辑回归模型、训练并预测
model = LogisticRegression(multi_class='multinomial', max_iter=1000)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 获取模型混淆矩阵、分类报告、准确率
print(f"混淆矩阵:\n{confusion_matrix(y_test, y_pred)}")
print(f"分类报告:\n{classification_report(y_test, y_pred)}")
print(f"准确率:\n{accuracy_score(y_test, y_pred)}")

output:
在这里插入图片描述

混淆矩阵中,横向表示真实值,纵向表示预测值。比如第一个位置7,表示实际类别为0且预测类别为0的样本有7个。基于此混淆矩阵,可以衍生下面相关指标:
1)准确率(accuracy):准确率表示模型对一个样本类别预测正确的可能性,是相对整体来说的。计算方式为所有预测正确的样本(斜对角线之和)/ 样本总数,本例中accuracy=(7+4+8)/30=0.63…
2)精确率(precision):精确率是针对某个具体关注类来说的,精确率关注的是,对于所有预测值为该类的样本中,真实值也属于该类的样本所占的比例,计算公式为 T P T P + F P \frac{TP}{TP+FP} TP+FPTP。比如对于类别0,模型预测的该类样本数=(7+0+2)=9,而真实值为该类的样本数为7,那么类别0的precision=7/9=0.77…。精确率反映了“模型找的对不对”
3)召回率(recall):同样召回率也是针对某个具体关注类来说的,关注的是所有真实值为该类的样本中,模型能正确预测为该类的样本所占的比例,计算公式 T P T P + F N \frac{TP}{TP+FN} TP+FNTP。还是拿类别0来说,真实值为0的样本总数=(7+3+0)=10,模型能正确预测为该类的样本总数为7,所以类别0的召回率=7/10=0.7。召回率代表了“模型找的全不全”
4)F1-score:F1分数是精确率与召回率的调和平均数,计算公式为 2 ∗ p r e c i s i o n ∗ r e c a l l p r e c i s i o n + r e c a l l \frac{2*precision*recall}{precision+recall} precision+recall2precisionrecall。对于类别0的F1-score= 2 ∗ 0.78 ∗ 0.7 0.78 + 0.7 \frac{2*0.78*0.7}{0.78+0.7} 0.78+0.720.780.7=0.737…,F1分数用来表示模型在关注的类上识别正类的综合表现,最大值1表示分类效果最好完全正确,最小值0表示分类效果最差完全错误。

分类报告直接提供了每个分类下的精确率、召回率、F1分数等指标。最下面两行的macro avgweighted avg分别表示对每个指标的算术平均和加权平均,最后一列的support表示对应的样本数量。

2.ROC曲线 & AUC指标

ROC:Receiver Operating Characteristic。
AUC:Area Under the [ROC] Curve,ROC曲线下的面积。

ROC曲线的绘制中需要用到两个指标:

  • 真正率(True Positive Rate/TPR):在所有真实类别为正类的样本中,模型正确识别为正类的样本所占的比例,也就是把正类样本识别成正类样本的概率。反映了模型识别正类的能力,可以看成是模型在识别正类样本时的收获能力,计算公式 T P T P + F N \frac{TP}{TP+FN} TP+FNTP
  • 假正率(False Positive Rate/FPR):在所有真实类别为负类的样本中,模型错误识别为正类的样本所占的比例,即把负类样本识别为正类样本的概率。反映了模型识别为正类样本时的错误程度,可以理解成模型在识别正类样本时付出的代价,计算公式 F P T N + F P \frac{FP}{TN+FP} TN+FPFP

大多数分类模型都是通过计算出每个样本属于正类的概率,和属于正类的概率阈值进行比较来对样本进行分类的。正类的概率>=阈值,判定为正类,反之判定为负类。

ROC曲线是由不同概率阈值下真正率(y轴)和假正率(x轴)对应的一系列点所构成的曲线,x轴从左到右判定为正类的概率阈值从1到0逐渐递减。ROC曲线用来描述二分类模型预测效果,对于多分类问题,是将关注类视为正类,其他类视为负类。

ROC曲线的具体绘制过程可以理解为:

  1. 对于测试集中的每个样本,利用分类器预测其为正类的概率值。
  2. 将这些概率值按照从大到小的顺序排列,作为阈值。
  3. 对于每个阈值,分别计算真正率和假正率,对应坐标轴上的一个点。
  4. 连接这些点。

从真正率和假正率的计算,可以看出曲线越往右,判定为正类的概率阈值越低,那么就有更多的样本被归类到正类当中,因为分母是不变的,分子(TP/FP)随着正类样本增多都会逐渐增大,因此ROC的曲线走势应该是一个从(0, 0)到(1, 1)逐渐上升的曲线。

同时,因为x轴代表了在识别正类时付出的代价,y轴代表了在识别正类时的收获,因此当x值越小,y值越大,即曲线越靠近左上角(0, 1),说明模型的分类效果越好。

而AUC,是ROC曲线下的面积,它衡量的是模型在所有概率阈值下识别正类时“收获”与“代价”的比重,因此AUC值越大越好,值域范围[0, 1]。
AUC=0.5:模型不具有分类效果,相当于盲猜。
AUC<0.5:分类效果最差,不如盲猜。
AUC>0.5:有一定的分类效果,值越接近1分类效果越好。

下面还是以鸢尾花的数据集为例,通过一个demo对ROC和AUC进行计算和绘制。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, auc# 获取特征值与目标值
data = load_iris()
X, y = data['data'], data['target']# 仅使用两个类别:0 & 1
X = X[y != 2]
y = y[y != 2]# 训练集添加噪声
np.random.seed(42)
noise = np.random.normal(0, 2, (len(X), len(X[0])))
X += noise# 归一化
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)# 划分数据集、创建逻辑回归模型、训练
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
model = LogisticRegression(multi_class='multinomial', max_iter=1000)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 获取每个样本预测为正类样本的概率,[:, 0]是负类样本的概率
y_pred_prob = model.predict_proba(X_test)[:, 1]# 计算FPR、TPR和AUC值
fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)# 绘制ROC曲线
plt.figure()
plt.plot(fpr, tpr, color='green', lw=1, label=f'ROC Curve (AUC={roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='red', lw=1, linestyle='--')
plt.xlim([0.0, 1])
plt.ylim([0.0, 1.05])
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()

output:
在这里插入图片描述
绿线代表ROC曲线,红线相当于盲猜,绿线在红线上方距离红线越远模型分类效果越好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/770662.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras实战演绎机器学习 希望政安晨的博客能够对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff01; 介绍 通过 Keras&#xff0c;您可以编写自定…

位段详细解释

结构体位段的使用原则 在C语言中&#xff0c;结构体&#xff08;Struct&#xff09;是一种复合数据类型&#xff0c;它允许我们将多个不同类型的数据项组合成一个单一的实体。位段&#xff08;Bit Field&#xff09;是结构体中的一个特殊成员&#xff0c;它允许我们只取结构体…

常用中间件redis,kafka及其测试方法

常用消息中间件及其测试方法 一、中间件的使用场景引入中间件的目的一般有两个&#xff1a;1、提升性能常用的中间件&#xff1a;1) 高速缓存&#xff1a;redis2) 全文检索&#xff1a;ES3) 存日志&#xff1a;ELK架构4) 流量削峰&#xff1a;kafka 2、提升可用性产品架构中高可…

Spring Cloud 网关Gateway + 配置中心

网关 网络的接口&#xff0c;负责请求的路由、转发、身份校验 路由&#xff1a;告诉请求去哪找 转发&#xff1a;请求找不到直接带请求过去 路由及转发 判断前端请求的规则就这么配 当前情况下只需要访问8080端口 就可以完成对全部微服务的访问 路由属性 登录校验 没必要在每…

sonar+gitlab提交阻断 增量扫描

通过本文&#xff0c;您将可以学习到 sonarqube、git\gitlab、shell、sonar-scanner、sonarlint 一、前言 sonarqube 是一款开源的静态代码扫描工具。 实际生产应用中&#xff0c;sonarqube 如何落地&#xff0c;需要考虑以下四个维度&#xff1a; 1、规则的来源 现在规则的…

java一和零(力扣Leetcode474)

一和零 力扣原题 给定一个二进制字符串数组 strs 和两个整数 m 和 n&#xff0c;请你找出并返回 strs 的最大子集的长度&#xff0c;该子集中最多有 m 个 0 和 n 个 1。 示例 1&#xff1a; 输入&#xff1a;strs [“10”, “0001”, “111001”, “1”, “0”], m 5, n …

【msyql】mysqldump: 未找到命令...

使用mysqldump备份数据库出现错误提示&#xff1a; mysqldump: 未找到命令... 执行的命令如下&#xff1a; mysqldump -uroot -proot --databases db_user > /home/backups/databackup.sql 解决方法 确认mysql是否安装 查看mysql版本 mysql --version 查找mysql安装路…

php反序列化刷题1

[SWPUCTF 2021 新生赛]ez_unserialize 查看源代码想到robots协议 看这个代码比较简单 直接让adminadmin passwdctf就行了 poc <?php class wllm {public $admin;public $passwd; }$p new wllm(); $p->admin "admin"; $p->passwd "ctf"; ec…

极光笔记|极光消息推送服务的云原生实践

摘要 极光始终秉承“以开发者为中心”的战略导向&#xff0c;极光推送&#xff08;JPush&#xff09;是国内领先的消息推送服务。极光推送&#xff08;JPush&#xff09;本质上是一种软件付费应用程序&#xff0c;结合当前主流云厂商基础施设&#xff0c;逐渐演进成了云上SaaS…

Java后端设置服务器允许跨域

文章目录 1、实现2、一些问题关于各项请求头的作用关于预检请求 3、一些补充4、疑问点 1、实现 以下通过servlet的Filter给所有响应的header加了一些跨域相关的数据&#xff0c;以实现允许跨域。 import org.springframework.context.annotation.Configuration; import org.s…

数据可视化基础与应用-04-seaborn库从入门到精通01-02

总结 本系列是数据可视化基础与应用的第04篇seaborn&#xff0c;是seaborn从入门到精通系列第1-2篇。本系列的目的是可以完整的完成seaborn从入门到精通。主要介绍基于seaborn实现数据可视化。 参考 参考:数据可视化-seaborn seaborn从入门到精通01-seaborn介绍与load_datas…

RabbitMQ3.x之二_RabbitMQ所有端口说明及开启后台管理功能

RabbitMQ3.x之二_RabbitMQ所有端口说明及开启后台管理功能 文章目录 RabbitMQ3.x之二_RabbitMQ所有端口说明及开启后台管理功能1. RabbitMQ端口说明2. 开启Rabbitmq后台管理功能1. 查看rabbitmq已安装的插件2. 开启rabbitmq后台管理平台插件3. 开启插件后&#xff0c;再次查看插…

RSTP环路避免实验(华为)

思科设备参考&#xff1a;RSTP环路避免实验&#xff08;思科&#xff09; 一&#xff0c;技术简介 RSTP (Rapid Spanning Tree Protocol) 是从STP发展而来 • RSTP标准版本为IEEE802.1w • RSTP具备STP的所有功能&#xff0c;可以兼容STP运行 • RSTP和STP有所不同 减少了…

Tomcat下载安装以及配置

一、Tomcat介绍 二、Tomcat下载安装 进入tomcat官网&#xff0c;https://tomcat.apache.org/ 1、选择需要下载的版本&#xff0c;点击下载 下载路径一定要记住&#xff0c;并且路径中尽量不要有中文 8、9、10都可以&#xff0c;本博文以8为例 2、将下载后的安装包解压到指定位…

linux-开发板移植MQTT

将源码复制到共享文件夹 链接&#xff1a;https://pan.baidu.com/s/1kvvO-HhDMDXkQ_wlNtyW_A?pwd332i 提取码&#xff1a;332i 以下步骤教程里都写了&#xff0c;我这里边进行&#xff0c;方便大家对照 pc端 1.进入mqtt_lib, 解压open压缩包 2.按照教程复制这一句并运行&…

服务端应用多级缓存架构方案

服务端应用多级缓存架构方案 场景 20w的QPS的场景下&#xff0c;服务端架构应如何设计&#xff1f; 常规解决方案 可使用分布式缓存来抗&#xff0c;比如redis集群&#xff0c;6主6从&#xff0c;主提供读写&#xff0c;从作为备&#xff0c;不提供读写服务。1台平均抗3w并…

【算法专题--双指针算法】leecode-15.三数之和(medium)、leecode-18. 四数之和(medium)

&#x1f341;你好&#xff0c;我是 RO-BERRY &#x1f4d7; 致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 &#x1f384;感谢你的陪伴与支持 &#xff0c;故事既有了开头&#xff0c;就要画上一个完美的句号&#xff0c;让我们一起加油 目录 前言1. 三数之和2. 解法&…

选择最佳图像处理工具OpenCV、JAI、ImageJ、Thumbnailator和Graphics2D

文章目录 1、前言2、 图像处理工具效果对比2.1 Graphics2D实现2.2 Thumbnailator实现2.3 ImageJ实现2.4 JAI&#xff08;Java Advanced Imaging&#xff09;实现2.5 OpenCV实现 3、图像处理工具结果 1、前言 SVD(stable video diffusion)开放了图生视频的API&#xff0c;但是限…

Ubuntu deb文件 安装 MySQL

更新系统软件依赖 sudo apt update && sudo apt upgrade下载安装包 输入命令查看Ubuntu系统版本 lsb_release -a2. 网站下载对应版本的安装包 下载地址. 解压安装 mkdir /home/mysqlcd /home/mysqltar -xvf mysql-server_8.0.36-1ubuntu20.04_amd64.deb-bundle.tar# …

【考研数学】张宇最新全年学习包

考研数学冲高分必备&#xff0c;张宇老师肯定榜上有名&#xff01; 考研数学&#xff0c;其实就像一场没有硝烟的战斗。基础题是常规武器&#xff0c;中难题就是重型火炮&#xff0c;而压轴题呢&#xff0c;那就是核弹级别的存在&#xff01;考研的战场&#xff0c;关键就在那…