多模态text-image模型之ITM loss(blip)

主要代码:

# forward the positve image-text pair
# 正向传播正面的图像文本对
output_pos = self.text_encoder.bert(encoder_embeds=text_embeds, attention_mask=text.attention_mask,encoder_hidden_states=image_embeds,encoder_attention_mask=image_atts,      return_dict=True,mode='fusion',)            
with torch.no_grad():bs = image.size(0)  # 获取批量大小          weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1)  # 对image到text的相似度进行softmax,沿着第二个维度计算weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1)  # 对text到image的相似度进行softmax,沿着第二个维度计算weights_i2t.fill_diagonal_(0)  # 将权重矩阵的对角线设为0weights_t2i.fill_diagonal_(0)  # 将权重矩阵的对角线设为0# select a negative image for each text
# 为每个文本选择一个负面的图像
image_embeds_neg = []    
for b in range(bs):neg_idx = torch.multinomial(weights_t2i[b], 1).item()  # 根据权重选择负面图像的索引image_embeds_neg.append(image_embeds[neg_idx])  # 添加负面图像到列表
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)  # 将负面图像张量堆叠起来# select a negative text for each image
# 为每张图像选择一个负面的文本
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):neg_idx = torch.multinomial(weights_i2t[b], 1).item()  # 根据权重选择负面文本的索引text_embeds_neg.append(text_embeds[neg_idx])  # 添加负面文本到列表text_atts_neg.append(text.attention_mask[neg_idx])  # 添加负面文本的注意力掩码到列表
text_embeds_neg = torch.stack(text_embeds_neg, dim=0)  # 将负面文本张量堆叠起来
text_atts_neg = torch.stack(text_atts_neg, dim=0)  # 将负面文本的注意力掩码张量堆叠起来text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)  # 拼接所有的文本张量
text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)  # 拼接所有的文本的注意力掩码张量image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)  # 拼接所有的图像张量
image_atts_all = torch.cat([image_atts, image_atts], dim=0)  # 拼接所有的图像的注意力掩码张量output_neg = self.text_encoder.bert(encoder_embeds=text_embeds_all, attention_mask=text_atts_all,encoder_hidden_states=image_embeds_all,encoder_attention_mask=image_atts_all,      return_dict=True,mode='fusion',)                         vl_embeddings = torch.cat([output_pos.last_hidden_state[:, 0, :], output_neg.last_hidden_state[:, 0, :]], dim=0)  # 拼接正负样本的嵌入表示
vl_output = self.itm_head(vl_embeddings)  # 输入到信息论训练头部            itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],  # 创建信息论训练标签dim=0).to(image.device)  # 将标签转移到相同的设备上
loss_itm = F.cross_entropy(vl_output, itm_labels)  # 计算信息论训练损失     

参考:多模态text-image模型之ITM loss-CSDN博客

求Loss的代码:

loss_itm = F.cross_entropy(vl_output, itm_labels)

 

  1. vl_output 是模型输出的分类得分,itm_labels 是每个样本的真实标签。

  2. vl_output:模型输出的是经过训练头部(self.itm_head)的得分,这个头部是一个全连接层,用于将模型学到的特征映射到正面和负面类别的得分。

  3. itm_labels:模型对应的标签,包含了每个样本的真实标签。torch.ones(bs, dtype=torch.long) 是正面样本的标签,设为 1,torch.zeros(2 * bs, dtype=torch.long) 是负面样本的标签,设为 0。然后,使用 torch.cat 函数将这些标签连接起来,形成一个完整的标签张量。

  4. loss_itm:通过调用 F.cross_entropy 函数计算模型输出和真实标签之间的交叉熵损失。这个损失反映了模型预测和实际标签之间的差异,用于指导模型参数的更新,以便更好地区分正面和负面样本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/798393.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【技术笔记】Ubuntu下VirtualBox不能识别USB解决办法(手把手解决)

环境说明 系统版本:Ubuntu 20.04 VirtualBox版本: 7.0.12 解决过程 扩展下载,进入VirtualBox 官方下载路径。选择本机安装版本,如下图所示,因笔者是7.0.x版本,因此点击第一条链接; 进入版本页…

机器学习(30)

文章目录 摘要一、文献阅读1. 题目2. abstract3. 网络架构3.1 Sequence Generative Adversarial Nets3.2 SeqGAN via Policy Gradient3.3 The Generative Model for Sequences3.4 The Discriminative Model for Sequences(CNN) 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过…

慢SQL问题排查

慢SQL问题排查是一个系统性的过程,它涉及到对数据库性能、查询优化以及系统资源的深入理解。 1. 收集慢查询日志 启用慢查询日志:大多数数据库系统(如MySQL、PostgreSQL等)都支持慢查询日志功能。启用该功能后,数据库…

关于npm和yarn的使用(自己的问题记录)

目录 一 npm 和 yarn 常用命令 二 package.json中 devDependencies 和 dependencies 的区别。 三 npm安装包时,加 --save和不加的区别 一 npm 和 yarn 常用命令 备注:以下命令以 axios 为例。 未完:待续。。。。 二 ​​​​​​​ …

【Vue3源码学习】— CH2.7 Computed: Vue 3 计算属性深入解析

