torch.where()中一般有三个参数。
第一个参数是一个判断条件。
第二个参数是条件成立时的值。
第三个参数是条件不成立时的值。
for batch in range(2):for i in range(256):for j in range(256):output[batch][i][j] = 0 if tensor_count_0[A_arg[batch,i,j]][B_arg[batch,i,j]].item() >= tensor_count_1[A_arg[batch,i,j]][B_arg[batch,i,j]].item() else 1
output,A_arg,B_arg尺寸为[2,256,256] tensor_count_0和tensor_count_1的尺寸为[15,15],它们都是tensor数据,且都在GPU上。所以可以改为并行方式:
output = torch.where(tensor_count_0[A_arg, B_arg] >= tensor_count_1[A_arg, B_arg], torch.zeros_like(output),torch.ones_like(output))