YOLOv5 分类模型的后处理
flyfish
简化源码测试
import torch
import numpy as np
from torchvision import transforms
import torch.nn.functional as Fdata0 = np.random.random((1, 7))
data0 = np.round(data0,7)
print(data0.shape)
print(data0)
data1 = torch.from_numpy(data0)p = F.softmax(data1, dim=1) # probabilities
print("P1:",p)
i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indicesprint("result:",i)
测试代码
def soft_max(x):x = np.exp(x) / np.sum(np.exp(x), axis = 1, keepdims = True)return xp= soft_max(data0)
print("P2:",p)print("argmax:",np.argmax(p))
print("argsort:",(np.argsort(-p,1)[:, :5]).squeeze())
详细的softmax解释
np.argsort(-p,1)
返回将对数组进行排序后的索引,默认从小到大,这里-p
是 从大到小,因为是二维,1表示1轴
因为返回的维度 是二维,样子i是[[]],squeeze后变成1维
(1, 7)
[[0.2796715 0.1158704 0.8198501 0.6392486 0.4822099 0.1255404 0.4855295]]
P1: tensor([[0.1204, 0.1022, 0.2066, 0.1724, 0.1474, 0.1032, 0.1479]],dtype=torch.float64)
result: tensor([2, 3, 6, 4, 0])
P2: [[0.12036311 0.10217755 0.20658082 0.17244705 0.14738549 0.10317040.14787556]]
argmax: 2
argsort: [2 3 6 4 0]
两者相同