PyTorch中的topk方法以及分类Top-K准确率的实现

PyTorch中的topk方法以及分类Top-K准确率的实现

Top-K 准确率

在分类任务中的类别数很多时(如ImageNet中1000类),通常任务是比较困难的,有时模型虽然不能准确地将ground truth作为最高概率预测出来,但通过学习,至少groud truth的准确率能够在所有类中处于很靠前的位置,这在现实生活中也是有一定应用意义的,因此除了常规的Top-1 Acc,放宽要求的Tok-K Acc也是某些分类任务的重要指标之一。

Tok-K准确率:即指在模型的预测结果中,前K个最高概率的类中有groud truth,就认为在Tok-K准确率的要求下,模型分类成功了。

PyTorch中的topk方法

PyTorch中并没有直接提供计算模型Top-K分类准确率的接口,但是提供了一个topk方法,用来获得某tensor某维度中最高或最低的K个值。

函数接口

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)

同样有tensor.topk的使用方式,参数及返回值类似。

参数说明

input:输入张量

dim:指定在哪个维度取topk

k:前k大或前k小值

largest:取最大(True)或最小(False)

sorted:返回值是否有序

返回值说明

返回两个张量:values和indices,分别对应前k大/小值的数值和索引,注意返回值的各维度的意义,不要搞反了,后面实验会说。

实验

我们在这里模拟常见的分类任务的情况,设置batch size为4,类别数为10,这样模型输出应为形状为(10,4)的张量。

output = torch.rand(4, 10)
print(output)
print('*'*100)
values, indices = torch.topk(output, k=2, dim=1, largest=True, sorted=True)
print("values: ", values)
print("indices: ", indices)
print('*'*100)
print(output.topk(k=2, dim=1, largest=True, sorted=False))		# tensor.topk的用法

输出:

tensor([[0.7082, 0.5335, 0.9494, 0.7792, 0.3288, 0.6303, 0.0335, 0.6918, 0.0778,0.6404],[0.3881, 0.8676, 0.7700, 0.6266, 0.8843, 0.8902, 0.4336, 0.5385, 0.8372,0.1204],[0.9717, 0.2727, 0.9086, 0.7797, 0.1216, 0.4793, 0.1149, 0.1544, 0.7292,0.0459],[0.0424, 0.0809, 0.1597, 0.4177, 0.4798, 0.7107, 0.9683, 0.7502, 0.1536,0.3994]])
****************************************************************************************************
values:  tensor([[0.9494, 0.7792],[0.8902, 0.8843],[0.9717, 0.9086],[0.9683, 0.7502]])
indices:  tensor([[2, 3],[5, 4],[0, 2],[6, 7]])
****************************************************************************************************
torch.return_types.topk(
values=tensor([[0.9494, 0.7792],[0.8902, 0.8843],[0.9717, 0.9086],[0.9683, 0.7502]]),
indices=tensor([[2, 3],[5, 4],[0, 2],[6, 7]]))

注意输出的行是用户指定的dim的k个最大/小值(实验中sorted=True,所以是有序返回的),列是其他未指定的维度,不要搞反了。

分类Top-K准确率的实现

实现

借助刚刚介绍的PyTorch中的topk方法实现的分类任务的Top-K准确率计算方法。

def accuracy(output, target, topk=(1, )):       # output.shape (bs, num_classes), target.shape (bs, )"""Computes the accuracy over the k top predictions for the specified values of k"""with torch.no_grad():maxk = max(topk)batch_size = target.size(0)_, pred = output.topk(maxk, 1, True, True)pred = pred.t()correct = pred.eq(target.view(1, -1).expand_as(pred))res = []for k in topk:correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)res.append(correct_k.mul_(100.0 / batch_size))return res

实验

我们同样拿上面的分类任务做实验,batch size为4,类别数为10,给定label:2,1,8,5,为了方便观察,计算Top-1,2准确率(ImageNet-1K中通常计算Top-1,5准确率)。

测试代码:

