非极大值抑制算法(Non-Maximum Suppression,NMS)

https://tcnull.github.io/nms/
https://blog.csdn.net/weicao1990/article/details/103857298

目标检测中检测出了许多的候选框,候选框之间是有重叠的,NMS作用重叠的候选框只保留一个

算法:

  1. 将所有候选框放入到集和B
  2. 从B中选出分数S最大的bm
  3. 将bm 从集和B中移除到集和D
  4. 计算bm与B中剩余的候选框之间的IOU。
  5. 如果iou>Nt则将其从B中删除。(去除重叠比较多的候选框)
  6. 循环直至B为空。
  7. D会越来越多。
def box_iou_union_2d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tuple[Tensor, Tensor]:"""Return intersection-over-union (Jaccard index) and  of boxes.Both sets of boxes are expected to be in (x1, y1, x2, y2) format.Arguments:boxes1: set of boxes (x1, y1, x2, y2)[N, 4]boxes2: set of boxes (x1, y1, x2, y2)[M, 4]eps: optional small constant for numerical stabilityReturns:iou (Tensor[N, M]): the NxM matrix containing the pairwiseIoU values for every element in boxes1 and boxes2union (Tensor[N, M]): the nxM matrix containing the pairwise unionvalues"""area1 = box_area(boxes1)area2 = box_area(boxes2)x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0])  # [N, M]y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1])  # [N, M]x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2])  # [N, M]y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3])  # [N, M]inter = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)) + eps  # [N, M]union = (area1[:, None] + area2 - inter)return inter / union, union
def nms_cpu(boxes, scores, thresh):"""Performs non-maximum suppression for 3d boxes on cpuArgs:boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]scores (Tensor): score for each box [N]iou_threshold (float): threshould when boxes are discardedReturns:keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,sorted in decreasing order of scores"""ious = box_iou(boxes, boxes)_, _idx = torch.sort(scores, descending=True)keep = []while _idx.nelement() > 0:keep.append(_idx[0])# get all elements that were not matched and discard all others.non_matches = torch.where((ious[_idx[0]][_idx] <= thresh))[0]_idx = _idx[non_matches]return torch.tensor(keep).to(boxes).long()

