每个用户抽取一定数量的困难负样本,然后ssm
def batch_softmax_loss_neg(self, user_idx, rec_user_emb, pos_idx, item_emb):user_emb = rec_user_emb[user_idx]product_scores = torch.matmul(F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1).transpose(0, 1))pos_score = (rec_user_emb[user_idx] * item_emb[pos_idx]).sum(dim=-1)pos_score = torch.exp(pos_score / self.temp2)train_mask = self.data.ui_adj[user_idx, self.data.user_num:].toarray()train_mask = torch.tensor(train_mask).cuda()product_scores = product_scores * (1 - train_mask)neg_score, indices = product_scores.topk(500, dim=1, largest=True, sorted=True)neg_score = torch.exp(neg_score[:,400:] / self.temp2).sum(dim=-1)loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))return torch.mean(loss)
def batch_softmax_loss_neg(user_emb, pos_item_emb, neg_item_emb, temperature):user_emb, pos_item_emb, neg_item_emb = F.normalize(user_emb, dim=1), F.normalize(pos_item_emb, dim=1), F.normalize(neg_item_emb, dim=1)pos_score = (user_emb * pos_item_emb).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0],neg_item_emb.shape[1],user_emb.shape[1])neg_score = (user_emb * neg_item_emb).sum(dim=-1) # user_emb(n*1*d) neg_item_emb = (n*m*d)neg_score = torch.exp(neg_score / temperature).sum(dim=-1)loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))return torch.mean(loss)
均匀性损失(错误案例)
# def cal_uniform_loss(user_emb, item_emb):
# user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1)
# distance = user_emb - item_emb # n*d
# gaussian_potential = torch.exp(-2 * torch.norm(distance,p=2,dim=1))
# E_gaussian_potential = gaussian_potential.mean()
# return torch.log(E_gaussian_potential)
DNS
def DNSbpr(user_emb, pos_item_emb, neg_item_emb):pos_score = torch.mul(user_emb, pos_item_emb).sum(dim=1)user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0], neg_item_emb.shape[1], user_emb.shape[1])ttl_socre = (user_emb * neg_item_emb).sum(dim=-1)neg_score = torch.max(ttl_socre, dim=1).valuesloss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))return torch.mean(loss)
带margin的infonce
def InfoNCE_margin(view1, view2, temperature, margin, b_cos = True):if b_cos:view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)pos_score = (view1 * view2).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)margin = margin * torch.eye(view1.shape[0])ttl_score = torch.matmul(view1, view2.transpose(0, 1))ttl_score += margin.cuda(0)ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = -torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)def InfoNCE_tau(view1, view2, temperature):view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)pos_score = (view1 * view2).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)ttl_score = torch.matmul(view1, view2.transpose(0, 1))ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = -torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)def batch_bpr_loss(user_emb, item_emb):pos_score = torch.mul(user_emb, item_emb).sum(dim=1)neg_score = torch.matmul(user_emb, item_emb.transpose(0, 1)).mean(dim=1)loss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))return torch.mean(loss)def Dis_softmax(view1, view2, temperature, b_cos = True):if b_cos:view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)N,M = view1.shapepos_score = (view1 - view2).norm(p=2, dim=1)pos_score = torch.exp(pos_score / temperature)view1 = view1.unsqueeze(1).expand(N,N,M)view2 = view2.unsqueeze(0).expand(N,N,M)ttl_score = (view1 - view2).norm(p=2, dim=-1)ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)cl_loss = torch.log(pos_score / ttl_score+10e-6)return torch.mean(cl_loss)
LightGCN+对比学习
def forward(self, perturbed=False):ego_embeddings = torch.cat([self.embedding_dict['user_emb'], self.embedding_dict['item_emb']], 0)all_embeddings = []all_embeddings_cl = ego_embeddingsfor k in range(self.n_layers):ego_embeddings = torch.sparse.mm(self.sparse_norm_adj, ego_embeddings)if perturbed:random_noise = torch.rand_like(ego_embeddings).cuda()ego_embeddings += torch.sign(ego_embeddings) * F.normalize(random_noise, dim=-1) * self.epsall_embeddings.append(ego_embeddings)if k==self.layer_cl-1:all_embeddings_cl += F.normalize(all_embeddings[1]-all_embeddings[0], dim=-1) * self.epsfinal_embeddings = torch.stack(all_embeddings, dim=1)final_embeddings = torch.mean(final_embeddings, dim=1)user_all_embeddings, item_all_embeddings = torch.split(final_embeddings, [self.data.user_num, self.data.item_num])user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embeddings_cl, [self.data.user_num, self.data.item_num])if perturbed:return user_all_embeddings, item_all_embeddings,user_all_embeddings_cl, item_all_embeddings_clreturn user_all_embeddings, item_all_embeddings
def train(self):model = self.model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=self.lRate)hot_uidx, hot_iidx = self.select_ui_idx(500, mode='hot')cold_uidx, cold_iidx = self.select_ui_idx(500, mode='cold')norm_uidx, norm_iidx = self.select_ui_idx(500, mode='norm')iters = 10alphas_init = torch.tensor([1, 2], dtype=torch.float64).to(device)betas_init = torch.tensor([2, 1], dtype=torch.float64).to(device)weights_init = torch.tensor([1 - 0.05, 0.05], dtype=torch.float64).to(device)for epoch in range(self.maxEpoch):# epoch_rec_loss = []bmm_model = BetaMixture1D(iters, alphas_init, betas_init, weights_init)rec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)self.bmm_fit(rec_user_emb, rec_item_emb,torch.arange(self.data.user_num),np.random.randint(0,self.data.item_num, 100),bmm_model)for n, batch in enumerate(next_batch_pairwise(self.data, self.batch_size)):user_idx, pos_idx, rec_neg_idx = batchrec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)user_emb, pos_item_emb= rec_user_emb[user_idx], rec_item_emb[pos_idx]# rec_loss = self.batch_softmax_loss_neg(user_idx, rec_user_emb, pos_idx, rec_item_emb)# rec_neg_idx = torch.tensor(rec_neg_idx,dtype=torch.int64)# rec_neg_item_emb = rec_item_emb[rec_neg_idx]weight = self.getWeightSim(user_emb, pos_item_emb, bmm_model)rec_loss = weighted_SSM(user_emb,pos_item_emb,self.temp2,weight)cl_loss = self.cl_rate * self.cal_cl_loss([user_idx,pos_idx],rec_user_emb,cl_user_emb,rec_item_emb,cl_item_emb)batch_loss = rec_loss + l2_reg_loss(self.reg, user_emb, pos_item_emb) + cl_loss# epoch_rec_loss.append(rec_loss.item()), epoch_cl_loss.append(cl_loss.item())# Backward and optimizeoptimizer.zero_grad()batch_loss.backward()optimizer.step()if n % 100==0 and n>0:print('training:', epoch + 1, 'batch', n, 'rec_loss:', rec_loss.item(), 'cl_loss', cl_loss.item())with torch.no_grad():self.user_emb, self.item_emb = self.model()hot_emb = torch.cat([self.user_emb[hot_uidx],self.item_emb[hot_iidx]],0)cold_emb = torch.cat([self.user_emb[cold_uidx],self.item_emb[cold_iidx]],0)self.eval_uniform(epoch, hot_emb, cold_emb)hot_user_mag = self.cal_sim(epoch, hot_uidx, self.user_emb, self.item_emb,mode='hot')self.cal_sim(epoch, norm_uidx, self.user_emb, self.item_emb, mode='norm')cold_user_mag= self.cal_sim(epoch, cold_uidx, self.user_emb, self.item_emb, mode='cold')hot_item_mag = self.item_magnitude(epoch, hot_iidx, self.item_emb,mode='hot')self.item_magnitude(epoch, norm_iidx, self.item_emb, mode='norm')cold_item_mag = self.item_magnitude(epoch, cold_iidx, self.item_emb, mode='cold')print('training:',epoch + 1,'U_mag_ratio:',hot_user_mag/cold_user_mag, 'I_mag_ratio:',hot_item_mag/cold_item_mag)# self.getTopSimNeg(hot_uidx, self.user_emb,self.item_emb, 100)# self.getTopSimNeg(norm_uidx,self.user_emb,self.item_emb, 100)# self.getTopSimNeg(cold_uidx,self.user_emb,self.item_emb, 100)# epoch_rec_loss = np.array(epoch_rec_loss).mean()# self.loss.extend([epoch_rec_loss,epoch_cl_loss,hot_pair_uniform_loss.item(),random_item_uniform_loss.item()])# if epoch%5==0:# self.save_emb(epoch, hot_emb, mode='hot')# self.save_emb(epoch, random_emb, mode='random')self.fast_evaluation(epoch)# self.save_loss()self.user_emb, self.item_emb = self.best_user_emb, self.best_item_emb# self.save_emb(self.bestPerformance[0], hot_emb, mode='best_hot')# self.save_emb(self.bestPerformance[0], random_emb, mode='best_random')
hard_neg buffer
def getHardNeg(self, user_idx, pos_idx, rec_user_emb, rec_item_emb,temperature):u_emb,i_emb = F.normalize(rec_user_emb[user_idx], dim=1),F.normalize(rec_item_emb[pos_idx], dim=1)pos_score = (u_emb * i_emb).sum(dim=-1)pos_score = torch.exp(pos_score / temperature)i_emb = i_emb.unsqueeze(0).expand(u_emb.size(0), -1, -1)neg_idx = torch.LongTensor(pos_idx).unsqueeze(0).expand(u_emb.size(0), -1).to(device)# if torch.all(self.hardNeg[user_idx])!=0:# preNeg = self.hardNeg[user_idx]# preNeg_emb = F.normalize(rec_item_emb[preNeg], dim=1)# neg_idx = torch.cat([neg_idx,preNeg],1)# i_emb = torch.cat([i_emb, preNeg_emb],1)ttl_score = (u_emb.unsqueeze(1) * i_emb).sum(dim=-1)indices = torch.topk(ttl_score, k=100)[1].detach()ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)rec_loss = -torch.log(pos_score / ttl_score + 10e-6)chosen_hardNeg = neg_idx[torch.arange(i_emb.size(0)).unsqueeze(1), indices]self.hardNeg[user_idx] = chosen_hardNegreturn torch.mean(rec_loss)