output = torch.rand(4, 10)
label = torch.Tensor([2, 1, 8, 5]).unsqueeze(dim=1)
print(output)
print('*'*100)
values, indices = torch.topk(output, k=2, dim=1, largest=True, sorted=True)
print("values: ", values)
print("indices: ", indices)
print('*'*100)print(accuracy(output, label, topk=(1, 2)))

输出:

tensor([[0.8721, 0.7391, 0.1365, 0.3017, 0.2840, 0.2400, 0.6473, 0.3965, 0.5449,0.7518],[0.7120, 0.8533, 0.2809, 0.9515, 0.2971, 0.8182, 0.5498, 0.0797, 0.8027,0.6916],[0.4540, 0.8468, 0.9022, 0.5144, 0.2007, 0.7292, 0.5559, 0.0290, 0.6664,0.2076],[0.1793, 0.0205, 0.7322, 0.4918, 0.6194, 0.9179, 0.1639, 0.6346, 0.8829,0.3573]])
****************************************************************************************************
values:  tensor([[0.8721, 0.7518],[0.9515, 0.8533],[0.9022, 0.8468],[0.9179, 0.8829]])
indices:  tensor([[0, 9],[3, 1],[2, 1],[5, 8]])
[tensor([25.]), tensor([50.])]

可以看到在top1准确率时只有最后一个样本与标签对应,故Top-1准确率为1 / 4 =25%,而在top2准确率时样本2,4预测成功了,Top-2准确率为50%,符合我们的预期。

有疑惑或异议欢迎留言讨论。

Ref:https://pytorch.org/docs/master/generated/torch.topk.html#torch-topk

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532845.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java高级语言特性,Java高级语言特性之注解

注解的定义Java 注解(Annotation)又称 Java 标注,是 JDK1.5 引入的一种注释机制。注解是元数据的一种形式,提供有关于程序但不属于程序本身的数据。注解对它们注解的代码的操作没有直接影响。注解本身没有任何意义,单独的注解就是一种注释&am…

C/C++中的typedef 和 #define

C/C中的typedef 和 #define typedef C/C中的关键字typedef允许用户为类型名来起一个新名字,通常会是缩写或者能够清晰表明类型含义的新名字。 例: typedef unsigned int UINT; UINT 100;值得注意的是,typedef除了为C/C内置的数据类型取别…

php3.2.3 升级,thinkphp3.2.3 升级到3.2.4时出错问题

有些项目最初用OneThink做的,而OneThink 默认使用的TP 是3.2.0 的,没事的时候就想给升级一下,但是直接复制进去的时候,有错误,导致OneThink 不能运行,排查后,需要修改两个地方1、修改 Applicati…

Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 1

Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。对其做各种改进的顶会论文也是层出不穷,本文将聚焦于各种最新的视觉trans…

mysql 分析查询语句,MySQL教程之SQL语句分析查询优化

怎么获取有功能问题的SQL1、经过用户反应获取存在功能问题的SQL2、经过慢查询日志获取功能问题的SQL3、实时获取存在功能问题的SQL运用慢查询日志获取有功能问题的SQL首要介绍下慢查询相关的参数1、slow_query_log 发动定制记载慢查询日志设置的办法,能够经过MySQL指…

关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 f…

geoda权重矩阵导入matlab,空间计量经济学-分析解析.ppt

厦门大学 邓明 空间截面回归模型 地理加权回归模型 地理加权回归模型扩展了普通线性回归模型。在GWR模型中,特定区位的回归系数不再是利用全部信息获得的假定常数,而是利用邻近观测值的子样本数据信息进行局域(Local)回归估计而得,并随着空间…

树莓派摄像头基础配置及测试

树莓派摄像头基础配置 step 1 硬件连接 硬件连接,注意不要接反了,排线蓝色一段朝向网口的方向。(笔者的设备是树莓派4B) step 2 安装raspi-config 安装 raspi-config raspi-config在raspbian中是预装的,而在kali、…

matlab sobel锐化,sobel锐化 - yirui wu.ppt

