切片辅助超推理-sahi库-get_sliced_prediction源码简析

代码地址:https://github.com/obss/sahi

get_sliced_prediction源码中重要是理解nms或nmm。nms经常遇到不说。

 其中nmm即Non-Max Merging算法是最重要部分,它其实和nms比较类似。其具体原理我看到了一片博客,感觉讲的很好,如下:https://blog.roboflow.com/non-max-merging/

截取最重要部分如下:

Here are the steps Non-Max Merge takes:

  1. First, it sorts all detections by their confidence score, from highest to lowest.
  2. It then takes all pairs of detections and computes their IOU, checking how much the pair overlaps.
  3. From most confident to least, it will build groups of overlapping detections.
    1. It starts by creating a new group with the most confident non-grouped detection D1.
    2. Then, each non-grouped detection that overlaps with D1 by at least iou_threshold (specified by the user) is placed in the same group.
    3. By repeating these two steps, we end up with mutually exclusive groups, such as [[D1, D2, D4], [D3], [D5, D6]].
  4. Then merging begins. This is done with detection pairs (D1, D2) and is implementation-specific. In supervision we:
    1. Make a new bounding box xyxy to fit both D1 and D2.
    2. Make a new mask containing pixels where the masks of D1 or D2 were.
    3. Create a new confidence value, adding together the confidence of D1 and D2, normalized by their xyxy areas.
      New Conf = (Conf 1 * Area 1 + Conf 2 * Area 2) / (Area 1 + Area 2)
    4. Copy class_idtracker_id, and data from the Detection with the higher confidence.
  5. The prior step is done on detection pairs. How do we merge the whole group?
    1. Create an empty list for results.
    2. If there's only one detection in a group, add it to the results list.
    3. Otherwise, pick the first two detections, compute the IOU again, and if it's above the user-specified iou_threshold, pairwise merge it as outlined in the prior step.
      The resulting merged detection stays in the group as the new first element, and the group is shortened by 1. Continue pairwise merging while there are at least two elements in a group.
      Note that the IOU calculation makes the algorithm more costly but is required to prevent the merged detection from growing boundlessly.

sahi库中get_sliced_prediction函数如下。

