非极大值抑制算法(Non-Maximum Suppression,NMS)

https://tcnull.github.io/nms/
https://blog.csdn.net/weicao1990/article/details/103857298

目标检测中检测出了许多的候选框,候选框之间是有重叠的,NMS作用重叠的候选框只保留一个

算法:

  1. 将所有候选框放入到集和B
  2. 从B中选出分数S最大的bm
  3. 将bm 从集和B中移除到集和D
  4. 计算bm与B中剩余的候选框之间的IOU。
  5. 如果iou>Nt则将其从B中删除。(去除重叠比较多的候选框)
  6. 循环直至B为空。
  7. D会越来越多。
def box_iou_union_2d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tuple[Tensor, Tensor]:"""Return intersection-over-union (Jaccard index) and  of boxes.Both sets of boxes are expected to be in (x1, y1, x2, y2) format.Arguments:boxes1: set of boxes (x1, y1, x2, y2)[N, 4]boxes2: set of boxes (x1, y1, x2, y2)[M, 4]eps: optional small constant for numerical stabilityReturns:iou (Tensor[N, M]): the NxM matrix containing the pairwiseIoU values for every element in boxes1 and boxes2union (Tensor[N, M]): the nxM matrix containing the pairwise unionvalues"""area1 = box_area(boxes1)area2 = box_area(boxes2)x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0])  # [N, M]y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1])  # [N, M]x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2])  # [N, M]y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3])  # [N, M]inter = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)) + eps  # [N, M]union = (area1[:, None] + area2 - inter)return inter / union, union
def nms_cpu(boxes, scores, thresh):"""Performs non-maximum suppression for 3d boxes on cpuArgs:boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]scores (Tensor): score for each box [N]iou_threshold (float): threshould when boxes are discardedReturns:keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,sorted in decreasing order of scores"""ious = box_iou(boxes, boxes)_, _idx = torch.sort(scores, descending=True)keep = []while _idx.nelement() > 0:keep.append(_idx[0])# get all elements that were not matched and discard all others.non_matches = torch.where((ious[_idx[0]][_idx] <= thresh))[0]_idx = _idx[non_matches]return torch.tensor(keep).to(boxes).long()