Computed: Vue 3 计算属性深入解析 1.计算属性的基本用法2. ComputedRefImpl 类深入解析JavaScript 中的 getter 函数 3. 计算属性的创建:computed 方法解析3.1 源码解析3.2 使用示例 4. 计算属性的工作原理5. 手动实现简化的计算属性6. 结语 在 Vue 3 的响应式系统…

【教程】VOC数据集制作

语义分割任务中VOC数据集的制作,任务中只有一种标签:gas 文章目录 1、由黑白图像识别为txt标签2、txt转json3、数据集转VOC格式 1、由黑白图像识别为txt标签 由于使用CycleGAN网络进行风格迁移学习,生成了大量伪标签图像,因此需…

【递归与递推】数的计算|数的划分|耐摔指数

1.数的计算 - 蓝桥云课 (lanqiao.cn) 思路: 1.dfs的变量>每一次递归什么在变? (1)当前数的大小一直在变:sum (2)最高位的数:k 2.递归出口:最高位数字为1 3.注意&#…

鱼塘钓鱼(c++实现)

题目 有 N 个鱼塘排成一排,每个鱼塘中有一定数量的鱼,例如:N5 时,如下表: 即:在第 1 个鱼塘中钓鱼第 1 分钟内可钓到 10 条鱼,第 2 分钟内只能钓到 8 条鱼,……,第 5 分…

Codeforces Round 932 (Div. 2) ---- F. Andrey‘s Tree ---- 题解

F. Andreys Tree: 题目描述: 思路解析: 我们假设删除任意一个结点后,我们会将整个树切分为k个联通块,那么可以明确的知道我们只需要连接(k-1)条边就可以将这k个联通块重新连为一棵树。 那么最小代价是啥呢? 图解分…

0基础进入IT行业

0基础如何进入IT行业? 简介:对于没有任何相关背景知识的人来说,如何才能成功进入IT行业?是否有一些特定的方法或技巧可以帮助他们实现这一目标? 方向一:学习路径 对于零基础进入 IT 行业的人来说&#xff…

第十五题:最大距离

题目描述 在数列 a1,a2,⋯ ,an​中,定义两个元素 ai 和 aj​ 的距离为∣i−j∣∣ai−aj∣,即元素下标的距离加上元素值的差的绝对值,其中 ∣x∣ 表示 x 的绝对值。 给定一个数列,请问找出元素之间最大的元素距离。 输入描述 …

【网站项目】校园订餐小程序

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

鱼骨图功能实现

dom: <div class="module-content"><div class="title"><span>[</span><p>鱼骨图</p><span>]</span></div><div class="line-mian"></div><div :ref="module + i&q…

通过UDP实现参数配置

来讲讲UDP的一种常见应用 我们知道UDP是一种无连接的网络传输协议&#xff0c;在发送数据时指定目标IP及端口就可以将数据发送出去&#xff0c;因此特别适合用作网络设备发现。 我们可以自定义一个通信端口&#xff0c;假设为55555。我们再制定一个协议用于查询目标设备&#x…

2024-04-07 作业

作业要求&#xff1a; 1> 思维导图 2> 自由发挥应用场景实现一个登录窗口界面。 【可以是QQ登录界面、也可以是自己发挥的登录界面】 要求&#xff1a;尽量每行代码都有注释 作业1&#xff1a; 作业2&#xff1a; 运行代码&#xff1a; #include "myqwidget.h&quo…

hatch,现代化的 Python 项目管理和打包工具!

目录 前言 安装 特性 基本功能 项目创建 示例代码 虚拟环境管理 依赖管理 测试 打包和发布 高级功能 插件系统 配置环境管理 自定义构建选项 集成测试工具 实际应用场景 多环境管理 持续集成与持续部署&#xff08;CI/CD&#xff09; 项目原型化 依赖与包管理 总结 前言…

Q1剧集市场复盘:2024爱优腾谁在领跑国产剧市场?

2024年Q1剧集市场的成绩单出炉了。 复盘2024年第一季度剧集市场&#xff0c;可以用“生机勃勃”四个字来形容&#xff0c;虽然和去年相比&#xff0c;今年的第一季度缺少了《狂飙》这样的头部大爆款&#xff0c;但市场大盘走势向好。 根据灯塔专业版统计&#xff0c;2024Q1剧…

4.文件上传下载

一、配置文件 Spring#上传文件使用servlet:multipart:#单个文件最大上传大小max-file-size: 10MB#每次请求上传文件大小最大值max-request-size: 30MB #自定义参数 define:nginx:path: D:\uploadFile\ 二、service层 public interface FileService {void saveFile(byte[] f…

nginx配置实例-反向代理

目录 一、目标-反向代理实现效果 二、安装tomcat 三、配置nginx服务 四、配置反向代理 一、目标-反向代理实现效果 访问过程分析&#xff1a; 二、安装tomcat 1、安装jdk环境 新建/export/server目录 解压jdk 查看是否解压成功 配置jdk软连接 进入jdk的bin目录中&#x…

echart 折线图或散点图当横坐标为小数位时,若想显示整数该如何处理?

如图当前是这样的&#xff1a; 横坐标刻度目前是小数位&#xff0c;如果直接将小数位取整则会失去精度&#xff0c;所以我们要做的是刻度即是整数&#xff0c;又能显示小数位对应的数值&#xff1b; 思路就是直接手动设置刻度&#xff1a;设置xAxis的min,max,splitNumber,同时不…