yolo极大抑制_pytorch实现yolov3(4) 非极大值抑制nms

在上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box.

理解了yolov3的输出的格式及每一个位置的含义,并不难理解源码.我在阅读源码的过程中主要的困难在于对pytorch不熟悉,所以在这篇文章里,关于其中涉及的一些pytorch中的函数的用法我都已经用加粗标示了并且给出了相应的链接,测试代码等.

obj score threshold

我们设置一个obj score thershold,超过这个值的才认为是有效的.

conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2)

prediction = prediction*conf_mask

prediction是1*boxnum*boxattr

prediction[:,:,4]是1*boxnum 元素值为boxattr的index=4的那个值.

torch中的Tensor index和numpy是类似的,参看下列代码输出

import torch

x = torch.Tensor(1,3,10) # Create an un-initialized Tensor of size 2x3

print(x)

print(x.shape) # Print out the Tensor

y = x[:,:,4]

print(y)

print(y.shape)

z = x[:,:,4:6]

print(z)

print(z.shape)

print((y>0.5).float().unsqueeze(2))

#### 输出如下

tensor([[[2.5226e-18, 1.6898e-04, 1.0413e-11, 7.7198e-10, 1.0549e-08,

4.0516e-11, 1.0681e-05, 2.9575e-18, 6.7333e+22, 1.7591e+22],

[1.7184e+25, 4.3222e+27, 6.1972e-04, 7.2443e+22, 1.7728e+28,

7.0367e+22, 5.9018e-10, 2.6540e-09, 1.2972e-11, 5.3370e-08],

[2.7001e-06, 2.6801e-09, 4.1292e-05, 2.1511e+23, 3.2770e-09,

2.5125e-18, 7.7052e+31, 1.9447e+31, 5.0207e+28, 1.1492e-38]]])

torch.Size([1, 3, 10])

tensor([[1.0549e-08, 1.7728e+28, 3.2770e-09]])

torch.Size([1, 3])

tensor([[[1.0549e-08, 4.0516e-11],

[1.7728e+28, 7.0367e+22],

[3.2770e-09, 2.5125e-18]]])

torch.Size([1, 3, 2])

tensor([[[0.],

[0.],

[0.]]])

Squeeze and unsqueeze 降低维度,升高维度.

t = torch.ones(2,1,2,1) # Size 2x1x2x1

r = torch.squeeze(t) # Size 2x2

r = torch.squeeze(t, 1) # Squeeze dimension 1: Size 2x2x1

# Un-squeeze a dimension

x = torch.Tensor([1, 2, 3])

r = torch.unsqueeze(x, 0) # Size: 1x3 表示在第0个维度添加1维

r = torch.unsqueeze(x, 1) # Size: 3x1 表示在第1个维度添加1维

这样prediction中objscore

nms

#得到box坐标(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)

box_corner = prediction.new(prediction.shape)

box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)

box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)

box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)

box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)

prediction[:,:,:4] = box_corner[:,:,:4]

原始的prediction中boxattr存放的是x,y,w,h,...,不方便我们处理,我们将其转换成(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)

接下来我们挨个处理每一张图片对应的feature map.

batch_size = prediction.size(0)

write = False

for ind in range(batch_size):

#image_pred.shape=boxnum\*boxattr

image_pred = prediction[ind] #image Tensor box_num*box_attr

#confidence threshholding

#NMS

#返回每一行的最大值,及最大值所在的列.

max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1)

#升级成和image_pred同样的维度

max_conf = max_conf.float().unsqueeze(1)

max_conf_score = max_conf_score.float().unsqueeze(1)

seq = (image_pred[:,:5], max_conf, max_conf_score)

#沿着列的方向拼接. 现在image_pred变成boxnum\*7

image_pred = torch.cat(seq, 1)

这里涉及到torch.max的用法,参见https://blog.csdn.net/Z_lbj/article/details/79766690

torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

按维度dim 返回最大值.可以这么记忆,沿着第dim维度比较.torch.max(0)即沿着行的方向比较,即得到每列的最大值.

假设input是二维矩阵,即行*列,行是第0维,列是第一维.

torch.max(a,0) 返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)

torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)

c=torch.Tensor([[1,2,3],[6,5,4]])

print(c)

