前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose
类和 to_multi
包装函数。不过 [1] 没考虑各图用不同 augmentation 的情况,如:
- ColorJitter 只对 image 做,而不对 label 做;
- image 的 resize interpolation 可任选,但 label 只能用
nearest
。
本篇更新写法,支持各图同用、独用 augmentation。
Code
- 对比 [1],主要改变是改写
MultiCompose
类,并将to_multi
吸收入内。 MultiCompose
的用法还是和torchvision.transforms.Compose
几乎一致,不过支持独用 augmentation:只要为各图指定各自的 augmentation 类/函数即可。见下一节例程。
def to_multi():"""不用单独的 to_multi 打包了,已并入 MultiCompose"""raise NotImplementedErrorclass MultiCompose:"""扩展 torchvision.transforms.Compose:支持输入多图,且保证各 augmentation 中所有输入都用同一随机状态(如旋转同一随机角度),分割任务有用。"""# numpy.random.seed range error:# ValueError: Seed must be between 0 and 2**32 - 1MIN_SEED = 0 # - 0x8000_0000_0000_0000MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)def __init__(self, transforms):"""输入:一个 list/tuple,其中每个元素可以是一个 augmentation 对象(transform)/函数,各输入同用;或一个嵌套的 list/tuple,为每个输入指定独用的 augmentation。"""# self.transforms = [to_multi(t) for t in transforms]no_op = lambda x: x # i.e. identity functionself.transforms = []for t in transforms:if isinstance(t, (tuple, list)):# convert `None` to `no_op` for convenienceself.transforms.append([no_op if _t is None else _t for _t in t])else:self.transforms.append(t)def __call__(self, *images):for t in self.transforms:if isinstance(t, (tuple, list)): # 独用assert len(images) <= len(t) # allow redundant transformelse: # 同用t = [t] * len(images)_aug_images = []_seed = random.randint(self.MIN_SEED, self.MAX_SEED)for _im, _t in zip(images, t):seed_everything(_seed)_aug_images.append(_t(_im))images = _aug_imagesif len(images) == 1:images = images[0]return images
Usage & Test
例程沿用 [1],但改一下 augmentation:
train_trans = MultiCompose([# image 用 bilinear,label 用 nearest(ResizeZoomPad((224, 256), "bilinear"), ResizeZoomPad((224, 256), "nearest")), # 独用transforms.RandomAffine(30, (0.1, 0.1)), # 同用,传一个就行transforms.RandomHorizontalFlip(), # 同用# ColorJitter 只对 image 做,label 不做(None)[transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None], # 独用
])
- 效果:
References
- pytorch一致数据增强