def get_sliced_prediction(image,detection_model=None,slice_height: int = None,slice_width: int = None,overlap_height_ratio: float = 0.2,overlap_width_ratio: float = 0.2,perform_standard_pred: bool = True,postprocess_type: str = "GREEDYNMM",postprocess_match_metric: str = "IOS",postprocess_match_threshold: float = 0.5,postprocess_class_agnostic: bool = False,verbose: int = 1,merge_buffer_length: int = None,auto_slice_resolution: bool = True,slice_export_prefix: str = None,slice_dir: str = None,
) -> PredictionResult:"""Function for slice image + get predicion for each slice + combine predictions in full image.Args:image: str or np.ndarrayLocation of image or numpy image matrix to slicedetection_model: model.DetectionModelslice_height: intHeight of each slice.  Defaults to ``None``.slice_width: intWidth of each slice.  Defaults to ``None``.overlap_height_ratio: floatFractional overlap in height of each window (e.g. an overlap of 0.2 for a windowof size 512 yields an overlap of 102 pixels).Default to ``0.2``.overlap_width_ratio: floatFractional overlap in width of each window (e.g. an overlap of 0.2 for a windowof size 512 yields an overlap of 102 pixels).Default to ``0.2``.perform_standard_pred: boolPerform a standard prediction on top of sliced predictions to increase large objectdetection accuracy. Default: True.postprocess_type: strType of the postprocess to be used after sliced inference while merging/eliminating predictions.Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.postprocess_match_metric: strMetric to be used during object prediction matching after sliced prediction.'IOU' for intersection over union, 'IOS' for intersection over smaller area.postprocess_match_threshold: floatSliced predictions having higher iou than postprocess_match_threshold will bepostprocessed after sliced prediction.postprocess_class_agnostic: boolIf True, postprocess will ignore category ids.verbose: int0: no print1: print number of slices (default)2: print number of slices and slice/prediction durationsmerge_buffer_length: intThe length of buffer for slices to be used during sliced prediction, which is suitable for low memory.It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered.scenario. See [the discussion](https://github.com/obss/sahi/pull/445).auto_slice_resolution: boolif slice parameters (slice_height, slice_width) are not given,it enables automatically calculate these params from image resolution and orientation.slice_export_prefix: strPrefix for the exported slices. Defaults to None.slice_dir: strDirectory to save the slices. Defaults to None.Returns:A Dict with fields:object_prediction_list: a list of sahi.prediction.ObjectPredictiondurations_in_seconds: a dict containing elapsed times for profiling"""# for profilingdurations_in_seconds = dict()# currently only 1 batch supportednum_batch = 1# create slices from full imagetime_start = time.time()#图像切片slice_image_result = slice_image(image=image,output_file_name=slice_export_prefix,output_dir=slice_dir,slice_height=slice_height,slice_width=slice_width,overlap_height_ratio=overlap_height_ratio,overlap_width_ratio=overlap_width_ratio,auto_slice_resolution=auto_slice_resolution,)num_slices = len(slice_image_result)time_end = time.time() - time_startdurations_in_seconds["slice"] = time_end# init match postprocess instance#支持"GREEDYNMM","NMM","NMS","LSNMS"后处理,GREEDYNMM为默认if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():raise ValueError(f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}")elif postprocess_type == "UNIONMERGE":# deprecated in v0.9.3raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.")#选择一个后处理模块postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]postprocess = postprocess_constructor(match_threshold=postprocess_match_threshold,match_metric=postprocess_match_metric,class_agnostic=postprocess_class_agnostic,)# create prediction inputnum_group = int(num_slices / num_batch)if verbose == 1 or verbose == 2:tqdm.write(f"Performing prediction on {num_slices} slices.")object_prediction_list = []# perform sliced prediction#下面对每一个切片图进行标准推理,只是需要把切片上位置坐标还原到最原始大图上for group_ind in range(num_group):# prepare batch (currently supports only 1 batch)image_list = []shift_amount_list = []for image_ind in range(num_batch):image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])# perform batch prediction# num_batch=1,image_list始终为一个切片图prediction_result = get_prediction(image=image_list[0],detection_model=detection_model,shift_amount=shift_amount_list[0],full_shape=[slice_image_result.original_image_height,slice_image_result.original_image_width,],)# convert sliced predictions to full predictionsfor object_prediction in prediction_result.object_prediction_list:if object_prediction:  # if not emptyobject_prediction_list.append(object_prediction.get_shifted_object_prediction())# merge matching predictions during sliced predictionif merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:object_prediction_list = postprocess(object_prediction_list)# perform standard predictionif num_slices > 1 and perform_standard_pred:prediction_result = get_prediction(image=image,detection_model=detection_model,shift_amount=[0, 0],full_shape=[slice_image_result.original_image_height,slice_image_result.original_image_width,],postprocess=None,)object_prediction_list.extend(prediction_result.object_prediction_list)# merge matching predictions# 结果后处理:融合if len(object_prediction_list) > 1:object_prediction_list = postprocess(object_prediction_list)time_end = time.time() - time_startdurations_in_seconds["prediction"] = time_endif verbose == 2:print("Slicing performed in",durations_in_seconds["slice"],"seconds.",)print("Prediction performed in",durations_in_seconds["prediction"],"seconds.",)return PredictionResult(image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds)

1、get_sliced_prediction中的postprocess,默认为combine.py中的GreedyNMMPostprocess这个是核心函数,

class GreedyNMMPostprocess(PostprocessPredictions):def __call__(self,object_predictions: List[ObjectPrediction],):object_prediction_list = ObjectPredictionList(object_predictions)# 转化为pytorch tensorobject_predictions_as_torch = object_prediction_list.totensor()if self.class_agnostic:  # 一般不进入这里keep_to_merge_list = greedy_nmm(object_predictions_as_torch,match_threshold=self.match_threshold,match_metric=self.match_metric,)else:  # 一般进入这里,计算需要融合的目标keep_to_merge_list = batched_greedy_nmm(object_predictions_as_torch,match_threshold=self.match_threshold,match_metric=self.match_metric,)selected_object_predictions = []# 截取程序中keep_to_merge_list一小段:{34: [45, 53], 5: [29], 6: []}# 34: [45, 53]表示:第34个目标需要和45,53目标融合;5: [29]表示:第5个目标需要和29目标融合;6: []表示:第6个目标不需要和任何框融合;# 融合是逐步进行的,即第一次融合结果作为下一次融合输入for keep_ind, merge_ind_list in keep_to_merge_list.items():for merge_ind in merge_ind_list:#iou或ios大于指定阈值if has_match(object_prediction_list[keep_ind].tolist(),object_prediction_list[merge_ind].tolist(),self.match_metric,self.match_threshold,):#融合主函数,融合包坐标、分数、类别object_prediction_list[keep_ind] = merge_object_prediction_pair(object_prediction_list[keep_ind].tolist(), object_prediction_list[merge_ind].tolist())selected_object_predictions.append(object_prediction_list[keep_ind].tolist())return selected_object_predictions

