class Grid(object):def __init__(self, d1, d2, rotate=1, ratio=0.5, mode=0, prob=0.8):self.d1 = d1self.d2 = d2self.rotate = rotateself.ratio = ratioself.mode = modeself.st_prob = self.prob = probdef set_prob(self, epoch, max_epoch):self.prob = self.st_prob *min(1, epoch / max_epoch)def __call__(self, img):if np.random.rand()> self.prob:return imgh = img.size(1)w = img.size(2)# 1.5* h,1.5* w works fine with the squared images# But with rectangular input, the mask might not be able to recover back to the input image shape# A square mask with edge length equal to the diagnoal of the input image# will be able to cover all the image spot after the rotation. This is also the minimum square.hh = math.ceil((math.sqrt(h * h + w * w)))d = np.random.randint(self.d1, self.d2)# d = self.d# maybe use ceil? but i guess no big differenceself.l = math.ceil(d * self.ratio)mask = np.ones((hh, hh), np.float32)st_h = np.random.randint(d)st_w = np.random.randint(d)for i inrange(-1, hh // d + 1):s = d * i + st_ht = s + self.ls =max(min(s, hh),0)t =max(min(t, hh),0)mask[s:t,:]*=0for i inrange(-1, hh // d + 1):s = d * i + st_wt = s + self.ls =max(min(s, hh),0)t =max(min(t, hh),0)mask[:, s:t]*=0r = np.random.randint(self.rotate)mask = Image.fromarray(np.uint8(mask))mask = mask.rotate(r)mask = np.asarray(mask)mask = mask[(hh - h)// 2:(hh - h) // 2 + h, (hh - w) // 2:(hh - w) // 2 + w]mask = torch.from_numpy(mask)if self.mode ==1:mask =1- maskmask = mask.expand_as(img)img = img * maskreturn imgclass GridMask(nn.Module):def __init__(self, d1=56, d2=128, rotate=360, ratio=0.4, mode=1, prob=0.8):super(GridMask, self).__init__()self.rotate = rotateself.ratio = ratioself.mode = modeself.st_prob = probself.grid =Grid(d1, d2, rotate, ratio, mode, prob)def set_prob(self, epoch, max_epoch):self.grid.set_prob(epoch, max_epoch)def forward(self, x):if not self.training:return x# n, c, h, w = x.size()# y =[]# for i inrange(n):# y.append(self.grid(x[i]))## y = torch.cat(y).view(n, c, h, w)return self.grid(x)