代码地址:https://github.com/obss/sahi
get_sliced_prediction源码中重要是理解nms或nmm。nms经常遇到不说。
其中nmm即Non-Max Merging算法是最重要部分,它其实和nms比较类似。其具体原理我看到了一片博客,感觉讲的很好,如下:https://blog.roboflow.com/non-max-merging/
截取最重要部分如下:
Here are the steps Non-Max Merge takes:
- First, it sorts all detections by their confidence score, from highest to lowest.
- It then takes all pairs of detections and computes their IOU, checking how much the pair overlaps.
- From most confident to least, it will build groups of overlapping detections.
- It starts by creating a new group with the most confident non-grouped detection D1.
- Then, each non-grouped detection that overlaps with D1 by at least
iou_threshold
(specified by the user) is placed in the same group. - By repeating these two steps, we end up with mutually exclusive groups, such as [[D1, D2, D4], [D3], [D5, D6]].
- Then merging begins. This is done with detection pairs (D1, D2) and is implementation-specific. In supervision we:
- Make a new bounding box
xyxy
to fit both D1 and D2. - Make a new
mask
containing pixels where the masks of D1 or D2 were. - Create a new
confidence
value, adding together theconfidence
of D1 and D2, normalized by theirxyxy
areas.New Conf = (Conf 1 * Area 1 + Conf 2 * Area 2) / (Area 1 + Area 2)
- Copy
class_id
,tracker_id,
anddata
from the Detection with the higher confidence.
- Make a new bounding box
- The prior step is done on detection pairs. How do we merge the whole group?
- Create an empty list for results.
- If there's only one detection in a group, add it to the results list.
- Otherwise, pick the first two detections, compute the IOU again, and if it's above the user-specified
iou_threshold
, pairwise merge it as outlined in the prior step.
The resulting merged detection stays in the group as the new first element, and the group is shortened by 1. Continue pairwise merging while there are at least two elements in a group.
Note that the IOU calculation makes the algorithm more costly but is required to prevent the merged detection from growing boundlessly.
sahi库中get_sliced_prediction函数如下。
def get_sliced_prediction(image,detection_model=None,slice_height: int = None,slice_width: int = None,overlap_height_ratio: float = 0.2,overlap_width_ratio: float = 0.2,perform_standard_pred: bool = True,postprocess_type: str = "GREEDYNMM",postprocess_match_metric: str = "IOS",postprocess_match_threshold: float = 0.5,postprocess_class_agnostic: bool = False,verbose: int = 1,merge_buffer_length: int = None,auto_slice_resolution: bool = True,slice_export_prefix: str = None,slice_dir: str = None,
) -> PredictionResult:"""Function for slice image + get predicion for each slice + combine predictions in full image.Args:image: str or np.ndarrayLocation of image or numpy image matrix to slicedetection_model: model.DetectionModelslice_height: intHeight of each slice. Defaults to ``None``.slice_width: intWidth of each slice. Defaults to ``None``.overlap_height_ratio: floatFractional overlap in height of each window (e.g. an overlap of 0.2 for a windowof size 512 yields an overlap of 102 pixels).Default to ``0.2``.overlap_width_ratio: floatFractional overlap in width of each window (e.g. an overlap of 0.2 for a windowof size 512 yields an overlap of 102 pixels).Default to ``0.2``.perform_standard_pred: boolPerform a standard prediction on top of sliced predictions to increase large objectdetection accuracy. Default: True.postprocess_type: strType of the postprocess to be used after sliced inference while merging/eliminating predictions.Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.postprocess_match_metric: strMetric to be used during object prediction matching after sliced prediction.'IOU' for intersection over union, 'IOS' for intersection over smaller area.postprocess_match_threshold: floatSliced predictions having higher iou than postprocess_match_threshold will bepostprocessed after sliced prediction.postprocess_class_agnostic: boolIf True, postprocess will ignore category ids.verbose: int0: no print1: print number of slices (default)2: print number of slices and slice/prediction durationsmerge_buffer_length: intThe length of buffer for slices to be used during sliced prediction, which is suitable for low memory.It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered.scenario. See [the discussion](https://github.com/obss/sahi/pull/445).auto_slice_resolution: boolif slice parameters (slice_height, slice_width) are not given,it enables automatically calculate these params from image resolution and orientation.slice_export_prefix: strPrefix for the exported slices. Defaults to None.slice_dir: strDirectory to save the slices. Defaults to None.Returns:A Dict with fields:object_prediction_list: a list of sahi.prediction.ObjectPredictiondurations_in_seconds: a dict containing elapsed times for profiling"""# for profilingdurations_in_seconds = dict()# currently only 1 batch supportednum_batch = 1# create slices from full imagetime_start = time.time()#图像切片slice_image_result = slice_image(image=image,output_file_name=slice_export_prefix,output_dir=slice_dir,slice_height=slice_height,slice_width=slice_width,overlap_height_ratio=overlap_height_ratio,overlap_width_ratio=overlap_width_ratio,auto_slice_resolution=auto_slice_resolution,)num_slices = len(slice_image_result)time_end = time.time() - time_startdurations_in_seconds["slice"] = time_end# init match postprocess instance#支持"GREEDYNMM","NMM","NMS","LSNMS"后处理,GREEDYNMM为默认if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():raise ValueError(f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}")elif postprocess_type == "UNIONMERGE":# deprecated in v0.9.3raise ValueError("'UNIONMERGE' postprocess_type is deprecated, use 'GREEDYNMM' instead.")#选择一个后处理模块postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]postprocess = postprocess_constructor(match_threshold=postprocess_match_threshold,match_metric=postprocess_match_metric,class_agnostic=postprocess_class_agnostic,)# create prediction inputnum_group = int(num_slices / num_batch)if verbose == 1 or verbose == 2:tqdm.write(f"Performing prediction on {num_slices} slices.")object_prediction_list = []# perform sliced prediction#下面对每一个切片图进行标准推理,只是需要把切片上位置坐标还原到最原始大图上for group_ind in range(num_group):# prepare batch (currently supports only 1 batch)image_list = []shift_amount_list = []for image_ind in range(num_batch):image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])# perform batch prediction# num_batch=1,image_list始终为一个切片图prediction_result = get_prediction(image=image_list[0],detection_model=detection_model,shift_amount=shift_amount_list[0],full_shape=[slice_image_result.original_image_height,slice_image_result.original_image_width,],)# convert sliced predictions to full predictionsfor object_prediction in prediction_result.object_prediction_list:if object_prediction: # if not emptyobject_prediction_list.append(object_prediction.get_shifted_object_prediction())# merge matching predictions during sliced predictionif merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:object_prediction_list = postprocess(object_prediction_list)# perform standard predictionif num_slices > 1 and perform_standard_pred:prediction_result = get_prediction(image=image,detection_model=detection_model,shift_amount=[0, 0],full_shape=[slice_image_result.original_image_height,slice_image_result.original_image_width,],postprocess=None,)object_prediction_list.extend(prediction_result.object_prediction_list)# merge matching predictions# 结果后处理:融合if len(object_prediction_list) > 1:object_prediction_list = postprocess(object_prediction_list)time_end = time.time() - time_startdurations_in_seconds["prediction"] = time_endif verbose == 2:print("Slicing performed in",durations_in_seconds["slice"],"seconds.",)print("Prediction performed in",durations_in_seconds["prediction"],"seconds.",)return PredictionResult(image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds)
1、get_sliced_prediction中的postprocess,默认为combine.py中的GreedyNMMPostprocess这个是核心函数,
class GreedyNMMPostprocess(PostprocessPredictions):def __call__(self,object_predictions: List[ObjectPrediction],):object_prediction_list = ObjectPredictionList(object_predictions)# 转化为pytorch tensorobject_predictions_as_torch = object_prediction_list.totensor()if self.class_agnostic: # 一般不进入这里keep_to_merge_list = greedy_nmm(object_predictions_as_torch,match_threshold=self.match_threshold,match_metric=self.match_metric,)else: # 一般进入这里,计算需要融合的目标keep_to_merge_list = batched_greedy_nmm(object_predictions_as_torch,match_threshold=self.match_threshold,match_metric=self.match_metric,)selected_object_predictions = []# 截取程序中keep_to_merge_list一小段:{34: [45, 53], 5: [29], 6: []}# 34: [45, 53]表示:第34个目标需要和45,53目标融合;5: [29]表示:第5个目标需要和29目标融合;6: []表示:第6个目标不需要和任何框融合;# 融合是逐步进行的,即第一次融合结果作为下一次融合输入for keep_ind, merge_ind_list in keep_to_merge_list.items():for merge_ind in merge_ind_list:#iou或ios大于指定阈值if has_match(object_prediction_list[keep_ind].tolist(),object_prediction_list[merge_ind].tolist(),self.match_metric,self.match_threshold,):#融合主函数,融合包坐标、分数、类别object_prediction_list[keep_ind] = merge_object_prediction_pair(object_prediction_list[keep_ind].tolist(), object_prediction_list[merge_ind].tolist())selected_object_predictions.append(object_prediction_list[keep_ind].tolist())return selected_object_predictions
需要注意的是sahi的score融合只是取其最大值,没有用博客中的计算公式
2、batched_greedy_nmm函数
def batched_greedy_nmm(object_predictions_as_tensor: torch.tensor,match_metric: str = "IOU",match_threshold: float = 0.5,
):"""Apply greedy version of non-maximum merging per category to avoid detectingtoo many overlapping bounding boxes for a given object.Args:object_predictions_as_tensor: (tensor) The location preds for the imagealong with the class predscores, Shape: [num_boxes,5].match_metric: (str) IOU or IOSmatch_threshold: (float) The overlap thresh formatch metric.Returns:keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indicesto keep to a list of prediction indices to be merged."""# x1,y1,x2,y2,score,clsidcategory_ids = object_predictions_as_tensor[:, 5].squeeze()keep_to_merge_list = {}for category_id in torch.unique(category_ids):curr_indices = torch.where(category_ids == category_id)[0] # 全局索引# 当前类别下目标框可能需要合并的信息,局部索引curr_keep_to_merge_list = greedy_nmm(object_predictions_as_tensor[curr_indices], match_metric, match_threshold)curr_indices_list = curr_indices.tolist()# 局部索引转化为全局索引for curr_keep, curr_merge_list in curr_keep_to_merge_list.items():keep = curr_indices_list[curr_keep]merge_list = [curr_indices_list[curr_merge_ind] for curr_merge_ind in curr_merge_list]keep_to_merge_list[keep] = merge_listreturn keep_to_merge_list
3、greedy_nmm函数
def greedy_nmm(object_predictions_as_tensor: torch.tensor,match_metric: str = "IOU",match_threshold: float = 0.5,
):"""Apply greedy version of non-maximum merging to avoid detecting too manyoverlapping bounding boxes for a given object.Args:object_predictions_as_tensor: (tensor) The location preds for the imagealong with the class predscores, Shape: [num_boxes,5].object_predictions_as_list: ObjectPredictionList Object prediction objectsto be merged.match_metric: (str) IOU or IOSmatch_threshold: (float) The overlap thresh formatch metric.Returns:keep_to_merge_list: (Dict[int:List[int]]) mapping from prediction indicesto keep to a list of prediction indices to be merged."""# 非极大值融合,和nms类似keep_to_merge_list = {}# we extract coordinates for every# prediction box present in Px1 = object_predictions_as_tensor[:, 0]y1 = object_predictions_as_tensor[:, 1]x2 = object_predictions_as_tensor[:, 2]y2 = object_predictions_as_tensor[:, 3]# we extract the confidence scores as wellscores = object_predictions_as_tensor[:, 4]# calculate area of every block in Pareas = (x2 - x1) * (y2 - y1)# sort the prediction boxes in P# according to their confidence scoresorder = scores.argsort()while len(order) > 0:# extract the index of the# prediction with highest score# we call this prediction Sidx = order[-1]# remove S from Porder = order[:-1]# sanity checkif len(order) == 0:keep_to_merge_list[idx.tolist()] = []break# select coordinates of BBoxes according to# the indices in orderxx1 = torch.index_select(x1, dim=0, index=order)xx2 = torch.index_select(x2, dim=0, index=order)yy1 = torch.index_select(y1, dim=0, index=order)yy2 = torch.index_select(y2, dim=0, index=order)# find the coordinates of the intersection boxesxx1 = torch.max(xx1, x1[idx])yy1 = torch.max(yy1, y1[idx])xx2 = torch.min(xx2, x2[idx])yy2 = torch.min(yy2, y2[idx])# find height and width of the intersection boxesw = xx2 - xx1h = yy2 - yy1# take max with 0.0 to avoid negative w and h# due to non-overlapping boxesw = torch.clamp(w, min=0.0)h = torch.clamp(h, min=0.0)# find the intersection areainter = w * h# find the areas of BBoxes according the indices in orderrem_areas = torch.index_select(areas, dim=0, index=order)if match_metric == "IOU":# find the union of every prediction T in P# with the prediction S# Note that areas[idx] represents area of Sunion = (rem_areas - inter) + areas[idx]# find the IoU of every prediction in P with Smatch_metric_value = inter / unionelif match_metric == "IOS":# find the smaller area of every prediction T in P# with the prediction S# Note that areas[idx] represents area of Ssmaller = torch.min(rem_areas, areas[idx])# find the IoS of every prediction in P with Smatch_metric_value = inter / smallerelse:raise ValueError()# keep the boxes with IoU/IoS less than thresh_ioumask = match_metric_value < match_threshold# matched_box_indices = order[(mask == False).nonzero().flatten()].flip(dims=(0,))ids = (mask == False).nonzero().flatten()matched_box_indices0 = order[ids]matched_box_indices = matched_box_indices0.flip(dims=(0,))#左右翻转,分数降序排列unmatched_indices = order[(mask == True).nonzero().flatten()]# update box poolorder = unmatched_indices[scores[unmatched_indices].argsort()]# create keep_ind to merge_ind_list mappingkeep_to_merge_list[idx.tolist()] = []for matched_box_ind in matched_box_indices.tolist():keep_to_merge_list[idx.tolist()].append(matched_box_ind)return keep_to_merge_list