sklearn中多分类和多标签分类评估方法总结

一、任务区分

  1. 多分类分类任务:在多分类任务中,每个样本只能被分配到一个类别中。换句话说,每个样本只有一个正确的标签。例如,将图像分为不同的物体类别,如猫、狗、汽车等。

  2. 多标签分类任务:在多标签分类任务中,每个样本可以被分配到一个或多个类别中。换句话说,每个样本可以有多个正确的标签。例如,在图像标注任务中,一张图像可能同时包含猫和狗,因此它可以同时被分配到 "猫" 和 "狗" 这两个标签。

二、sklearn评估方式

 1、多分类(multiclass)任务

多分类任务的标签有两种,一种是原始标签,例如[0,1,2],另一种是独热编码的形式,[[1,0,0],[0,1,0],[0,0,1]]

经过模型分类之后的结果一般是各类的预测分数

(1)准确率(Accuracy):是分类正确的样本数与总样本数之比,是最简单的评估方法,但在类别不平衡的情况下可能会失效。

(2)混淆矩阵(Confusion Matrix):是一个 N×N 的矩阵(N 为类别数量),将真实类别与预测类别的对应关系表示出来。基于混淆矩阵可以计算精确率、召回率、F1 分数等指标。

(3)精确率(Precision)召回率(Recall):精确率表示被分类器正确分类的正样本数量与分类器预测为正样本的样本数量之比;召回率表示被分类器正确分类的正样本数量与数据集中所有正样本数量之比。

(4)F1 分数:精确率和召回率的调和平均数,综合考虑了分类器的准确性和完整性。

(5)ROC 曲线和AUC(Area Under the Curve):对于二分类任务,可以绘制ROC曲线,以真正例率(True Positive Rate)作为纵轴,假正例率(False Positive Rate)作为横轴。AUC表示ROC曲线下的面积,是一个评估分类器性能的常用指标。对于多分类任务,通常使用微平均(micro-average)或宏平均(macro-average)来计算AUC。

- 引入模块,并自己定义一下模型输出

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
import numpy as np
import matplotlib.pyplot as plt
import torch# 示例真实标签和预测结果
true_labels = np.array([0, 1, 2, 1, 0, 2, 2, 1, 0, 1])
print("true label",true_labels)
# 生成随机数据作为概率值,实际应用中需要替换为模型的预测概率值
model_output = torch.randn(len(true_labels), 3)
print("model output",model_output)
# 获得最大类别的index
_, predicted_labels = torch.max(model_output, 1)
print("predicted label",predicted_labels)

 示例数据如下:

- 进行模型评估

注意,计算roc_auc时需要将输出概率归一化,否则会报错

ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes

 准确率等的计算用的是模型输出后最大类别的index,而计算roc_auc直接用模型输出的概率,但需要归一化。

# 准确率
accuracy = accuracy_score(true_labels, predicted_labels)
print("Accuracy:", accuracy)
# 混淆矩阵
conf_matrix = confusion_matrix(true_labels, predicted_labels)
print("Confusion Matrix:\n", conf_matrix)
# 分类报告
class_report = classification_report(true_labels, predicted_labels)
print("Classification Report:\n", class_report)
# 精确率
precision = precision_score(true_labels, predicted_labels, average='macro')
print("Precision:", precision)
# 召回率
recall = recall_score(true_labels, predicted_labels, average='macro')
print("Recall:", recall)
# F1 分数
f1 = f1_score(true_labels, predicted_labels, average='macro')
print("F1 Score:", f1)
# ROC AUC
# 计算ROC需要将模型输出概率归一化
prob_new = torch.nn.functional.softmax(model_output, dim=1)
print(prob_new)
roc_auc = roc_auc_score(true_labels, prob_new, average='macro', multi_class='ovo')
print("ROC AUC Score:", roc_auc)

结果:

 

- 独热编码

但如果是对原始的标签数据进行了独热编码,那么在进行准确率等的计算的时候,需要将输出也转化为与独热编码类似的形式,然后再使用sklearn的函数进行计算

from sklearn.preprocessing import label_binarize
# 进行独热编码
true_one_hot = label_binarize(true_labels, classes=np.arange(3))
# 获取每行最大值的索引
max_indices = torch.argmax(model_output, dim=1)
# 创建一个与模型输出相同形状的零张量
predicted_labels = torch.zeros_like(model_output)
# 将每行最大值的位置设为1
predicted_labels[torch.arange(len(max_indices)), max_indices] = 1
print("predicted labels",predicted_labels)accuracy = accuracy_score(true_one_hot, predicted_labels)
print("Accuracy:", accuracy)
roc_auc = roc_auc_score(true_one_hot, model_output, average='macro')
print("ROC AUC Score:", roc_auc)

结果如下:

总之,无论是采用原始标签的形式,还是独热编码的形式,在计算accuracy,recall,precision,F1-score的时候,都需要将模型输出转化为0,1且与真实标签维度一致的格式,而在计算roc的时候,若是独热编码的真实标签,则可以直接用模型输出,但如果不是,就需要归一化概率。

