混淆矩阵细致理解

1、什么是混淆矩阵

混淆矩阵(Confusion Matrix)是深度学习和机器学习领域中的一个重要工具,用于评估分类模型的性能。它提供了一个清晰的视觉方式来展示模型的预测结果与真实标签之间的关系,尤其在分类任务中,帮助我们了解模型的强项和弱点。

混淆矩阵通常是一个二维矩阵,其中包含四个关键的指标:

预测值=正例预测值=反例
真实值=正例TPFN
真实值=反例FPTN

  1. 真正例(True Positives,TP): 模型正确地预测了正类别的样本数量。

  2. 真负例(True Negatives,TN): 模型正确地预测了负类别的样本数量。

  3. 假正例(False Positives,FP): 模型错误地将负类别的样本预测为正类别的数量。

  4. 假负例(False Negatives,FN): 模型错误地将正类别的样本预测为负类别的数量。

你可以选择看我之前写过的这一篇博客。其实很好理解,比如TP,它就是正确的预测了正确的样本,FP就是错误的预测为了正确的样本。同理,TN就是正确的预测了错误的样本,FN就是错误的预测了错误的样本。

2、从混淆矩阵得到分类指标

然后我们来看上图,这里就能够得到一些指标运算的公式:

准确率(Accuracy):它是分类正确的样本数与总样本数的比率。准确率通常用来衡量模型在整个数据集上的性能,但在不平衡类别的情况下可能不太适用。

准确率 = (TP + TN) / (TP + TN + FP + FN)

精确率(Precision):精确率是指被模型正确预测为正类别的样本数占所有预测为正类别的样本数的比率。它用于衡量模型在正类别预测中的准确性。

精确率 = TP / (TP + FP)

召回率(Recall):召回率是指被模型正确预测为正类别的样本数占所有真实正类别的样本数的比率。它用于衡量模型在识别正类别样本中的能力。

召回率 = TP / (TP + FN)

F1 分数:F1 分数是精确率和召回率的调和平均值,用于综合评估模型的性能。它对精确率和召回率都进行了考虑,特别适用于不平衡类别的情况。

F1 分数 = 2 * (精确率 * 召回率) / (精确率 + 召回率)

IoU(Intersection over Union):IoU 用于语义分割等任务,它是真实正类别区域与模型预测正类别区域的交集与并集之比。

IoU = TP / (TP + FP + FN) 

3、使用pytorch构建混淆矩阵

最初要写的目的也是为了回顾一下之前所学的,并且想要在训练过程中能写一个类方便调用。先说一下思路。

首先,这个是针对标签的,我需要一个num_classes,也就是分类数,以便我先创建一个分类数大小的矩阵。

然后在不计算梯度的情况下,我们需要筛选出合适的像素点,这里简单来说就是一行代码:

k = (t >= 0) & (t < n)
  • t是真实类别的张量,其中包含了每个像素的真实类别标签。
  • n是类别总数,表示模型可以进行分类的类别数量。

然后,通过t[k]与p[k]就可以确定正确的像素范围。将每个选定像素的真实类别标签乘以总类别数n,以获得一个在混淆矩阵中的行索引,然后再加上p[k],就是我们混淆矩阵的索引。

inds = n * t[k].to(torch.int64) + p[k]

torch.bincount是用于统计 inds 中每个索引出现的次数。minlength 参数指定了输出张量的长度,这里设置为 n**2,以确保输出张量的长度足够容纳混淆矩阵的所有元素。

以上就是我的思路,这里大家可以自己打印出来看看每个步骤是怎么实现的:

import torchnum_classes = n = 3
mat = torch.zeros((n, n), dtype=torch.int64)true_labels = t = torch.tensor([0, 1, 2, 0, 1, 2])  # 真实标签
predicted_labels = p = torch.tensor([0, 1, 1, 0, 2, 1])  # 预测结果with torch.no_grad():k = (t >= 0) & (t < n)inds = n * t[k].to(torch.int64) + p[k]print(inds)mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
print(mat)

打印出来的值:

tensor([0, 4, 7, 0, 5, 7])
tensor([[2, 0, 0],
            [0, 1, 1],
            [0, 2, 0]])

这个混淆矩阵解释如下:

  • 第一行表示真实标签为类别0的样本,模型将其分为类别0的次数为2次。
  • 第二行表示真实标签为类别1的样本,模型将其中一个分为类别1,另一个分为类别2。
  • 第三行表示真实标签为类别2的样本,模型将其都分为类别1。

然后对照我们原本设定的数据也是完全符合的。

4、使用pytorch构建分类指标

将混淆矩阵 mat 转换为浮点数张量 h ,以便进行后续计算

h = mat.float()