a,b=torch.max(c,1)

print(a)

print(b)

##输出如下:

tensor([[1., 2., 3.],

[6., 5., 4.]])

tensor([3., 6.])

tensor([2, 0])

torch.cat(tensors, dim=0, out=None) → Tensor

>>> x = torch.randn(2, 3)

>>> x

tensor([[ 0.6580, -1.0969, -0.4614],

[-0.1034, -0.5790, 0.1497]])

>>> torch.cat((x, x, x), 0)

tensor([[ 0.6580, -1.0969, -0.4614],

[-0.1034, -0.5790, 0.1497],

[ 0.6580, -1.0969, -0.4614],

[-0.1034, -0.5790, 0.1497],

[ 0.6580, -1.0969, -0.4614],

[-0.1034, -0.5790, 0.1497]])

>>> torch.cat((x, x, x), 1)

tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,

-1.0969, -0.4614],

[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,

-0.5790, 0.1497]])

接下来我们只处理obj_score非0的数据(obj_score

non_zero_ind = (torch.nonzero(image_pred[:,4]))

try:

image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7)

except:

continue

#For PyTorch 0.4 compatibility

#Since the above code with not raise exception for no detection

#as scalars are supported in PyTorch 0.4

if image_pred_.shape[0] == 0:

continue

ok,接下来我们对每一种class做nms.

首先取到我们有哪些类别

#Get the various classes detected in the image

img_classes = unique(image_pred_[:,-1]) # -1 index holds the class index

然后依次对每一种类别做处理

for cls in img_classes:

#perform NMS

#get the detections with one particular class

#取出当前class为当前class且class prob!=0的行

cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)

class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()

image_pred_class = image_pred_[class_mask_ind].view(-1,7)

#sort the detections such that the entry with the maximum objectness

#confidence is at the top

#按照obj score从高到低做排序

conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1]

image_pred_class = image_pred_class[conf_sort_index]

idx = image_pred_class.size(0) #Number of detections

for i in range(idx):

#Get the IOUs of all boxes that come after the one we are looking at

#in the loop

try:

#计算第i个和其后每一行的的iou

ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:])

except ValueError:

break

except IndexError:

break

#Zero out all the detections that have IoU > treshhold

#把与第i行iou>nms_conf的认为是同一个目标的box,将其转成0

iou_mask = (ious < nms_conf).float().unsqueeze(1)

image_pred_class[i+1:] *= iou_mask

#把iou>nms_conf的移除掉

non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()

image_pred_class = image_pred_class[non_zero_ind].view(-1,7)

batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) #Repeat the batch_id for as many detections of the class cls in the image

seq = batch_ind, image_pred_class

其中计算iou的代码如下,不多解释了.iou=交叠面积/总面积

def bbox_iou(box1, box2):

"""

Returns the IoU of two bounding boxes

"""

#Get the coordinates of bounding boxes

b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]

b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]

#get the corrdinates of the intersection rectangle

inter_rect_x1 = torch.max(b1_x1, b2_x1)

inter_rect_y1 = torch.max(b1_y1, b2_y1)

inter_rect_x2 = torch.min(b1_x2, b2_x2)

inter_rect_y2 = torch.min(b1_y2, b2_y2)

#Intersection area

inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0)

#Union Area

b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)

b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)

iou = inter_area / (b1_area + b2_area - inter_area)

return iou

tensor index操作用法如下:

image_pred_ = torch.Tensor([[1,2,3,4,9],[5,6,7,8,9]])

#print(image_pred_[:,-1] == 9)

has_9 = (image_pred_[:,-1] == 9)

print(has_9)

###执行顺序是(image_pred_[:,-1] == 9).float().unsqueeze(1) 再做tensor乘法

cls_mask = image_pred_*(image_pred_[:,-1] == 9).float().unsqueeze(1)

print(cls_mask)

class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()

image_pred_class = image_pred_[class_mask_ind]

输出如下:

tensor([1, 1], dtype=torch.uint8)

tensor([[1., 2., 3., 4., 9.],

[5., 6., 7., 8., 9.]])

torch.sort用法如下:

d=torch.Tensor([[1,2,3],[6,5,4]])

e=d[:,2]

print(e)

print(torch.sort(e))

输出