需要注意的是sahi的score融合只是取其最大值,没有用博客中的计算公式

2、batched_greedy_nmm函数

def batched_greedy_nmm(object_predictions_as_tensor: torch.tensor,match_metric: str = "IOU",match_threshold: float = 0.5,
):"""Apply greedy version of non-maximum merging per category to avoid detectingtoo many overlapping bounding boxes for a given object.Args:object_predictions_as_tensor: (tensor) The location preds for the imagealong with the class predscores, Shape: [num_boxes,5].match_metric: (str) IOU or IOSmatch_threshold: (float) The overlap thresh formatch metric.Returns:keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indicesto keep to a list of prediction indices to be merged."""# x1,y1,x2,y2,score,clsidcategory_ids = object_predictions_as_tensor[:, 5].squeeze()keep_to_merge_list = {}for category_id in torch.unique(category_ids):curr_indices = torch.where(category_ids == category_id)[0]  # 全局索引# 当前类别下目标框可能需要合并的信息,局部索引curr_keep_to_merge_list = greedy_nmm(object_predictions_as_tensor[curr_indices], match_metric, match_threshold)curr_indices_list = curr_indices.tolist()# 局部索引转化为全局索引for curr_keep, curr_merge_list in curr_keep_to_merge_list.items():keep = curr_indices_list[curr_keep]merge_list = [curr_indices_list[curr_merge_ind] for curr_merge_ind in curr_merge_list]keep_to_merge_list[keep] = merge_listreturn keep_to_merge_list

3、greedy_nmm函数

def greedy_nmm(object_predictions_as_tensor: torch.tensor,match_metric: str = "IOU",match_threshold: float = 0.5,
):"""Apply greedy version of non-maximum merging to avoid detecting too manyoverlapping bounding boxes for a given object.Args:object_predictions_as_tensor: (tensor) The location preds for the imagealong with the class predscores, Shape: [num_boxes,5].object_predictions_as_list: ObjectPredictionList Object prediction objectsto be merged.match_metric: (str) IOU or IOSmatch_threshold: (float) The overlap thresh formatch metric.Returns:keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indicesto keep to a list of prediction indices to be merged."""# 非极大值融合,和nms类似keep_to_merge_list = {}# we extract coordinates for every# prediction box present in Px1 = object_predictions_as_tensor[:, 0]y1 = object_predictions_as_tensor[:, 1]x2 = object_predictions_as_tensor[:, 2]y2 = object_predictions_as_tensor[:, 3]# we extract the confidence scores as wellscores = object_predictions_as_tensor[:, 4]# calculate area of every block in Pareas = (x2 - x1) * (y2 - y1)# sort the prediction boxes in P# according to their confidence scoresorder = scores.argsort()while len(order) > 0:# extract the index of the# prediction with highest score# we call this prediction Sidx = order[-1]# remove S from Porder = order[:-1]# sanity checkif len(order) == 0:keep_to_merge_list[idx.tolist()] = []break# select coordinates of BBoxes according to# the indices in orderxx1 = torch.index_select(x1, dim=0, index=order)xx2 = torch.index_select(x2, dim=0, index=order)yy1 = torch.index_select(y1, dim=0, index=order)yy2 = torch.index_select(y2, dim=0, index=order)# find the coordinates of the intersection boxesxx1 = torch.max(xx1, x1[idx])yy1 = torch.max(yy1, y1[idx])xx2 = torch.min(xx2, x2[idx])yy2 = torch.min(yy2, y2[idx])# find height and width of the intersection boxesw = xx2 - xx1h = yy2 - yy1# take max with 0.0 to avoid negative w and h# due to non-overlapping boxesw = torch.clamp(w, min=0.0)h = torch.clamp(h, min=0.0)# find the intersection areainter = w * h# find the areas of BBoxes according the indices in orderrem_areas = torch.index_select(areas, dim=0, index=order)if match_metric == "IOU":# find the union of every prediction T in P# with the prediction S# Note that areas[idx] represents area of Sunion = (rem_areas - inter) + areas[idx]# find the IoU of every prediction in P with Smatch_metric_value = inter / unionelif match_metric == "IOS":# find the smaller area of every prediction T in P# with the prediction S# Note that areas[idx] represents area of Ssmaller = torch.min(rem_areas, areas[idx])# find the IoS of every prediction in P with Smatch_metric_value = inter / smallerelse:raise ValueError()# keep the boxes with IoU/IoS less than thresh_ioumask = match_metric_value < match_threshold# matched_box_indices = order[(mask == False).nonzero().flatten()].flip(dims=(0,))ids = (mask == False).nonzero().flatten()matched_box_indices0 = order[ids]matched_box_indices = matched_box_indices0.flip(dims=(0,))#左右翻转,分数降序排列unmatched_indices = order[(mask == True).nonzero().flatten()]# update box poolorder = unmatched_indices[scores[unmatched_indices].argsort()]# create keep_ind to merge_ind_list mappingkeep_to_merge_list[idx.tolist()] = []for matched_box_ind in matched_box_indices.tolist():keep_to_merge_list[idx.tolist()].append(matched_box_ind)return keep_to_merge_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/54677.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux C接口编程入门之文件I/O