全局预测准确率,混淆矩阵的对角线表示的是真实和预测相对应的个数

acc_global = torch.diag(h).sum()/h.sum()

计算每个类别的准确率

acc = torch.diag(h)/h.sum(1)

 计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)

iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))

我们打印出来,大家可以对照着公式进行对比即可:

tensor(0.5000)   tensor([1.0000, 0.5000, 0.0000])   tensor([1.0000, 0.2500, 0.0000])

其他的指标以后在补充,我想的是在评估的时候使用,所以这里的指标最好还是写一个类定义。

我也将其放进了pyzjr当中,欢迎大家pip安装使用。

class ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, t, p):n = self.num_classesif self.mat is None:# 创建混淆矩阵self.mat = torch.zeros((n, n), dtype=torch.int64, device=t.device)with torch.no_grad():# 寻找GT中为目标的像素索引k = (t >= 0) & (t < n)# 统计像素真实类别t[k]被预测成类别p[k]的个数inds = n * t[k].to(torch.int64) + p[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()def compute(self):"""计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)计算每个类别的准确率计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)"""h = self.mat.float()acc_global = torch.diag(h).sum() / h.sum()acc = torch.diag(h) / h.sum(1)iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)if __name__=="__main__":num_classes = 3confusion_matrix = ConfusionMatrix(num_classes)# 模拟一些真实标签和预测结果true_labels = torch.tensor([0, 1, 2, 0, 1, 2])  # 真实标签predicted_labels = torch.tensor([0, 1, 1, 0, 2, 1])  # 预测结果# 更新混淆矩阵confusion_matrix.update(true_labels, predicted_labels)# 打印混淆矩阵及评估指标报告print("Confusion Matrix:")print(confusion_matrix.mat)print("\nEvaluation Report:")print(confusion_matrix)

打印的信息:

Confusion Matrix:
tensor([[2, 0, 0],
        [0, 1, 1],
        [0, 2, 0]])

Evaluation Report:
global correct: 50.0
average row correct: ['100.0', '50.0', '0.0']
IoU: ['100.0', '25.0', '0.0']
mean IoU: 41.7

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/79712.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Unity基础】2.网格材质贴图与资源打包

【Unity基础】2.网格材质贴图与资源打包 大家好&#xff0c;我是Lampard~~ 欢迎来到Unity基础系列博客&#xff0c;所学知识来自B站阿发老师~感谢 &#xff08;一&#xff09;网格材质纹理 第一次接触3D物体的话&#xff0c;会觉得好神奇啊&#xff0c;这个物体究竟是由什么组…

基于安卓Java试题库在线考试系统uniapp 微信小程序

本文首先分析了题库app应用程序的需求&#xff0c;从系统开发环境、系统目标、设计流程、功能设计等几个方面对系统进行了系统设计。开发出本题库app&#xff0c;主要实现了学生、教师、测试卷、试题、考试等。总体设计主要包括系统功能设计、该系统里充分综合应用Mysql数据库、…

Ae 效果:CC Particle Systems II

模拟/CC Particle Systems II Simulation/CC Particle Systems II CC Particle Systems II&#xff08;CC 粒子系统 II&#xff09;可用于生成和模拟各种类型的粒子系统&#xff0c;包括火焰、雨、雪、爆炸、烟雾等等。 与 CC Particle World 效果相比有许多类似的属性。最大的…

前端该了解的网络知识

网络 前端开发需要了解的网络知识 URL URL(uniform resource locator,统一资源定位符)用于定位网络服务. URL是一个固定格式的字符串 它表达了: 从网络中哪台计算机(domain)中的哪个服务(port),获取服务器上资源的路径(path),以及要用什么样的协议通信(schema). 注意: 当…

C# wpf 实现桌面放大镜

文章目录 前言一、如何实现&#xff1f;1、制作无边框窗口2、Viewbox放大3、截屏显示&#xff08;1&#xff09;、截屏&#xff08;2&#xff09;、转BitmapSource&#xff08;3&#xff09;、显示 4、定时截屏 二、完整代码三、效果预览总结 前言 做桌面截屏功能时需要放大镜…

卫星物联网生态建设全面加速,如何抓住机遇?

当前&#xff0c;卫星通信无疑是行业最热门的话题之一。近期发布的华为Mate 60 Pro“向上捅破天”技术再次升级&#xff0c;成为全球首款支持卫星通话的大众智能手机&#xff0c;支持拨打和接听卫星电话&#xff0c;还可自由编辑卫星消息。 据悉&#xff0c;华为手机的卫星通话…

【Unity每日一记】资源加载相关和检测相关

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;uni…

【计算机网络】Tcp详解