tensor([3., 4.])

torch.return_types.sort(

values=tensor([3., 4.]),

indices=tensor([0, 1]))

总结一下我们做nms的流程

每一个image,会预测出N个detetction信息,包括4+1+C(4个坐标信息,1个obj score以及C个class probability)

首先过滤掉obj_score < confidence的行

每一行只取class probability最高的作为预测出来的类别

将所有的预测按照obj_score从大到小排序

循环每一种类别,开始做nms

比较第一个box与其后所有box的iou,删除iou>threshold的box,即剔除所有相似box

比较下一个box与其后所有box的iou,删除所有与该box相似的box

不断重复上述过程,直至不再有相似box

至此,实现了当前处理的类别的多个box均是独一无二的box.

write_results最终的返回值是一个n*8的tensor,其中8是(batch_index,4个坐标,1个objscore,1个class prob,一个class index)

def write_results(prediction, confidence, num_classes, nms_conf = 0.4):

print("prediction.shape=",prediction.shape)

#将obj_score < confidence的行置为0

conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2)

prediction = prediction*conf_mask

#得到box坐标(top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y)

box_corner = prediction.new(prediction.shape)

box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)

box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)

box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)

box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)

#修改prediction第三个维度的前四列

prediction[:,:,:4] = box_corner[:,:,:4]

batch_size = prediction.size(0)

write = False

for ind in range(batch_size):

#image_pred.shape=boxnum\*boxattr

image_pred = prediction[ind] #image Tensor

#confidence threshholding

#NMS

##取出每一行的class score最大的一个

max_conf_score,max_conf = torch.max(image_pred[:,5:5+ num_classes], 1)

max_conf = max_conf.float().unsqueeze(1)

max_conf_score = max_conf_score.float().unsqueeze(1)

seq = (image_pred[:,:5], max_conf_score, max_conf)

image_pred = torch.cat(seq, 1) #现在变成7列,分别为左上角x,左上角y,右下角x,右下角y,obj score,最大probabilty,相应的class index

print(image_pred.shape)

non_zero_ind = (torch.nonzero(image_pred[:,4]))

try:

image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7)

except:

continue

#For PyTorch 0.4 compatibility

#Since the above code with not raise exception for no detection

#as scalars are supported in PyTorch 0.4

if image_pred_.shape[0] == 0:

continue

#Get the various classes detected in the image

img_classes = unique(image_pred_[:,-1]) # -1 index holds the class index

for cls in img_classes:

#perform NMS

#get the detections with one particular class

#取出当前class为当前class且class prob!=0的行

cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1)

class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()

image_pred_class = image_pred_[class_mask_ind].view(-1,7)

#sort the detections such that the entry with the maximum objectness

#confidence is at the top

#按照obj score从高到低做排序

conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1]

image_pred_class = image_pred_class[conf_sort_index]

idx = image_pred_class.size(0) #Number of detections

for i in range(idx):

#Get the IOUs of all boxes that come after the one we are looking at

#in the loop

try:

#计算第i个和其后每一行的的iou

ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:])

except ValueError:

break

except IndexError:

break

#Zero out all the detections that have IoU > treshhold

#把与第i行iou>nms_conf的认为是同一个目标的box,将其转成0

iou_mask = (ious < nms_conf).float().unsqueeze(1)

image_pred_class[i+1:] *= iou_mask

#把iou>nms_conf的移除掉

non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()

image_pred_class = image_pred_class[non_zero_ind].view(-1,7)

batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) #Repeat the batch_id for as many detections of the class cls in the image

seq = batch_ind, image_pred_class

if not write:

output = torch.cat(seq,1) #沿着列方向,shape 1*8

write = True

else:

out = torch.cat(seq,1)

output = torch.cat((output,out)) #沿着行方向 shape n*8

try:

return output

except:

return 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/504951.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

抽象类可以生成自己的对象吗_大理石可以自己抛光吗?大理石自己抛光方法解答...

大理石可以自己抛光吗&#xff1f;大理石自己抛光方法解答。大理石只有经过抛光后&#xff0c;表面才会有更好的光泽和质感。具体抛光可以要求大理石抛光团队施工&#xff0c;也可以自己做。那么大理石自己抛光方法是什么呢&#xff1f;下面石大夫为您解答。大理石自己抛光方法…