2、多标签(multilabel)分类任务

对于多标签分类,初始的真实标签要用到独热编码

 针对模型输出的概率分数,需要设定一个阈值,大于阈值的标记为1,低于阈值的标记为0,如下代码所示:

import torch
# 示例模型输出
model_output = torch.tensor([[0.8, 0.3, 0.9],[0.2, 0.7, 0.4],[0.9, 0.1, 0.3]])
# 设置阈值
threshold = 0.5
# 将概率分数转换为0-1结果
predicted_labels = (model_output > threshold).float()
print(predicted_labels)

然后直接使用sklearn的函数进行评估

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss, jaccard_score, coverage_error, average_precision_score, roc_auc_score
import numpy as np# 示例标签和预测结果
true_labels = np.array([[1, 0, 1], [0, 1, 1], [1, 1, 0]])
predicted_labels = np.array([[1, 0, 1], [0, 1, 0], [1, 0, 0]])
# 准确率
accuracy = accuracy_score(true_labels, predicted_labels)
print("Accuracy:", accuracy)
# 精确率
precision = precision_score(true_labels, predicted_labels, average='micro')
print("Precision:", precision)
# 召回率
recall = recall_score(true_labels, predicted_labels, average='micro')
print("Recall:", recall)
# F1 分数
f1 = f1_score(true_labels, predicted_labels, average='micro')
print("F1 Score:", f1)
# 平均准确率
average_precision = average_precision_score(true_labels, predicted_labels, average='micro')
print("Average Precision:", average_precision)
# ROC AUC
roc_auc = roc_auc_score(true_labels, predicted_labels, average='micro')
print("ROC AUC Score:", roc_auc)

得到的结果如下:

详细对于多标签分类的指标解释可以参考下面的文章:

sklearn中多标签分类场景下的常见的模型评估指标_51CTO博客_sklearn模型评估icon-default.png?t=N7T8https://blog.51cto.com/liguodong/4290183总的来说,要根据自己的任务和目标来制定合适的评估指标,因为评估指标是实验结果的体现。


都看到了这里了,给个小心心♥呗~ 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/12488.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

助力数字农林业发展服务香榧智慧种植,基于YOLOv5全系列【n/s/m/l/x】参数模型开发构建香榧种植场景下香榧果实检测识别系统

作为一个生在北方但在南方居住多年的人,居然头一次听过香榧(fei)这种作物,而且这个字还不会念,查了以后才知道读音(fei),三声,这着实引起了我的好奇心,我相信…

STM32使用ADC单/多通道检测数据

文章目录 1. STM32单片机ADC功能详解 2. AD单通道 2.1 初始化 2.2 ADC.c 2.3 ADC.h 2.4 main.c 3. AD多通道 3.1 ADC.c 3.2 ADC.h 3.3 main.c 3.4 完整工程文件 1. STM32单片机ADC功能详解 STM32单片机ADC功能详解 2. AD单通道 这个代码实现通过ADC功能采集三脚电…

【Vue2】关于response返回数据的错误小记

关于Vue2中response返回数据的一个错误小记 如图&#xff0c;在这里返回的时候&#xff0c;后端是通过List< String >返回的&#xff0c;response接收到的实际上是一个Array数组&#xff0c;但是赋值给searchedTaskList的时候&#xff0c;需要在.then包括的范围里面赋值给…

【SpringBoot】 什么是springboot(二)?springboot操作mybatisPlus、swagger、thymeleaf模板

文章目录 SpringBoot第三章1、整合mybatsPlus1-234-67-10问题 2、整合pageHelper分页3、MP代码生成器1、编写yml文件2、导入依赖3、创建mp代码生成器4、生成代码5、编写配置类扫描mapper类6、编写控制器类 4、swagger1、什么是swagger2、作用3、发展历程4、一个简单的swagger项…

ElastiCache Serverless for Redis应用场景和性能成本分析

一. 前言 传统基于实例节点的 Redis 缓存架构中&#xff0c;扩展性是一个重要影响因素。在很多场景中&#xff0c;例如广告投放、电商交易、游戏对战&#xff0c;流量是经常变化的。无论是主从还是集群模式&#xff0c;当大流量进入时&#xff0c;Redis 处理能力达到上限&…

“打工搬砖记”中吃什么的轮盘功能实现(二)

文章目录 打工搬砖记转盘主要的逻辑实现转盘的素材小结 打工搬砖记 先来一个吃什么轮盘的预览图&#xff0c;这轮盘文案加字呈圆形铺出来&#xff0c;开始后旋转到指定的选项处停下来。 已上线小程序“打工人搬砖记”&#xff0c;可以扫码进行预览观看。 转盘主要的逻辑实现…

如何使用Docker安装并运行Nexus容器结合内网穿透实现远程管理本地仓库

前言 作者简介&#xff1a; 懒大王敲代码&#xff0c;计算机专业应届生 今天给大家聊聊如何使用Docker安装并运行Nexus容器结合内网穿透实现远程管理本地仓库&#xff0c;希望大家能觉得实用&#xff01; 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01;&#x1f496…

