pytorch的数据增强功能并非是事先对整个数据集进行数据增强处理,而是在从dataloader中获取训练数据的时候(获取每个epoch的时候)才进行数据增强。
举个例子,如下面的数据增强代码:
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), # 对图像四周各填充4个0像素,然后随机裁剪成32*32transforms.RandomHorizontalFlip(), # 按0.5的概率水平翻转图片transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
假设数据集一共有100张图片,pytorch并非对数据集中的每张图片进行随机裁剪,再随机翻转,将数据集扩增到200张,然后用这固定的200张图来训练网络,这是错误的理解。
正确的理解应该是dataloader在每次生成epoch时才对数据集进行以上数据增强操作。由于数据增强有些操作是具有随机性的(例如上面的随机裁剪和随机翻转),导致每次epoch产生的数据都不相同,例如同一张图片在有的epoch翻转了,在有的epoch没有翻转,或者同一张图片在各个epoch裁剪的位置不一样,所以每次用来训练的数据不相同,到达了数据增强的目的。
当然,有些数据增强操作不具有随机性,如CenterCrop,每次都是对图片中间位置进行裁剪,不管在哪个epoch,裁剪出来的图片都一样。