ajax 参数带百分号,Ajax请求中带有IPv6地址后的百分号的问题

IPv6地址后的百分号:对于连入网络但没有IPv6路由器或DHCPv6服务器的IPv6客户端&#xff0c;它们始终使用fe80::/64链路本地网络地址。如果运行Windows的计算机中有多个网络适配器连接到不同的网段&#xff0c;可以在IP地址后加百分号和区域ID数字来区分不同的网络&#xff0c;如…

高中数学40分怎么办_高二数学不会,准高三该怎么办?40分到高考140如何逆袭?...

原标题&#xff1a;高二数学不会&#xff0c;准高三该怎么办&#xff1f;40分到高考140如何逆袭&#xff1f;高二&#xff0c;这个年级是有点尴尬的&#xff0c;适应了高一的学习&#xff0c;感觉高二学习没有了动力&#xff0c;离高考还远&#xff0c;于是有些孩子就开始了放任…

python识别人脸多种属性_深度学习人脸识别仅9行python代码实现?同时高效处理100张相片?...

随着人脸识别、视频结构化等计算视觉相关技术在安防、自动驾驶、手机等领域走向商业化应用阶段&#xff0c;计算视觉技术行业市场迎来大规模的爆发。伴随人脸识别、物体识别等分类、分割算法不断提升精度。计算视觉的核心算法深度学习算法日渐成熟&#xff0c;通过对输出与对应…

服务器可以装2个系统吗,云服务器可以装多个系统吗

云服务器可以装多个系统吗 内容精选换一换共享云硬盘是一种支持多个云服务器并发读写访问的数据块级存储设备&#xff0c;具备多挂载点、高并发性、高性能、高可靠性等特点。主要应用于需要支持集群、HA(High Available&#xff0c;指高可用集群)能力的关键企业应用场景&#x…

关卡 动画 蓝图 运行_UE4无缝过场动画

最近有个哥们给我看他们最近在做的一个游戏&#xff0c;其中有这样一段镜头https://www.zhihu.com/video/1171378736917364736运用到了一个很常用的过场方式&#xff0c;就是平时我们所说的无缝过场。过场动画不通过黑屏转换&#xff0c;而是通过运镜来代入。这是一种比较容易实…

python条件循环叠加_Python基础:条件判断与循环的两个要点

一、条件判断&#xff1a;Python中&#xff0c;条件判断用if语句实现&#xff0c;多个条件判断时用if...elif实现&#xff1a;看下面一段程序#python 3.3.5#test if...elifage 20if age > 6:print (teenager)elif age > 18:print (adult)else:print (kid)程序输出结果&a…

H3C批量收集服务器信息,H3C设备服务器采集参数认证过程(包含redfish和restfull协议)...

该脚本针对H3C服务器分别对redfish和restfull两种协议的认证方式进行测试&#xff0c;并合并。有三个类&#xff0c;分别是redfish协议测试、restfull协议测试、以及两个合并测试文章最后使用redfish模块简单进行认证访问测试。import requestsimport jsonrequests.packages.ur…

个推的appid是指什么_推箱子软件介绍→安卓下最专业的推箱子软件(推箱快手)...

俗语说&#xff1a;工欲善其事必先利其器目前各安卓系统下的应用市场有很多很多推箱子软件&#xff0c;除了soko推箱子软件比较好以外&#xff0c;其余没有任何一款软件是推箱子好手想去使用的&#xff0c;为什么呢&#xff1f;先说说soko这款软件好在哪儿&#xff1f;点推式推…

list转字符串_剑指offer 38——字符串的排列

本题主要在于对回溯的理解&#xff0c;优化时可以结合 java 特性&#xff0c;以及排列的一些知识。原题输入一个字符串&#xff0c;打印出该字符串中字符的所有排列。你可以以任意顺序返回这个字符串数组&#xff0c;但里面不能有重复元素。示例:输入&#xff1a;s "abc&…

v5系列服务器后面板不存在以下哪款指示,群晖RS10613xs+ NAS服务器后面板简介