sobel锐化 - yirui wu第六章 图像锐化 图像锐化的概念 图像锐化的目的是加强图像中景物的细节边缘和轮廓。 锐化的作用是使灰度反差增强。 因为边缘和轮廓都位于灰度突变的地方。所以锐化算法的实现是基于微分作用。 图像锐化方法 图像的景物细节特征; 一阶微分锐化…

使用百度云智能SDK和树莓派搭建简易的人脸识别系统 Python语言版

硬件 树莓派4B一个CSI摄像头一个 笔者使用的是树莓派4B和CSI摄像头,但是树莓派3和USB摄像头等相似设备均可。 百度云智能设置 Step 1 登录 百度云智能 网址https://cloud.baidu.com/ 首先登录百度账号,与百度云、百度贴吧等互通,可直接…

php 5.6 引用传递,升级到5.6.x后如何在php中修复引用传递

我最近将fom php 5.2升级到5.6,并且有一些代码我无法修复://Finds users with the same ip- or email-addressfunction find_related_users($user_id) {global $pdo;//print_R($pdo);//Let SQL do the magic!$sth $pdo->prepare(CALL find_related_users(?));$…

RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading zip arc

RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading zip archive: failed finding central directory 原因分析 这个报错是出现在PyTorch在读入模型参数时: checkpoint torch.load(epoch_15.pth, map_locationcpu)…

xp搭建 php环境,windows xp 下 LAMP环境搭建

1. apache安装步骤如下图在浏览器中输入:localhost,出现下面页面说明已成功安装apache。2. mysql安装如下图显示在运行里面输入cmd ,然后连接测试mysql ,如图所示:3. php安装(1)将php压缩包解压到安装路径中的php目录…

C++中的虚函数(表)实现机制以及用C语言对其进行的模拟实现

C中的虚函数(表)实现机制以及用C语言对其进行的模拟实现 声明:本文非博主原创,转自https://blog.twofei.com/496/,博主读后受益良多,特地转载,一是希望好文能有更多人看到,二是为了日后自己查阅。 前言 …

php 前端模板 yii,php – Yii2高级模板:添加独立网页

我在backend / views / site下添加了help.php,并在SiteController.php下声明了一个能够识别链接的函数public function behaviors(){return [access > [class > AccessControl::className(),rules > [[actions > [login, error],allow > true,],[actions > […

C++中数组和指针的关系(区别)详解

C中数组和指针的关系(区别)详解 本文转自:http://c.biancheng.net/view/1472.html 博主在阅读后将文中几个知识点提出来放在前面: 没有方括号和下标的数组名称实际上代表数组的起始地址,这意味着数组名称实际上就是…

安装php独立环境,0507-php独立环境的安装与配置 Web程序 - 贪吃蛇学院-专业IT技术平台...

1.在一个纯英文目录下新建三个文件夹2.安装apache(选择好版本)过程中该填的按格式填好,其余的只更改安装目录即可如果报错1901是安装版本的问题。检查:安装完成后localhost打开为It works!添加到电脑属性环境变量:3.将php文件解压文档放到AMP…

linux中PATH变量-详细介绍

转自:https://blog.csdn.net/haozhepeng/article/details/100584451 转载者勘误 原文最后提到的 echo 命令对于环境变量的修改无影响。这是肯定的,echo 命令相当于只是一个打印的函数(比如 Python 中的 print)。这里要修改环境变…

php assert eval,代码执行函数之一句话木马

前言大家好,我是阿里斯,一名IT行业小白。非常抱歉,昨天的内容出现瑕疵比较多,今天重新整理后再次发出,修改并添加了细节,另增加了常见的命令执行函数如果哪里不足,还请各位表哥指出。eval和asse…

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理

显卡、显卡驱动、CUDA、CUDA Toolkit、cuDNN 梳理 转自:https://www.cnblogs.com/marsggbo/p/11838823.html#nvccnvidia-smi GPU型号含义 显卡: 简单理解这个就是我们前面说的GPU,尤其指NVIDIA公司生产的GPU系列,因为后面介绍的…