主要代码:
# forward the positve image-text pair
# 正向传播正面的图像文本对
output_pos = self.text_encoder.bert(encoder_embeds=text_embeds, attention_mask=text.attention_mask,encoder_hidden_states=image_embeds,encoder_attention_mask=image_atts, return_dict=True,mode='fusion',)
with torch.no_grad():bs = image.size(0) # 获取批量大小 weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) # 对image到text的相似度进行softmax,沿着第二个维度计算weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) # 对text到image的相似度进行softmax,沿着第二个维度计算weights_i2t.fill_diagonal_(0) # 将权重矩阵的对角线设为0weights_t2i.fill_diagonal_(0) # 将权重矩阵的对角线设为0# select a negative image for each text
# 为每个文本选择一个负面的图像
image_embeds_neg = []
for b in range(bs):neg_idx = torch.multinomial(weights_t2i[b], 1).item() # 根据权重选择负面图像的索引image_embeds_neg.append(image_embeds[neg_idx]) # 添加负面图像到列表
image_embeds_neg = torch.stack(image_embeds_neg, dim=0) # 将负面图像张量堆叠起来# select a negative text for each image
# 为每张图像选择一个负面的文本
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):neg_idx = torch.multinomial(weights_i2t[b], 1).item() # 根据权重选择负面文本的索引text_embeds_neg.append(text_embeds[neg_idx]) # 添加负面文本到列表text_atts_neg.append(text.attention_mask[neg_idx]) # 添加负面文本的注意力掩码到列表
text_embeds_neg = torch.stack(text_embeds_neg, dim=0) # 将负面文本张量堆叠起来
text_atts_neg = torch.stack(text_atts_neg, dim=0) # 将负面文本的注意力掩码张量堆叠起来text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) # 拼接所有的文本张量
text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0) # 拼接所有的文本的注意力掩码张量image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) # 拼接所有的图像张量
image_atts_all = torch.cat([image_atts, image_atts], dim=0) # 拼接所有的图像的注意力掩码张量output_neg = self.text_encoder.bert(encoder_embeds=text_embeds_all, attention_mask=text_atts_all,encoder_hidden_states=image_embeds_all,encoder_attention_mask=image_atts_all, return_dict=True,mode='fusion',) vl_embeddings = torch.cat([output_pos.last_hidden_state[:, 0, :], output_neg.last_hidden_state[:, 0, :]], dim=0) # 拼接正负样本的嵌入表示
vl_output = self.itm_head(vl_embeddings) # 输入到信息论训练头部 itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], # 创建信息论训练标签dim=0).to(image.device) # 将标签转移到相同的设备上
loss_itm = F.cross_entropy(vl_output, itm_labels) # 计算信息论训练损失
参考:多模态text-image模型之ITM loss-CSDN博客
求Loss的代码:
loss_itm = F.cross_entropy(vl_output, itm_labels)
-
vl_output
是模型输出的分类得分,itm_labels
是每个样本的真实标签。 -
vl_output
:模型输出的是经过训练头部(self.itm_head
)的得分,这个头部是一个全连接层,用于将模型学到的特征映射到正面和负面类别的得分。 -
itm_labels
:模型对应的标签,包含了每个样本的真实标签。torch.ones(bs, dtype=torch.long)
是正面样本的标签,设为 1,torch.zeros(2 * bs, dtype=torch.long)
是负面样本的标签,设为 0。然后,使用torch.cat
函数将这些标签连接起来,形成一个完整的标签张量。 -
loss_itm
:通过调用F.cross_entropy
函数计算模型输出和真实标签之间的交叉熵损失。这个损失反映了模型预测和实际标签之间的差异,用于指导模型参数的更新,以便更好地区分正面和负面样本。