一切皆文件 "Linux一切皆文件"是Linux操作系统中的一个重要理念和设计原则。在Linux系统中&#xff0c;几乎所有的设备、资源都以文件的形式进行访问和操作。简化了操作系统的设计和管理&#xff0c;提供了一种统一的抽象模型&#xff0c;使得应用程序可以使用相同的…

docker简述

1.安装dockers&#xff0c;配置docker软件仓库 安装&#xff0c;可能需要开代理&#xff0c;这里我提前使用了下好的包安装 启动docker systemctl enable --now docker查看是否安装成功 2.简单命令 拉取镜像&#xff0c;也可以提前下载使用以下命令上传 docker load -i imag…

【gRPC】1—gRPC是什么

gRPC是什么 ⭐⭐⭐⭐⭐⭐ Github主页&#x1f449;https://github.com/A-BigTree 笔记链接&#x1f449;https://github.com/A-BigTree/Code_Learning ⭐⭐⭐⭐⭐⭐ 如果可以&#xff0c;麻烦各位看官顺手点个star~&#x1f60a; &#x1f4d6;RPC专栏&#xff1a;https://b…

数据工程师岗位常见面试问题-3(附回答)

数据工程师已成为科技行业最重要的角色之一&#xff0c;是组织构建数据基础设施的骨干。随着企业越来越依赖数据驱动的决策&#xff0c;对成熟数据工程师的需求会不断上升。如果您正在准备数据工程师面试&#xff0c;那么应该掌握常见的数据工程师面试问题&#xff1a;包括工作…

脉冲下跳沿提取电路

本例中的电路可将负脉冲转换为正脉冲。尽管这个任务看似简单&#xff0c;但负脉冲的幅度为-5V~-2V。按照不同应用要求&#xff0c;正脉冲也需要不同的脉冲宽度&#xff0c;而负脉冲是梯形的。脉冲必须先经过一个长距离的传输线才能到达某个控制设备。有多个电路可以解决这一问题…

jQuery——解决快速点击翻页的bug

本文分享到此结束&#xff0c;欢迎大家评论区相互讨论学习&#xff0c;下一篇继续分享jQuery中内置动画的学习。

谷歌AI大模型Gemini API快速入门及LangChain调用视频教程

1. 谷歌Gemini API KEY获取及AI Studio使用 要使用谷歌Gemini API&#xff0c;首先需要获取API密钥。以下是获取API密钥的步骤&#xff1a; 访问Google AI Studio&#xff1a; 打开浏览器&#xff0c;访问Google AI Studio。使用Google账号登录&#xff0c;若没有账号&#xf…

大数据ETL数据提取转换和加载处理

什么是 ETL&#xff1f; 提取转换加载&#xff08;英语&#xff1a;Extract, transform, load&#xff0c;简称ETL&#xff09;&#xff0c;用来描述将资料从来源端经过抽取、转置、加载至目的端的过程。ETL一词较常用在数据仓库&#xff0c;但其对象并不限于数据仓库。 ETL&…

C++竞赛初阶—— 石头剪子布

题目内容 石头剪子布&#xff0c;是一种猜拳游戏。起源于中国&#xff0c;然后传到日本、朝鲜等地&#xff0c;随着亚欧贸易的不断发展它传到了欧洲&#xff0c;到了近现代逐渐风靡世界。简单明了的规则&#xff0c;使得石头剪子布没有任何规则漏洞可钻&#xff0c;单次玩法比…

