多数度量学习的代码都需要进行挖掘,样本挖掘过程就是把一个Batch中的所有样本,根据标签来划分成正样本和负样本
这里我们只讨论多标签分类问题,标签是onehot编码,如果是单标签分类任务可以去看pytorch_metric_learning这个库有实现好的挖掘方法
比如输入样本为[Batch,Embedding],对应的标签是[Batch,Class]
对这些样本进行挖掘后得到以下三部分:
- Anchor :锚点样本,其实就是和输入的Batch一模一样,
- Positive Sample : 挖掘的正正样本
- Negtive Sample : 挖掘的负样本
import torch
import torch.nn as nn
import torchvision# 损失函数
class HibCriterion(nn.Module):def __init__(self):super().__init__()def forward(self, z_samples, alpha, beta, indices_tuple):n_samples = z_samples.shape[1]if len(indices_tuple) == 3:a, p, n = indices_tupleap = an = aelif len(indices_tuple) == 4:ap, p, an, n = indices_tuplealpha = torch.nn.functional.softplus(alpha)loss = 0for i in range(n_samples):z_i = z_samples[:, i, :]for j in range(n_samples):z_j = z_samples[:, j, :]prob_pos = torch.sigmoid(- alpha * torch.sum((z_i[ap] - z_j[p])**2, dim=1) + beta) + 1e-6prob_neg = torch.sigmoid(- alpha * torch.sum((z_i[an] - z_j[n])**2, dim=1) + beta) + 1e-6# maximize the probability of positive pairs and minimize the probability of negative pairsloss += -torch.log(prob_pos) - torch.log(1 - prob_neg)loss = loss / (n_samples ** 2)return loss.mean()def get_matches_and_diffs(labels):matches = (labels.float() @ labels.float().T).byte()diffs = matches ^ 1 # 异或运算得到负标签的矩阵return matches, diffsdef get_all_triplets_indices_vectorized_method(all_matches, all_diffs):"""Args:all_matches (torch.Tensor): 相同标签all_diffs (torch.Tensor): 不相同标签Processing : all_matches.unsqueeze(2) -> [Batch,Batch,1]all_diffs.unsqeeeze(1) -> [Batch,1,Batch] Returns:torch.Tensor: _description_"""triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)return torch.where(triplets)class TripletMinner(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.sim_mat = get_matches_and_diffsself.selctor = get_all_triplets_indices_vectorized_methoddef forward(self,labels):a , b = self.sim_mat(labels)c = self.selctor(a,b)return c