文章目录 前言Tcp协议段格式TCP的可靠性面向字节流应答机制超时重传流量控制滑动窗口&#xff08;重要&#xff09;拥塞控制延迟应答捎带应答标志位具体标志位三次握手四次挥手粘包问题TCP异常情况listen的第二个参数 前言 前面我们学习了传输层协议Udp&#xff0c;今天我们一…

使用FFmpeg+ubuntu系统转化flac无损音频为mp3

功能需求如上题,我们来具体的操作一下: 1.先在ubuntu上面安装FFmpeg:sudo apt install ffmpeg 2.进入有flac音频文件的目录使用下述命令: ffmpeg -i test.FLAC -c:a libmp3lame -q:a 2 output.mp3 3.如果没有什么意外的话,你就能看到你的文件夹里面已经有转化好的mp3文件了 批…

ubuntu中如何用docker下载华为opengauss数据库(超简单)

ubuntu中如何下载华为opengauss数据库 前言一、安装docker1.方法一&#xff1a;2.方法二 二、拉取openguass镜像三、创建容器四、连接数据库 ,切换到omm用户 &#xff0c;用gsql连接到数据库五.最后用DateGrip远程连接测试(1&#xff09;选择数据源(2&#xff09;查看虚拟机ip地…

#循循渐进学51单片机#定时器与数码管#not.4

1、熟练掌握单片机定时器的原理和应用方法。 1&#xff09;时钟周期&#xff1a;单片机时序中的最小单位&#xff0c;具体计算的方法就是时钟源分之一。 2&#xff09;机器周期&#xff1a;我们的单片机完成一个操作的最短时间。 3)定时器&#xff1a;打开定时器“储存寄存器…

Python提取JSON数据中的键值对并保存为.csv文件

本文介绍基于Python&#xff0c;读取JSON文件数据&#xff0c;并将JSON文件中指定的键值对数据转换为.csv格式文件的方法。 在之前的文章Python提取JSON文件中的指定数据并保存在CSV或Excel表格文件内&#xff08;https://blog.csdn.net/zhebushibiaoshifu/article/details/132…

Windows PostgreSql 创建多个数据库目录

1 使用默认用户Administrator 1.1初始化数据库目录 E:\Program Files\PostgreSQL\13> .\bin\initdb -D G:\DATA\pgsql\data3 -W -A md5 1.2连接数据库 这时User为Administrator&#xff0c;密码就是你刚才设置的&#xff0c;我设置的为123456&#xff0c;方便测试。 2 添加…

黑马JVM总结(九)

&#xff08;1&#xff09;StringTable_调优1 我们知道StringTable底层是一个哈希表&#xff0c;哈希表的性能是跟它的大小相关的&#xff0c;如果哈希表这个桶的个数比较多&#xff0c;元素相对分散&#xff0c;哈希碰撞的几率就会减少&#xff0c;查找的速度较快&#xff0c…

【微服务】六. Nacos配置管理

6.1 Nacos实现配置管理 配置更改热更新 在nacos左侧新建配置管理 Data ID&#xff1a;就是配置文件名称 一般命名规则&#xff1a;服务名称-环境名称.yaml 配置内容填写&#xff1a;需要热更新需求的配置 配置文件的id&#xff1a;[服务名称]-[profile].[后缀名] 分组&#…

Vuex详解:Vue.js的状态管理方案

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

启动微服务,提示驱动程序无法通过使用安全套接字层(SSL)加密与 SQL Server 建立安全连接

说明&#xff1a;启动一些微服务后&#xff0c;一直在报下面这个错误&#xff1b; com.microsoft.sqlserver.jdbc.SQLServerException: 驱动程序无法通过使用安全套接字层(SSL)加密与 SQL Server 建立安全连接。错误:“The server selected protocol version TLS10 is not acc…

uniapp抽取组件绑定事件中箭头函数含花括号无法解析

版本: "dcloudio/uni-ui": "^1.4.27", "vue": "> 2.6.14 < 2.7"... 箭头函数后含有花括号的时候, getData就拿不到val参数 , 解决办法就是去除花括号 // 错误代码: <SearchComp change"(val) > { getData({ val …

跨域问题解决方案(三种)

Same Origin Policy同源策略&#xff08;SOP&#xff09; 具有相同的Origin&#xff0c;也即是拥有相同的协议、主机地址以及端口。一旦这三项数据中有一项不同&#xff0c;那么该资源就将被认为是从不同的Origin得来的&#xff0c;进而不被允许访问。 Cross-origin resource…

Jsoup | Document | HTML解析器

Jsoup 一、获取 <p>标签下的所有图片 一、获取 <p>标签下的所有图片 <p> <img style"max-width: 100%;" src"http://image.svipjf.cn/1678271098160-480_01.jpg"/><img src"http://image.svipjf.cn/1678271097994-480_02…