群晖RS10613xs NAS服务器后面板简介群晖RS10613xs NAS服务器后面板简介:NAS服务器的后面板往往承担着数据的输入、输出&#xff0c;电影的输入&#xff0c;网络的传输&#xff0c;容量的扩展&#xff0c;电能的支持以及产品的散热等重要功能&#xff0c;看似简单的后面板往往是…

怎么判断冠词用a还是an_【语法微课堂】英语冠词的用法,学会这4点,轻松玩转a、an、the...

点击上方??蓝色字&#xff0c;轻松关注&#xff01;Well begun is half done.良好的开端是成功的一半。准备了一下午&#xff0c;终于可以给大家更新了&#xff0c;给大家分享了冠词讲解的视频、音频和文字版&#xff0c;自行取用吧&#xff01;?冠词讲解视频版(小提示&…

pb自定义控件 事件_Android WebView与下拉刷新控件滑动冲突的解决方法

使用WebView时一般会在外层使用下拉刷新控件如(SwipeRefreshLayout)。但是测试时会发现网页无法上拉&#xff0c;往上滑动就会触发下拉刷新控件的refresh事件。所以这里记录一下解决该问题的办法。1、通过webView.getScrollY() 的值来判断是否滚动到顶部private SwipeRefreshLa…

双路服务器cpu必须型号相同,双路主板存在使用不同型号的cpu之说吗?还是必须使用一模一样相同的cpu型号?...

双路主板不存在使用不同型号的cpu一说&#xff0c; 可以使用不同型号的cpu&#xff0c; 不过参数差别不能过大(例如处理器的架构差别)多路主板就是一种主从结构&#xff0c; 处理器之间是协同工作&#xff0c;由中间的高速总线实现两个处理器的配合&#xff0c;不存在处理器必须…

打开多个界面_如何创建用户界面

CANBusKit&#xff0c;是一款集成汽车总线开发、测试、分析的专业软件工具。本章内容主要介绍如何使用CBK_OpenPanel工具为CANBUSKIT 工程创建用户界面&#xff0c;本工具目前支持Vector的xvp格式的面板文件导入。首先是启动软件(试用版软件只能从CANBusKit软件界面中启动该软件…

python命令行解析_python命令行解析函数

sys.argv在终端运行python 1.py hahahimportsysprint(sys.argv) #[1.py, hahah]argparsePython的命令行解析模块&#xff0c;这是一个python的内置库&#xff0c;通过在程序中我们定义好的参数&#xff0c;argparse将会从sys.argv中解析出这些参数&#xff0c;并自动生成帮助和…

汤姆克兰西全境封锁服务器维护时间,汤姆克兰西全境封锁无法登录怎么解决 无法登录解决方法攻略...

《汤姆克兰西&#xff1a;全境封锁》是款大型射击游戏&#xff0c;这款游戏的画面十分的精致&#xff0c;在这款游戏中会有各种不同的任务&#xff0c;玩家要带着武器来进行射击。在游戏的时候很多玩家们都反映无法登录怎么解决&#xff1f;那么下面小编就为玩家们详细解说下关…

需要的依赖_三十而已:夫妻关系中需要的是坦诚和依赖

最近三十而已大热播&#xff0c;开始时很多人都看好顾佳和许幻山这一对&#xff0c;顾佳有才有颜&#xff0c;上得厅堂下得厨房&#xff0c;处理事情干净利索&#xff0c;是难得的贤内助。许幻山温柔帅气还有才&#xff0c;关键是还对老婆好&#xff0c;他们的组合可以说是很让…

python代码导出_代码生成 – Python生成Python

我有一组对象,我正在创建一个类,我想要将每个对象存储为自己的文本文件.我真的希望将其存储为一个Python类定义,它会分类我正在创建的主类.所以,我做了一些戳,并在effbot.org上找到了一个Python代码生成器.我做了一些实验,这里是我想出来的&#xff1a;## a Python code genera…

语言建立一个学生籍贯管理簿_编写一个Excel自定义函数,身份证信息提取如探囊取物...

观看视频更直观我们建立信息表时不仅要输入性别、生日和年龄等信息&#xff0c;往往也需要输入身份证号码&#xff0c;而身份证号码中包含有籍贯、性别、生日和年龄等信息&#xff0c;从身份证号码中提取上述信息可以减少输入工作量&#xff0c;提高工作效率。利用Excel中的内置…