以下是在一个半监督情景中 weak_output_ul为弱扰动出来的logits ,strong_output_ul为强扰动出来的logits 两者尺寸都可看作[8,2,256,256]
CE:
weak_x_ul = self.encoder(A_ul, B_ul)
weak_output_ul = self.main_decoder(weak_x_ul)
weak_targets = F.softmax(weak_output_ul.detach(), dim=1)
strong_x_ul = self.encoder(s_A_ul, s_B_ul)
strong_output_ul = self.main_decoder(strong_x_ul)
#我们先得到伪标签 [8,256,256]
pseudo_labels = torch.argmax(weak_targets, dim=1)
#得到概率最大值和位置 [8,256,256]
max_probs, _ = torch.max(weak_targets, dim=1)
#得到Ture False矩阵 [8,256,256]
confidence_mask = max_probs > 0.95
#.long()是输入要求 reduction=none是为了不平均 得到概率举证
loss = F.cross_entropy(strong_output_ul, pseudo_labels.long(), reduction='none')
loss = loss * confidence_mask.float()
#计算需要的平均损失
loss_unsup = loss.sum() / confidence_mask.sum()
MSE:
import torch import torch.nn.functional as Fweak_targets=torch.rand(8,2,256,256) weak_targets=torch.softmax(weak_targets,dim=1)prob,_=torch.max(weak_targets,dim=1) confidence_mask=prob>0.95strong_output_ul=torch.rand(8,2,256,256)strong_output_ul = F.softmax(strong_output_ul, dim=1) mse_loss = F.mse_loss(strong_output_ul, weak_targets, reduction='none')# 应用置信度掩码 print(mse_loss.size()) mse_loss = mse_loss * confidence_mask.unsqueeze(1).float() # 确保confidence_mask在应用前与mse_loss的形状匹配loss_unsup = mse_loss.sum(dim=[1, 2, 3]) / confidence_mask.sum()# 计算最终的平均损失 loss_unsup = loss_unsup.mean()