template <typename T>
__device__ inline float devIoU(T const* const a, T const* const b) {// a, b hold box coords as (y1, x1, y2, x2) with y1 < y2 etc.T bottom = max(a[0], b[0]), top = min(a[2], b[2]);T left = max(a[1], b[1]), right = min(a[3], b[3]);T width = max(right - left, (T)0), height = max(top - bottom, (T)0);T interS = width * height;T Sa = (a[2] - a[0]) * (a[3] - a[1]);T Sb = (b[2] - b[0]) * (b[3] - b[1]);return interS / (Sa + Sb - interS);
}
at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold) {/* dets expected as (n_dets, dim) where dim=4 in 2D, dim=6 in 3D */AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");at::cuda::CUDAGuard device_guard(dets.device());//管理CUDA设备上下文,并指制定使用的CUDA设备bool is_3d = dets.size(1) == 6;//按照第一维(索引为0)对scores降序排序,并返回一个包含排序后索引的 Tensor,std::get<1>(...) 提取排序后索引的 Tensor。auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));//使用排序后的索引 order_t 从 dets 中选择对应的元素,返回一个新的 Tensor; dets_sorted,其中的元素按照排序后的顺序排列。auto dets_sorted = dets.index_select(0, order_t);//bbox个数int dets_num = dets.size(0);//该函数用于计算在CUDA中进行并行化计算时所需的最大列块数。其中,dets_num代表待处理的数据数量,threadsPerBlock表示每个CUDA块中的线程数量。//函数通过将dets_num除以threadsPerBlock并向上取整得到最大列块数col_blocks。这个函数常用于确定CUDA并行计算中需要启动多少个CUDA块来处理所有数据。const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);//获取block个数//该函数创建了一个名为mask的空Tensor,其大小为dets_num * col_blocks,数据类型为长整型(at::kLong)。该Tensor的属性与dets相同。at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));//创建一个空容器  Ddim3 blocks(col_blocks, col_blocks);//定义了一个二维网格,其维度为col_blocks行col_blocks列dim3 threads(threadsPerBlock);//定义了一个线程块,其维度为threadsPerBlock;cudaStream_t stream = at::cuda::getCurrentCUDAStream();//获取当前的CUDA流,用于异步计算。if (is_3d) {//std::cout << "performing NMS on 3D boxes in CUDA" << std::endl;AT_DISPATCH_FLOATING_TYPES_AND_HALF(dets_sorted.type(), "nms_kernel_cuda", [&] {nms_kernel_3d<scalar_t><<<blocks, threads, 0, stream>>>(dets_num,iou_threshold,dets_sorted.data_ptr<scalar_t>(),(unsigned long long*)mask.data_ptr<int64_t>());});}else {//该函数是PyTorch CUDA扩展中的一个宏定义,用于在CUDA代码中处理浮点类型数据(包括float、double、half)的泛型编程。//它会根据传入的输入数据类型,自动选择对应的CUDA内核函数进行计算。//这样可以避免为每种数据类型编写重复的代码,提高代码的可维护性和可扩展性。//相当与模板类AT_DISPATCH_FLOATING_TYPES_AND_HALF(dets_sorted.type(), "nms_kernel_cuda", [&] {nms_kernel<scalar_t><<<blocks, threads, 0, stream>>>(dets_num,iou_threshold,dets_sorted.data_ptr<scalar_t>(),(unsigned long long*)mask.data_ptr<int64_t>());});}//将mask_cpu的数据拷贝到主机内存中at::Tensor mask_cpu = mask.to(at::kCPU);unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();//创建一个空的remv Tensor,其大小为col_blocks,数据类型为unsigned long long。用于记录每个block中哪些线程被标记为无效。std::vector<unsigned long long> remv(col_blocks);memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);//创建一个空的keep Tensor,其大小为dets_num,数据类型为long。用于存放整个筛选后的索引。at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));int64_t* keep_out = keep.data_ptr<int64_t>();int num_to_keep = 0;for (int i = 0; i < dets_num; i++) {int nblock = i / threadsPerBlock;//第几个blockint inblock = i % threadsPerBlock;//block中第几个threads//判断当前线程是否被标记为无效if (!(remv[nblock] & (1ULL << inblock))) {keep_out[num_to_keep++] = i;unsigned long long* p = mask_host + i * col_blocks;//实际上将所有的形式设置成one hot形式。for (int j = nblock; j < col_blocks; j++) {remv[j] |= p[j];//按位或运算,并将结果保留在remv数组中}}}AT_CUDA_CHECK(cudaGetLastError());//检测CUDA 是否发生错误return order_t.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)//从张量keep中选择制定维度、起始位置和长度的子张量.to(order_t.device(), keep.scalar_type())});
}
template <typename T>
__global__ void nms_kernel(const int n_boxes, const float iou_threshold, const T* dev_boxes, unsigned long long* dev_mask) {const int row_start = blockIdx.y;const int col_start = blockIdx.x;// if (row_start > col_start) return;const int row_size = min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);//当前block行可以放多少个boxconst int col_size = min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);//获取block列可以放多少个box__shared__ T block_boxes[threadsPerBlock * 4];if (threadIdx.x < col_size) {//将当前block列的box拷贝到shared memoryblock_boxes[threadIdx.x * 4 + 0] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];block_boxes[threadIdx.x * 4 + 1] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];block_boxes[threadIdx.x * 4 + 2] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];block_boxes[threadIdx.x * 4 + 3] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];}__syncthreads();//同步if (threadIdx.x < row_size) {const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;//当前block行box的索引const T* cur_box = dev_boxes + cur_box_idx * 4;//当前block行box的指针int i = 0;unsigned long long t = 0;int start = 0;if (row_start == col_start) {start = threadIdx.x + 1;//自己的IOU就不算了}for (i = start; i < col_size; i++) {if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {//计算各自的IOUt |= 1ULL << i;//以二进制的形式表示重叠关系成立}}const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);//列block数dev_mask[cur_box_idx * col_blocks + col_start] = t;//将重叠关系写入shared memory}
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/35685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu安装QT

以QT5.15.14为例 下载地址&#xff1a;Index of /archive/qt 安装步骤&#xff1a; 解压qt-everywhere-src-5.15.14运行&#xff1a; cd qt-everywhere-src-5.15.14 mkdir build cd build ../configure -prefix /opt/qt5.15.14 -opensource -confirm-license make -j16 sudo…

keep-alive页面切回原滚动位置hook方法

keep-alive页面切回原滚动位置hook方法 原理hook使用 原理 如果使用了keep-alive组件&#xff0c;当前组件会额外增加两个生命周期。 activated&#xff1a;被 keep-alive 缓存的组件激活时调用 deactivated&#xff1a;被 keep-alive 缓存的组件失活时调用。 通过这两个声明周…

VBA技术资料MF168:移动工作表为单独工作簿

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。“VBA语言専攻”提供的教程一共九套&#xff0c;分为初级、中级、高级三大部分&#xff0c;教程是对VBA的系统讲解&#…

eventbus和vuex

EventBus和Vuex EventBus 工作原理 创建一个vue实例&#xff0c;然后通过空的vue实例作为组件之间的桥梁&#xff0c;进行通信&#xff0c;利用到的设计模式有发布订阅模式 Vuex 工作原理 维护了一个state树&#xff0c;是独立的状态树&#xff0c;有明显的层级关系。不论…

云计算-期末复习题-框架设计/选择/填空/简答(2)

目录 框架设计 1.负载分布架构 2.动态可扩展架构 3.弹性资源容量架构 4.服务负载均衡架构 5.云爆发结构 6.弹性磁盘供给结构 7.负载均衡的虚拟服务器实例架构 填空题/简答题 单选题 多选题 云计算期末复习部分练习题&#xff0c;包括最后的部分框架设计大题(只是部分…

vue3工程引用vue2模块文件时所做的修改

vue3工程引用vue2模块文件&#xff0c;如果代码没改的话&#xff0c;编译运行时可能会报错&#xff1a; SyntaxError: The requested module /fs/D:/HBuilderX.3.8.7.20230703/HBuilderX/plugins/uniapp-cli-vite/node_modules/dcloudio/uni-h5-vue/dist/vue.runtime.esm.js do…

SchedulerLock LockProvider参数配置表,列,大小写等 分布式锁 定时任务锁 学习总结

一、SchedulerLock 使用场景 如果是分布式任务&#xff0c;即多个相同的应用执行定时任务&#xff0c;那么为了防止重复执行可以使用其他分布式锁做内部判断或其他形式的锁机制来防止重复执行。 SchedulerLock 提供了现成的封装好的分布式锁机制来防止定时任务被重复执行 gi…

jQuery动画与特效

显示与隐藏动画 语法&#xff1a; $(obj).show(duration,fn);显示 $(obj).hide(duration,fn);隐藏 $(obj).toggle(); 功能&#xff1a; 1. show()方法能动态地改变当前元素的高度、宽度和不透明度&#xff0c;最终显示当前元素&#xff1b; 2. hide()方法会动态地改变当前元素…

AI办公自动化:多音频轨电影视频抽取出英语音频

很多电影视频是有中、英、粤语等多个音频轨的&#xff0c;如果直接转换成音频&#xff0c;很有可能不是自己想要的那种语音。 可以先查看音频流信息&#xff0c;确定属于哪个音频轨&#xff1a; Reading video file: E:\1-7\比得兔1.mp4 输出音频流信息 Available audio str…

利用viztracer进行性能分析和优化

上一篇文章&#xff0c;我们详细讲解了scalene这个性能分析和优化工具的使用流程&#xff1b;今天&#xff0c;我们将深入探讨另一个性能分析和优化工具——viztracer。 什么是viztracer&#xff1f; viztracer是一个非常强大的分析器&#xff0c;可以生成详细的性能报告和可…

设计师进阶指南:掌握这6条版式设计要点

布局设计是设计师的必修课。优秀的排版不是强制性的“东拼西凑”&#xff0c;而是通过设计师独特的排版获得的。这不是简单的信息列表&#xff0c;而是认真思考如何分层、有节奏地组织和安排元素。今天我将给你带来它 6 文章还附带了布局设计模板资源&#xff0c;设计师朋友一定…

《Windows API每日一练》6.1 鼠标基础知识

本节我们讲述鼠标的一些基础知识。 本节必须掌握的知识点&#xff1a; 鼠标 6.1.1 鼠标 鼠标是1964年由Douglas Engelbart发明的&#xff0c;用来取代由键盘输入的繁琐指令&#xff0c;简化电脑操作。早期的鼠标是单键鼠标&#xff0c;只有一个键&#xff0c;后来逐步改进为双…

EthernetIP IO从站设备数据 转opc ua项目案例

1 案例说明 设置网关采集EthernetIP IO设备数据把采集的数据转成opc ua协议转发给其他系统。 2 VFBOX网关工作原理 VFBOX网关是协议转换网关&#xff0c;是把一种协议转换成另外一种协议。网关可以采集西门子&#xff0c;欧姆龙&#xff0c;三菱&#xff0c;AB PLC&#xff0…

Element 页面滚动表头置顶

在开发后台管理系统时&#xff0c;表格是最常用的一个组件&#xff0c;为了看数据方便&#xff0c;时常需要固定表头。 如果页面基本只有一个表格区域&#xff0c;我们可以根据屏幕的高度动态的计算出一个值&#xff0c;给表格设定一个固定高度&#xff0c;这样表头就可以固定…

红酒达人教你秘技:选酒、存酒,一招一式皆学问

在繁忙的都市生活中&#xff0c;红酒不仅仅是一种饮品&#xff0c;更是一种生活态度&#xff0c;一种品味的象征。然而&#xff0c;面对琳琅满目的红酒品牌与种类&#xff0c;如何选择一瓶心仪的红酒&#xff0c;又如何妥善保存&#xff0c;使其保持很好口感&#xff0c;成为了…

python函数练习

1、编写函数&#xff0c;传入N&#xff0c;求123…N的和 def s_sum(num):i 1sum1 0while i < num:sum1 ii 1return sum1num int(input(请输入一个整数&#xff1a;)) print(和为:,s_sum(num))2、编写一个函数&#xff0c;定义一个列表&#xff0c;求列表中的最大值 d…

LabVIEW遇到无法控制国外设备时怎么办

当使用LabVIEW遇到无法控制国外产品的问题时&#xff0c;解决此类问题需要系统化的分析和处理方法。以下是详细的解决思路和具体办法&#xff0c;以及不同方法的分析和比较&#xff0c;包括寻求代理、国外技术支持、国内用过的人请教等内容。 1. 了解产品的通信接口和协议 思路…

LeetCode.25K个一组翻转链表详解

问题描述 给你链表的头节点 head &#xff0c;每 k 个节点一组进行翻转&#xff0c;请你返回修改后的链表。 k 是一个正整数&#xff0c;它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍&#xff0c;那么请将最后剩余的节点保持原有顺序。 你不能只是单纯的改变节…

Python 基础 (标准库):collections (集合类)

1. 官方文档 collections --- 容器数据类型 — Python 3.12.4 文档 Python 的 collections 模块提供了许多有用的数据类型&#xff08;包括 OrderedDict、Counter、defaultdict、deque 和 namedtuple&#xff09;用于扩展 Python 的标准数据类型。掌握 collections 中的数据类…

五子棋纯python手写,需要的拿去

import pygame,sys from pygame import * pygame.init()game pygame.display.set_mode((600,600)) gameover False circlebox [] # 棋盘坐标点存储 box [] def xy():for x in range(0,800//40): for y in range(0,800//40): box.append((x*40,y*40)) xy() defaultColor wh…