template <typename T>
__device__ inline float devIoU(T const* const a, T const* const b) {// a, b hold box coords as (y1, x1, y2, x2) with y1 < y2 etc.T bottom = max(a[0], b[0]), top = min(a[2], b[2]);T left = max(a[1], b[1]), right = min(a[3], b[3]);T width = max(right - left, (T)0), height = max(top - bottom, (T)0);T interS = width * height;T Sa = (a[2] - a[0]) * (a[3] - a[1]);T Sb = (b[2] - b[0]) * (b[3] - b[1]);return interS / (Sa + Sb - interS);
}
at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold) {/* dets expected as (n_dets, dim) where dim=4 in 2D, dim=6 in 3D */AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");at::cuda::CUDAGuard device_guard(dets.device());//管理CUDA设备上下文,并指制定使用的CUDA设备bool is_3d = dets.size(1) == 6;//按照第一维(索引为0)对scores降序排序,并返回一个包含排序后索引的 Tensor,std::get<1>(...) 提取排序后索引的 Tensor。auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));//使用排序后的索引 order_t 从 dets 中选择对应的元素,返回一个新的 Tensor; dets_sorted,其中的元素按照排序后的顺序排列。auto dets_sorted = dets.index_select(0, order_t);//bbox个数int dets_num = dets.size(0);//该函数用于计算在CUDA中进行并行化计算时所需的最大列块数。其中,dets_num代表待处理的数据数量,threadsPerBlock表示每个CUDA块中的线程数量。//函数通过将dets_num除以threadsPerBlock并向上取整得到最大列块数col_blocks。这个函数常用于确定CUDA并行计算中需要启动多少个CUDA块来处理所有数据。const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);//获取block个数//该函数创建了一个名为mask的空Tensor,其大小为dets_num * col_blocks,数据类型为长整型(at::kLong)。该Tensor的属性与dets相同。at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));//创建一个空容器  Ddim3 blocks(col_blocks, col_blocks);//定义了一个二维网格,其维度为col_blocks行col_blocks列dim3 threads(threadsPerBlock);//定义了一个线程块,其维度为threadsPerBlock;cudaStream_t stream = at::cuda::getCurrentCUDAStream();//获取当前的CUDA流,用于异步计算。if (is_3d) {//std::cout << "performing NMS on 3D boxes in CUDA" << std::endl;AT_DISPATCH_FLOATING_TYPES_AND_HALF(dets_sorted.type(), "nms_kernel_cuda", [&] {nms_kernel_3d<scalar_t><<<blocks, threads, 0, stream>>>(dets_num,iou_threshold,dets_sorted.data_ptr<scalar_t>(),(unsigned long long*)mask.data_ptr<int64_t>());});}else {//该函数是PyTorch CUDA扩展中的一个宏定义,用于在CUDA代码中处理浮点类型数据(包括float、double、half)的泛型编程。//它会根据传入的输入数据类型,自动选择对应的CUDA内核函数进行计算。//这样可以避免为每种数据类型编写重复的代码,提高代码的可维护性和可扩展性。//相当与模板类AT_DISPATCH_FLOATING_TYPES_AND_HALF(dets_sorted.type(), "nms_kernel_cuda", [&] {nms_kernel<scalar_t><<<blocks, threads, 0, stream>>>(dets_num,iou_threshold,dets_sorted.data_ptr<scalar_t>(),(unsigned long long*)mask.data_ptr<int64_t>());});}//将mask_cpu的数据拷贝到主机内存中at::Tensor mask_cpu = mask.to(at::kCPU);unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();//创建一个空的remv Tensor,其大小为col_blocks,数据类型为unsigned long long。用于记录每个block中哪些线程被标记为无效。std::vector<unsigned long long> remv(col_blocks);memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);//创建一个空的keep Tensor,其大小为dets_num,数据类型为long。用于存放整个筛选后的索引。at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));int64_t* keep_out = keep.data_ptr<int64_t>();int num_to_keep = 0;for (int i = 0; i < dets_num; i++) {int nblock = i / threadsPerBlock;//第几个blockint inblock = i % threadsPerBlock;//block中第几个threads//判断当前线程是否被标记为无效if (!(remv[nblock] & (1ULL << inblock))) {keep_out[num_to_keep++] = i;unsigned long long* p = mask_host + i * col_blocks;//实际上将所有的形式设置成one hot形式。for (int j = nblock; j < col_blocks; j++) {remv[j] |= p[j];//按位或运算,并将结果保留在remv数组中}}}AT_CUDA_CHECK(cudaGetLastError());//检测CUDA 是否发生错误return order_t.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)//从张量keep中选择制定维度、起始位置和长度的子张量.to(order_t.device(), keep.scalar_type())});
}
template <typename T>
__global__ void nms_kernel(const int n_boxes, const float iou_threshold, const T* dev_boxes, unsigned long long* dev_mask) {const int row_start = blockIdx.y;const int col_start = blockIdx.x;// if (row_start > col_start) return;const int row_size = min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);//当前block行可以放多少个boxconst int col_size = min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);//获取block列可以放多少个box__shared__ T block_boxes[threadsPerBlock * 4];if (threadIdx.x < col_size) {//将当前block列的box拷贝到shared memoryblock_boxes[threadIdx.x * 4 + 0] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];block_boxes[threadIdx.x * 4 + 1] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];block_boxes[threadIdx.x * 4 + 2] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];block_boxes[threadIdx.x * 4 + 3] =  dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];}__syncthreads();//同步if (threadIdx.x < row_size) {const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;//当前block行box的索引const T* cur_box = dev_boxes + cur_box_idx * 4;//当前block行box的指针int i = 0;unsigned long long t = 0;int start = 0;if (row_start == col_start) {start = threadIdx.x + 1;//自己的IOU就不算了}for (i = start; i < col_size; i++) {if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {//计算各自的IOUt |= 1ULL << i;//以二进制的形式表示重叠关系成立}}const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);//列block数dev_mask[cur_box_idx * col_blocks + col_start] = t;//将重叠关系写入shared memory}
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/35685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VBA技术资料MF168:移动工作表为单独工作簿

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。“VBA语言専攻”提供的教程一共九套&#xff0c;分为初级、中级、高级三大部分&#xff0c;教程是对VBA的系统讲解&#…

eventbus和vuex

EventBus和Vuex EventBus 工作原理 创建一个vue实例&#xff0c;然后通过空的vue实例作为组件之间的桥梁&#xff0c;进行通信&#xff0c;利用到的设计模式有发布订阅模式 Vuex 工作原理 维护了一个state树&#xff0c;是独立的状态树&#xff0c;有明显的层级关系。不论…

云计算-期末复习题-框架设计/选择/填空/简答(2)

目录 框架设计 1.负载分布架构 2.动态可扩展架构 3.弹性资源容量架构 4.服务负载均衡架构 5.云爆发结构 6.弹性磁盘供给结构 7.负载均衡的虚拟服务器实例架构 填空题/简答题 单选题 多选题 云计算期末复习部分练习题&#xff0c;包括最后的部分框架设计大题(只是部分…

AI办公自动化:多音频轨电影视频抽取出英语音频

很多电影视频是有中、英、粤语等多个音频轨的&#xff0c;如果直接转换成音频&#xff0c;很有可能不是自己想要的那种语音。 可以先查看音频流信息&#xff0c;确定属于哪个音频轨&#xff1a; Reading video file: E:\1-7\比得兔1.mp4 输出音频流信息 Available audio str…

利用viztracer进行性能分析和优化

上一篇文章&#xff0c;我们详细讲解了scalene这个性能分析和优化工具的使用流程&#xff1b;今天&#xff0c;我们将深入探讨另一个性能分析和优化工具——viztracer。 什么是viztracer&#xff1f; viztracer是一个非常强大的分析器&#xff0c;可以生成详细的性能报告和可…

设计师进阶指南:掌握这6条版式设计要点

布局设计是设计师的必修课。优秀的排版不是强制性的“东拼西凑”&#xff0c;而是通过设计师独特的排版获得的。这不是简单的信息列表&#xff0c;而是认真思考如何分层、有节奏地组织和安排元素。今天我将给你带来它 6 文章还附带了布局设计模板资源&#xff0c;设计师朋友一定…

EthernetIP IO从站设备数据 转opc ua项目案例

1 案例说明 设置网关采集EthernetIP IO设备数据把采集的数据转成opc ua协议转发给其他系统。 2 VFBOX网关工作原理 VFBOX网关是协议转换网关&#xff0c;是把一种协议转换成另外一种协议。网关可以采集西门子&#xff0c;欧姆龙&#xff0c;三菱&#xff0c;AB PLC&#xff0…

Element 页面滚动表头置顶

在开发后台管理系统时&#xff0c;表格是最常用的一个组件&#xff0c;为了看数据方便&#xff0c;时常需要固定表头。 如果页面基本只有一个表格区域&#xff0c;我们可以根据屏幕的高度动态的计算出一个值&#xff0c;给表格设定一个固定高度&#xff0c;这样表头就可以固定…

红酒达人教你秘技:选酒、存酒,一招一式皆学问

在繁忙的都市生活中&#xff0c;红酒不仅仅是一种饮品&#xff0c;更是一种生活态度&#xff0c;一种品味的象征。然而&#xff0c;面对琳琅满目的红酒品牌与种类&#xff0c;如何选择一瓶心仪的红酒&#xff0c;又如何妥善保存&#xff0c;使其保持很好口感&#xff0c;成为了…

LabVIEW遇到无法控制国外设备时怎么办

当使用LabVIEW遇到无法控制国外产品的问题时&#xff0c;解决此类问题需要系统化的分析和处理方法。以下是详细的解决思路和具体办法&#xff0c;以及不同方法的分析和比较&#xff0c;包括寻求代理、国外技术支持、国内用过的人请教等内容。 1. 了解产品的通信接口和协议 思路…

Python 基础 (标准库):collections (集合类)

1. 官方文档 collections --- 容器数据类型 — Python 3.12.4 文档 Python 的 collections 模块提供了许多有用的数据类型&#xff08;包括 OrderedDict、Counter、defaultdict、deque 和 namedtuple&#xff09;用于扩展 Python 的标准数据类型。掌握 collections 中的数据类…

五子棋纯python手写,需要的拿去

import pygame,sys from pygame import * pygame.init()game pygame.display.set_mode((600,600)) gameover False circlebox [] # 棋盘坐标点存储 box [] def xy():for x in range(0,800//40): for y in range(0,800//40): box.append((x*40,y*40)) xy() defaultColor wh…

8.DELL R730服务器对RAID5进行扩容

如果服务器的空间不足了&#xff0c;如何进行扩容&#xff1f;我基本上按照如何重新配置虚拟磁盘或添加其他硬盘来进行操作。我的机器上已经有三块硬盘了&#xff0c;组了Raid5&#xff0c;现在再添加一块硬盘。 先把要添加的硬盘插入服务器&#xff0c;无论是在IDRAC还是管理…

物联网“此用户无权修改接入点名称设置”解决方案

根本原因apns-conf.xml里面没有 符合 物理网卡 的配置 可以先加一个APN试一下&#xff0c;看看默认的MCC和MNC是什么 然后在”命令行“查询一下 adb shell sqlite3 /data/user_de/0/com.android.providers.telephony/databases/telephony.db "select * from carriers wh…

乐鑫已支持Matter 1.2标准新增多种设备类型,启明云端乐鑫代理商

随着物联网技术的飞速发展&#xff0c;智能家居正逐渐成为现代生活的一部分。物联网和智能家居行业应用取得了巨大的增长&#xff0c;一系列无线连接的智能设备涌入家庭&#xff0c;为家庭生活带来自动化和便利。 像是可以连网的扬声器、灯泡和中控开关&#xff0c;它们都可以…

迁移学习——CycleGAN

CycleGAN 1.导入需要的包2.数据加载&#xff08;1&#xff09;to_img 函数&#xff08;2&#xff09;数据加载&#xff08;3&#xff09;图像转换 3.随机读取图像进行预处理&#xff08;1&#xff09;函数参数&#xff08;2&#xff09;数据路径&#xff08;3&#xff09;读取文…

MySQL笔记——索引

索引 SQL性能分析使用原则SQL提示覆盖索引前缀索引单列索引和联合索引索引设计原则 学习黑马MySQL课程&#xff0c;记录笔记&#xff0c;用于复习。 查询建表语句&#xff1a; show create table account;以下为建表语句&#xff1a; CREATE TABLE account (id int NOT NULL …

Redis-集群-环境搭建

文章目录 1、清空主从复制和哨兵模式留下的一些文件1.1、删除以rdb后缀名的文件1.2、删除主从复制的配置文件1.3、删除哨兵模式的配置文件 2、appendonly修改回no3、开启daemonize yes4、protect-mode no5、注释掉bind6、制作六个实例的配置文件6.1、制作配置文件redis6379.con…

使用 fvm 管理 Flutter 版本

文章目录 Github官网fvm 安装Mac/Linux 环境Windows 环境 fvm 环境变量fvm 基本命令 Github https://github.com/leoafarias/fvmhttps://github.com/flutter/flutter 官网 https://fvm.app/ fvm 安装 Mac/Linux 环境 Install.sh curl -fsSL https://fvm.app/install.sh …

20240627 每日AI必读资讯

&#x1f50d;挑战英伟达&#xff01;00 后哈佛辍学小哥研发史上最快 AI 芯片 - 3名大学辍学生创立、目前仅35 名员工、刚筹集1.2 亿美元的团队&#xff1a;Etched。 - 史上最快Transformer芯片诞生了&#xff01; - 用Sohu跑Llama 70B&#xff0c;推理性能已超B200十倍&…