openlayer实现ImageStatic扩展支持平铺Wrapx

地图平铺&#xff08;Tiling&#xff09;是地图服务中常见的技术&#xff0c;用于将大尺寸的地图数据分割成许多小块&#xff08;瓦片&#xff09;&#xff0c;便于高效加载和展示。这种技术特别适用于网络环境&#xff0c;因为它允许浏览器只加载当前视图窗口内所需的地图瓦片…

IT行业现状与未来趋势分析

IT行业现状与未来趋势显示出持续的活力和变革&#xff0c;以下是上大学网&#xff08;www.sdaxue.com&#xff09;关于IT行业现状与未来趋势分析&#xff0c;供大家参考。 当前现状&#xff1a; 市场需求持续增长&#xff1a;随着信息时代的深入发展&#xff0c;各行各业对信息…

LLM Agent智能体综述(超详细)

前言 &#x1f3c6;&#x1f3c6;&#x1f3c6;在上一篇文章中&#xff0c;我们介绍了如何部署MetaGPT到本地&#xff0c;获取OpenAI API Key并配置其开发环境&#xff0c;并通过一个开发小组的多Agent案例感受了智能体的强大&#xff0c;在本文中&#xff0c;我们将对AI Agent…

5G消息和5G阅信的释义与区别 | 赛邮科普

5G消息和5G阅信的释义与区别 | 赛邮科普 在 5G 技术全面普及的当下&#xff0c;历史悠久的短信服务也迎来了前所未有的变革。5G 阅信和 5G 消息就是应运而生的两种短信形态&#xff0c;为企业和消费者带来更加丰富的功能和更加优质的体验。 这两个产品名字和形态都比较接近&am…

618速递丨各平台内卷严重,这些行业能否率先炸场?

根据最新发布的《中国网络视听发展研究报告&#xff08;2024&#xff09;》显示&#xff0c;71.2%的受访用户因为看短视频和直播进行网上购物&#xff0c;超40%的用户认为短视频和直播是他们的主要消费渠道&#xff0c;内容消费正成为各大电商争夺的关键赛道。 今年618&#x…

信创厂商选择要点

信创厂商选择要点 信创项目推进&#xff0c;不可避免的要与众多信创厂商打交道。选择靠谱的供应商&#xff0c;合理避坑&#xff0c;是信创项目成败的关键因素。个人认为技术突破能力、产品服务能力、生态建设能力、平滑迁移能力是评估一个信创厂商是否合格的重要标准。 技术…

【iOS】——RunLoop学习

文章目录 一、RunLoop简介1.RunLoop介绍2.RunLoop功能3.RunLoop使用场景4.Run Loop 与线程5.RunLoop源代码和模型图 二、RunLoop Mode1.CFRunLoopModeRef2.RunLoop Mode的五种模式3.RunLoop Mode使用 三、RunLoop Source1.CFRunLoopSourceRefsourc0&#xff1a;source1: 2.CFRu…

Vue中使用$t(‘xxx‘)实现中英文切换;

&#xff08;原文链接&#xff09; 介绍 {{$t(key)}} &#xff1a;是VueI18n插件提供的函数&#xff0c;主要用于根据当前语言环境返回对应的翻译文本&#xff0c;以便在页面上显示多语言内容。 key&#xff1a;作为参数传递给函数$t()的字符串&#xff0c;用于指定需要翻译的…

基于springboot+vue+Mysql的在线BLOG网

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

虾皮选品:Shopee首季盈利2.4亿;TikTok美区电商权限要求降低

2024年5月14号&#xff0c;跨境电商日报&#xff1a; 1.Ozon已成功回款 2.TikTok降低美区达人开通电商权限要求 3.Shopee首季盈利2.4亿 4.6月1日起&#xff0c;亚马逊退货处理费收取标准更新 5.欧盟委员会对从中国台湾地区和越南进口的不锈钢冷轧产品征收反补贴和反倾销税…

在数据库中使用存储过程插入单组/多组数据

存储过程可以插入单组数据&#xff0c;也可以以字符串的形式插入多组数据&#xff0c;将字符串中的信息拆分成插入的数据。 首先建立一个简单的数据库 create database student; use student;选中数据库之后建立一张学生表 create table stu(uid int primary key,uname varc…

wordpress 访问文章内容页 notfound

解决&#xff1a; 程序对应的伪静态规则文件.htaccess是空的 网站根目录下要有 .htaccess 文件&#xff0c;然后将下面的代码复制进去。 <ifmodule mod_rewrite.c>RewriteEngine OnRewriteBase /RewriteRule ^index\.php$ - [L]RewriteCond %{REQUEST_FILENAME} !-fRew…

python模拟QQ聊天的代码

以下是一个简单的Python模拟QQ聊天的代码示例&#xff1a; python # 导入QQ消息包 import tqq # 创建QQ客户端对象 client tqq.TQQClient() # 连接QQ服务器 client.connect("你的QQ号码", "你的QQ密码") # 创建一个QQ会话对象 session client.session() …