🚩🚩🚩Transformer实战-系列教程总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)
3、ConvertCocoPolysToMask类
位置:datasets/coco.py/ConvertCocoPolysToMask类
ConvertCocoPolysToMask类主要是进行数据预处理,主要在CocoDetection类中被调用
从的ConvertCocoPolysToMask
类的代码来看,主要涉及到以下几种计算机视觉任务的数据预处理步骤:
- 物体检测(Object Detection):
- 体现:通过处理
bbox
(边界框)信息。代码中提取和调整bbox
坐标来适应物体检测任务的需求。
- 体现:通过处理
- 实例分割(Instance Segmentation):
- 体现:如果
return_masks
为True,将COCO多边形标注(segmentation
)转换为掩码(mask
)。这对于实例分割任务来说是必要的,因为它需要精确地区分图像中各个对象的形状。
- 体现:如果
- 姿态估计(Pose Estimation):
- 体现:通过处理
keypoints
信息。当标注中包含关键点数据时,代码会提取这些数据,这些数据对于识别和估计图像中人物的姿态非常有用。
- 体现:通过处理
class ConvertCocoPolysToMask(object):def __init__(self, return_masks=False):self.return_masks = return_masksdef __call__(self, image, target):w, h = image.sizeimage_id = target["image_id"]image_id = torch.tensor([image_id])anno = target["annotations"]anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]boxes = [obj["bbox"] for obj in anno] # x y w h# guard against no boxes via resizingboxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)boxes[:, 2:] += boxes[:, :2]boxes[:, 0::2].clamp_(min=0, max=w)boxes[:, 1::2].clamp_(min=0, max=h)classes = [obj["category_id"] for obj in anno]classes = torch.tensor(classes, dtype=torch.int64)if self.return_masks:segmentations = [obj["segmentation"] for obj in anno]masks = convert_coco_poly_to_mask(segmentations, h, w)keypoints = Noneif anno and "keypoints" in anno[0]:keypoints = [obj["keypoints"] for obj in anno]keypoints = torch.as_tensor(keypoints, dtype=torch.float32)num_keypoints = keypoints.shape[0]if num_keypoints:keypoints = keypoints.view(num_keypoints, -1, 3)keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])boxes = boxes[keep]classes = classes[keep]if self.return_masks:masks = masks[keep]if keypoints is not None:keypoints = keypoints[keep]target = {}target["boxes"] = boxestarget["labels"] = classesif self.return_masks:target["masks"] = maskstarget["image_id"] = image_idif keypoints is not None:target["keypoints"] = keypointsarea = torch.tensor([obj["area"] for obj in anno])iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])target["area"] = area[keep]target["iscrowd"] = iscrowd[keep]target["orig_size"] = torch.as_tensor([int(h), int(w)])target["size"] = torch.as_tensor([int(h), int(w)])return image, target
- 定义ConvertCocoPolysToMask类,用于处理COCO数据集的转换
- 类的初始化方法,参数return_masks用于控制是否返回标注的掩码信息
- return_masks
- 可调用方法,接收两个参数:image(PIL图像对象)和target(包含图像标注信息的字典)
- 图像w、h, 427 ∗ 640 427*640 427∗640,每张图片读进来的长宽都可能不一样
- 获取图像id,image id: 538686]
- 将id转化为Tensor,image id: tensor([538686])
- 获取标签的标注信息,包含面积、bbox框的长宽xy四个值、类别id、图像id、分割的标注信息
- 过滤标注信息,过滤掉有重叠框的,只保留对单个物体的框,包含重叠物体的不要,如果iscrowd为1,表示这个标注包含的是一个对象群,而不是单个对象
- 获取所有框,[[62.37, 135.48, 184.94, 364.52],…, [107.99, 46.17, 101.51, 157.66]]
- 框的数据转化为Tensor
- 将x、y、w、h
- 转化为
- x1、y1、x2、y2,tensor([[ 62.3700, 135.4800, 247.3100, 500.0000],…, [107.9900, 46.1700, 209.5000, 203.8300]])
- 获取当前图像的所有类别标签(可能对应有多个类别),[19, 21, 21, 1]
- 转化为Tensor,tensor([19, 21, 21, 1])
- 是否进行掩码转换
- 提取分割信息
- 调用函数将分割信息转化为掩码
- 初始化 keypoints (姿态估计任务使用)变量
- 判断标注信息中是否包含 keypoints 信息
- 提取所有标注的 keypoints 信息
- 将 keypoints 列表转换为PyTorch张量
- 获取 keypoints 的数量
- 判断是否存在 keypoints
- 重塑 keypoints 张量
- 过滤掉不合逻辑的边界框(即右下角坐标不大于左上角坐标的边界框),因为在标注数据的时候,外包人员如果没有按照标注说明去标,拉框不是从上面往下框住物体,而是从下往上,这会影响两个点的顺序判断
- 使用keep数组过滤边界框,保留有效的边界框
- 同样使用keep数组过滤类别ID,保留与有效边界框对应的类别ID
- 判断是否有掩码
- 使用keep数组过滤掩码,保留与有效边界框对应的掩码
- 如果 keypoints 信息存在
- 使用keep数组过滤 keypoints 信息
- 初始化一个新的字典target,用于存储处理后的标注信息
- 将过滤后的边界框信息添加到target字典
- 将过滤后的类别ID添加到target字典
- 判断是否有掩码
- 将过滤后的掩码添加到target字典
- 将图像ID添加到target字典
- 如果存在关键点信息
- 将过滤后的关键点信息添加到target字典
- 提取所有标注的面积信息,并转换为PyTorch张量
- 将过滤后的iscrowd信息添加到target字典,如果iscrowd为1,表示这个标注包含的是一个对象群,而不是单个对象
- 分别将原始图像的高度和宽度作为orig_size和size添加到target字典。这两个字段通常用于后续的处理或数据恢复步骤
- 返回处理后的图像和更新后的target字典
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)