论文标题:SegFix: Model-Agnostic Boundary Refinement for Segmentation
论文地址:https://arxiv.org/pdf/2007.04269.pdf
代码地址:https://github.com/openseg-group/openseg.pytorch
两种loss监督
八种方向变回归问题为分类问题
代码地址:
1、使用sobel算子把边界点的方向换成分类问题
for id in range(1, len(label_list) + 1):labelmap_i = labelmap.copy()labelmap_i[labelmap_i != id] = 0labelmap_i[labelmap_i == id] = 1if labelmap_i.sum() < 100:continueif args.metric == 'euc':depth_i = distance_transform_edt(labelmap_i)elif args.metric == 'taxicab':depth_i = distance_transform_cdt(labelmap_i, metric='taxicab')else:raise RuntimeErrordepth_map += depth_idir_i_before = dir_i = np.zeros_like(dir_map)dir_i = torch.nn.functional.conv2d(torch.from_numpy(depth_i).float().view(1, 1, *depth_i.shape), sobel_ker, padding=ksize//2).squeeze().permute(1, 2, 0).numpy()# The following line is necessarydir_i[(labelmap_i == 0), :] = 0dir_map += dir_i
2、计算偏移量
def shift(x, offset):"""x: h x woffset: 2 x h x w"""h, w = x.shapex = torch.from_numpy(x).unsqueeze(0)offset = torch.from_numpy(offset).unsqueeze(0)coord_map = gen_coord_map(h, w)norm_factor = torch.FloatTensor([(w-1)/2, (h-1)/2])grid_h = offset[:, 0]+coord_map[0]grid_w = offset[:, 1]+coord_map[1]grid = torch.stack([grid_w, grid_h], dim=-1) / norm_factor - 1x = F.grid_sample(x.unsqueeze(1).float(), grid, padding_mode='border', mode='bilinear').squeeze().numpy()x = np.round(x)return x.astype(np.uint8)
3、重新计算label
class LabelTransformer:label_list = [7, 8, 11, 12, 13, 17, 19, 20,21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]@staticmethoddef encode(labelmap):labelmap = np.array(labelmap)shape = labelmap.shapeencoded_labelmap = np.ones(shape=(shape[0], shape[1]), dtype=np.int) * 255for i in range(len(LabelTransformer.label_list)):class_id = LabelTransformer.label_list[i]encoded_labelmap[labelmap == class_id] = ireturn encoded_labelmap@staticmethoddef decode(labelmap):labelmap = np.array(labelmap)shape = labelmap.shapeencoded_labelmap = np.ones(shape=(shape[0], shape[1]), dtype=np.uint8) * 255for i in range(len(LabelTransformer.label_list)):class_id = iencoded_labelmap[labelmap ==class_id] = LabelTransformer.label_list[i]return encoded_labelmap