Spring Cloud Netflix Zuul 网关详解及案例示范

1. 引言 在微服务架构中&#xff0c;API 网关作为服务间通信的入口&#xff0c;扮演着重要的角色。Netflix Zuul 是一个提供动态路由、监控、安全等功能的 API 网关服务器&#xff0c;它可以为微服务系统提供统一的入口&#xff0c;简化服务间的交互。在业务系统中&#xff0c…

【计网】【计网】从零开始学习http协议 ---理解http重定向和请求方法

去光荣地受伤&#xff0c; 去勇敢地痊愈自己。 --- 简嫃 《水问》--- 从零开始学习http协议 1 知识回顾2 认识网络重定向3 http请求方法3.1 http常见请求方法3.2 postman工具进行请求3.3 处理GET和POST参数 1 知识回顾 前面两篇文章中我们学习并实现了http协议下的请求与应…

Linux 命令 netstat 的 10 个基本用法

Netstat 简介 Netstat 是一款命令行工具&#xff0c;可用于列出系统上所有的网络套接字连接情况&#xff0c;包括 tcp, udp 以及 unix 套接字&#xff0c;另外它还能列出处于监听状态&#xff08;即等待接入请求&#xff09;的套接字。如果你想确认系统上的 Web 服务有没有起来…

行为设计模式 -观察者模式- JAVA

观察者模式 一.简介二. 案例2.1 抽象主题&#xff08;Subject&#xff09;2.2 具体主题&#xff08;Concrete Subject&#xff09;2.3 抽象观察者&#xff08;Observer&#xff09;2.4 具体观察者&#xff08;Concrete Observer&#xff09;2.5 测试 三. 结论3.1 优缺点3.2 使用…

【分别为微服务云原生】9分钟ActiveMQ延时消息队列:定时任务的革命与Quartz的较量

ActiveMQ延时消息队列&#xff1a;定时任务的革命与Quartz的较量 摘要&#xff1a; 在现代的消息驱动架构中&#xff0c;ActiveMQ的延迟消息队列功能为定时任务提供了一种新的解决方案。本文将详细介绍ActiveMQ延迟消息队列的功能、应用场景&#xff0c;并与Quartz定时任务进行…

STM32外设详解——ADC

来源&#xff1a;铁头山羊 基本概念 ①ADC是模数转换器的统称&#xff0c;stm32f103c8t6内部集成了2个12位主次逼近型ADC&#xff0c;外设名称为ADC1、ADC2。 ② 采样深度为12位意味着ADC可以将0~3.3V的模拟电压等比转换为0~4095的数字值&#xff08;分割为2的12次方份&…

网 络 安 全

网络安全是指保护网络系统及其所存储或传输的数据免遭未经授权访问、使用、揭露、破坏、修改或破坏的实践和技术措施。网络安全涉及多个方面&#xff0c;包括但不限于以下几个方面&#xff1a; 1. 数据保护&#xff1a;确保数据在传输和存储过程中的完整性和保密性&#xff0c;…

Java后端基础练习|请求参数

请求参数&#xff0c;可以通过四种方式传递到后端 请求路径查询参数请求体请求头 controller代码 package com.urfread.breaknews.core.controller;import com.urfread.breaknews.core.common.model.ResultData; import lombok.Data; import org.springframework.web.bind.a…

【深度学习】矩阵操作万能函数 einsum-爱因斯坦求和

很不错的transformer 的学习仓库&#xff1a;https://github.com/tianxinliao/Transformer-learning&#xff0c;记录一下自用 ref:https://blog.csdn.net/zhaohongfei_358/article/details/125273126 在学习transformer的时候&#xff0c;看到代码里面有 values self.values(…

命令设计模式

简介 命令模式&#xff08;Command Pattern&#xff09;是对命令的封装&#xff0c;每一个命令都是一个操作&#xff1a;请求方发出请求要求执行一个操作&#xff1b;接收方收到请求&#xff0c;并执行操作。命令模式解耦了请求方和接收方&#xff0c;请求方只需请求执行命令&…

银河麒麟V10安装ToDesk远程控制

银河麒麟V10安装ToDesk远程控制 ARM版本安装 1.下载arm的deb包 wget https://dl.todesk.com/linux/todesk_4.0.3_aarch64.deb2.安装 sudo apt-get install ./todesk_4.0.3_aarch64.deb3.启动todesk todesk