给定一堆二维图像,要拼成三维,然后要在x,y,z上分别预测,预测要预测一个二维patch,还要tta,最后平均结果
import torchM, N, R = 40, 40, 4
CUT_SIZE = 10
OFFSET = 5def get_data():# (M, N) * Rtemp = [torch.rand(M, N) for _ in range(R)]# (M, N, R)res = torch.stack(temp, dim=-1)# (1, M, N, R)res = res.unsqueeze(0)# channelif res.ndim == 4:res = res.unsqueeze(1)return resdef get_split(length, cut_size, offset):splits_left = [i for i in range(0, length - cut_size, offset)] + [max(0, length - cut_size)]if len(splits_left) > 1 and splits_left[-1] == splits_left[-2]:splits_left.pop()splits_right = [min(i + CUT_SIZE, length) for i in splits_left]return splits_left, splits_rightdef f(x):# modelreturn xdef get_pad(x, size):# pad x in the centerb, c, m, n = x.shapetarget_x, target_y = sizeres = torch.zeros(b, c, target_x, target_y, device=x.device)x1 = (target_x - m) // 2x2 = x1 + my1 = (target_y - n) // 2y2 = y1 + ns1 = slice(x1, x2)s2 = slice(y1, y2)res[..., s1, s2] = xreturn s1, s2, resdef predict_by_axis(model, data, res, idx, axis, rotate_tta=True):with torch.no_grad():# b = 1b, c, m, n, r = data.shapex_splits_left, x_splits_right = get_split(m, CUT_SIZE, OFFSET)y_splits_left, y_splits_right = get_split(n, CUT_SIZE, OFFSET)z_splits_left, z_splits_right = get_split(r, CUT_SIZE, OFFSET)if axis == 0:iter_1_left = y_splits_leftiter_1_right = y_splits_rightiter_2_left = z_splits_leftiter_2_right = z_splits_rightlength = melif axis == 1:iter_1_left = x_splits_leftiter_1_right = x_splits_rightiter_2_left = z_splits_leftiter_2_right = z_splits_rightlength = nelse:iter_1_left = x_splits_leftiter_1_right = x_splits_rightiter_2_left = y_splits_leftiter_2_right = y_splits_rightlength = rfor i in range(length):for x1, x2 in zip(iter_1_left, iter_1_right):for y1, y2 in zip(iter_2_left, iter_2_right):if axis == 0:slice1 = islice2 = slice(x1, x2)slice3 = slice(y1, y2)elif axis == 1:slice1 = slice(x1, x2)slice2 = islice3 = slice(y1, y2)else:slice1 = slice(x1, x2)slice2 = slice(y1, y2)slice3 = i# 1 * c * (x2 - x1) * (y2 - y1)x = data[..., slice1, slice2, slice3]# x (b, c, cut_size, cut_size)s1, s2, x = get_pad(x, (CUT_SIZE, CUT_SIZE))x = x.cuda()if rotate_tta:for state in range(4):dims = []j = 0while (1 << j) <= state:if (1 << j) & state:dims.append(j + 2)j += 1y = torch.flip(model(x.flip(dims)), dims)res[..., slice1, slice2, slice3] += y[..., s1, s2].detach().cpu()idx[..., slice1, slice2, slice3] += 1else:y = model(x)res[..., slice1, slice2, slice3] += y[..., s1, s2].detach().cpu()idx[..., slice1, slice2, slice3] += 1if __name__ == '__main__':x = get_data()res = torch.zeros_like(x)idx = torch.zeros_like(x)predict_by_axis(f, x, res, idx, 0, True)predict_by_axis(f, x, res, idx, 1, True)predict_by_axis(f, x, res, idx, 2, True)print(idx.unique())ans = res / idxprint(torch.allclose(ans, x))