多模态(二)--- CoCa原理与源码解读

1 CoCa简介

CoCa代表Contrastive Captioner 的缩写,代表模型用两个目标函数训练出来的,一个是Contrastive Loss,一个是Captioning Loss。

2 CoCa训练流程

  1. 利用ViT对image进行encoder编码获得图像特征token
  2. 对图像特征进行attention pooling(multihead attention), 取第0位作为计算对比损失的cls-token,后255位作为计算生成损失的视觉token
  3. 对text进行embedding编码,在文本token末尾嵌入cls_token
  4. 生成相应的单词遮挡掩膜mask,给text-token加上位置编码
  5. 将text-token和mask-atten送入transformer学习获得文本cls_token(text_latent), 和其余单词token(token_emb)
    在这里插入图片描述

2.1 image encoder

    def _encode_image(self, images, normalize: bool = True):image_latent, tokens_embs = self.visual(images)image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent# image_latent:constractive_token, tokens_embs: caption_tokenreturn image_latent, tokens_embs#### self.visual(images):def forward(self, x: torch.Tensor):# [b, 3, 224, 224]--->[b, 1024, 16, 16]x = self.conv1(x)# [b, 1024, 16, 16]--->[b, 1024, 256]x = x.reshape(x.shape[0], x.shape[1], -1)# [b, 1024, 256]--->[b, 256, 1024]x = x.permute(0, 2, 1)# 在序列长度上给图像嵌入一个类别,x:[b, 256 + 1, 1024]x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)# 嵌入位置编码,x:[b, 256 + 1, 1024]x = x + self.positional_embedding.to(x.dtype)# patch_dropout, x:[b, 256 + 1, 1024]x = self.patch_dropout(x)# LayerNorm处理 x:[b, 256 + 1, 1024]x = self.ln_pre(x)# NLD -> LND [b, 256 + 1, 1024]---> [256 + 1, b, 1024]x = x.permute(1, 0, 2)# transformer网络处理x = self.transformer(x)# LND -> NLD  [256 + 1, b, 1024]--->[b, 256 + 1, 1024]x = x.permute(1, 0, 2)if self.attn_pool is not None:# this is the original OpenCLIP CoCa setup, does not match paper# x:[b, 257, 1024]--->[b, 256, 768]x = self.attn_pool(x)# ln归一化, [b, 256, 768]x = self.ln_post(x)# pooled: 类别token:[b, 768] tokens:图像token:[b, 255, 768]pooled, tokens = self._global_pool(x)# pooled: [b, 768]@[768, 768]--->[b, 768]if self.proj is not None:pooled = pooled @ self.proj# 同时返回cls-token和视觉tokenif self.output_tokens:return pooled, tokensreturn pooled
# self.attn_pool(x)
class AttentionalPooler(nn.Module):def __init__(self,d_model: int,context_dim: int,n_head: int = 8,n_queries: int = 256,norm_layer: Callable = LayerNorm):super().__init__()self.query = nn.Parameter(torch.randn(n_queries, d_model))self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)self.ln_q = norm_layer(d_model)self.ln_k = norm_layer(context_dim)def forward(self, x: torch.Tensor):# ln归一化,NLD -> LND [b, 257, 1024]--->[257, b, 1024]x = self.ln_k(x).permute(1, 0, 2)N = x.shape[1]# q: [256, 768]q = self.ln_q(self.query)# q: [256, 768]--->[256, 1, 768]--->[256,b, 768], k=v=x, x:[257, b, 1024]# out: [256, b, 768], MultiheadAttentionout = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0]# out: [256, b, 768]--->[b, 256, 768]return out.permute(1, 0, 2)  # LND -> NLD

2.2 Unimodal text decoder

    def _encode_text(self, text, normalize: bool = True):# text_latent:[b, 768], token_emb:[b, 76, 768]text_latent, token_emb = self.text(text)text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latentreturn text_latent, token_embdef forward(self, text):cast_dtype = self.transformer.get_cast_dtype()seq_len = text.shape[1]# x:[b, 76, 768], 将text:[b, 76]进行embeding, F.embedding(text, weight=[40408, 768])49408---一共49408个单词,每个单词维度768x = self.token_embedding(text).to(cast_dtype)attn_mask = self.attn_maskif self.cls_emb is not None:seq_len += 1# 在文本token末尾嵌入cls_token, x:[b, 76, 768] ---> [b, 76+1, 768]x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)# cls_mask: [12b, 77, 77], text:[b, 76]cls_mask = self.build_cls_mask(text, cast_dtype)# 将单词有序遮挡mask, attn_mask: [[0, -inf, -inf,...-inf], [0, 0, -inf, ..., -inf],...[0, 0, 0,...,0,-inf], [0, 0, 0, ...,0]]if attn_mask is not None:# attn_mask: [1,77, 77] + cls_mask[12b, 77, 77] ===> 获得最终的attn_mask: [12b, 77, 77], 有单词的位置为0, 被遮挡以及没单词的位置为-infattn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]# 加上位置编码, x: [b, 77, 768]x = x + self.positional_embedding[:seq_len].to(cast_dtype)# x: [b, 77, 768]--->[77, b, 768]x = x.permute(1, 0, 2)  # NLD -> LND# 进入transformer学习, x:[77, b, 768]x = self.transformer(x, attn_mask=attn_mask)# x: [77, b, 768]--->[b, 77, 768]x = x.permute(1, 0, 2)  # LND -> NLD# x.shape = [batch_size, n_ctx, transformer.width]if self.cls_emb is not None:# presence of appended cls embed (CoCa) overrides pool_type, always take last token# pooled: cls_token:[b, 768] tokens:图像token:[b, 76, 768]pooled, tokens = text_global_pool(x, pool_type='last')# layernormpooled = self.ln_final(pooled)  # final LN applied after pooling in this case# [b, 768] @ 【768, 768】---> [b, 768]pooled = pooled @ self.text_projection# pooled:[b, 768], tokens:[b, 76, 768]if self.output_tokens:return pooled, tokensreturn pooled
    def build_cls_mask(self, text, cast_dtype: torch.dtype):# 找到text中存在单词的cls_mask,值设为True,text:[b, 76], cls_mask: [b, 1, 76]cls_mask = (text != self.pad_id).unsqueeze(1)# cls_mask: [b, 1, 76]--->[b, 77, 77]cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)# 随机一个[b, 77, 77]的maskadditive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)# 全部填充为0 additive_mask:[b, 77, 77]additive_mask.fill_(0)# 不满77长度的单词中,0填充的位置换为-infadditive_mask.masked_fill_(~cls_mask, float("-inf"))# 将additive_mask在batch维度上重复self.heads(12)次,[b, 77, 77]--->[12b, 77, 77]additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)return additive_mask

2.3 Multimodal text decoder

# logits: [b, 76, 49408], image_embs:caption_embedings[b, 255, 768], token_embs:文本embedings [b, 76, 768]
logits = self.text_decoder(image_embs, token_embs)
# self.text_decoder forward
def forward(self, image_embs, text_embs):# [b, 76, 768]--->[76, b, 768]text_embs = text_embs.permute(1, 0, 2)# [b, 255, 768]--->[255, b, 768]image_embs = image_embs.permute(1, 0, 2)# 76seq_len = text_embs.shape[0]# cross-attention: q=text_embs, k_x=image_embs, v_x=image_embsfor resblock, cross_attn in zip(self.resblocks, self.cross_attn):text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])# q=text_embs, k_x=image_embs, v_x=image_embstext_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)# x: [76, b, 768]--->[b, 76, 768]x = text_embs.permute(1, 0, 2)  # LND -> NLD# layer_normx = self.ln_final(x)# x: [b, 76, 768] @ [768, 49408] ---> [b, 76, 49408]if self.text_projection is not None:x = x @ self.text_projection# [b, 76, 49408]return x

2.4 Loss计算

    def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):clip_loss = torch.tensor(0)# constractive lossif self.clip_loss_weight:# image_features: [b, 768], text_features:[b, 768], logit_scale:温度系数clip_loss = super().forward(image_features, text_features, logit_scale)clip_loss = self.clip_loss_weight * clip_loss# caption loss, self.caption_loss:CE losscaption_loss = self.caption_loss(logits.permute(0, 2, 1), # [b, 76, 49408]labels, # [b, 76])caption_loss = caption_loss * self.caption_loss_weightif output_dict:return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}return clip_loss, caption_loss# clip_lossdef forward(self, image_features, text_features, logit_scale, output_dict=False):device = image_features.device# 假设有N个图像-文本对: logits_per_image: [N, N], logits_per_text: [N, N]logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)# 假设有N个图像-文本对:labels=[0, 1, 2,....N]labels = self.get_ground_truth(device, logits_per_image.shape[0])# 总损失 = (图像维度的损失 + 文本维度的损失)/ 2total_loss = (F.cross_entropy(logits_per_image, labels) +    # 图像维度的损失F.cross_entropy(logits_per_text, labels)       # 文本维度的损失) / 2return {"contrastive_loss": total_loss} if output_dict else total_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/678748.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第四篇:SQL语法-DDL-数据定义语言

大年初一限定篇😀 (祝广大IT学习者、工作者0 error 0 warning!) 一,DDL数据库操作 (一)库的查询操作 1.列出所有已定义数据库 show databases; 2.查询当前所处数据库 select database(); &…

基于Linux的HTTP代理服务器搭建与配置实战

在数字化世界中,HTTP代理服务器扮演着至关重要的角色,它们能够帮助我们管理网络请求、提高访问速度,甚至在某些情况下还能保护我们的隐私。而Linux系统,凭借其强大的功能和灵活性,成为了搭建HTTP代理服务器的理想选择。…

【C语言——打印乘法口诀表】

乘法表: 我们可以定义一个i控制行的变化,外加看上图的表得知我们需要用到循环结构,i是行需要不停的加加,因此,for循环比较好用,可以用两个嵌套的循环,外层循环即用到的i表示的是每一行的打印&am…

【从Python基础到深度学习】4. Linux 常用命令

1.配置root用户密码 root用户为系统默认最高权限用户,其他用户密码修改命令与root用户修改密码命令相同 sudo passwd root 2.添加用户(henry) sudo useradd -m henry -s /bin/bash 3.配置henry用户密码 Xshell下连接新用户(hen…

基于javaEE的ssm仓库管理系统

仓库管理系统的重中之重是进销存分析这一板块,在这一板块中,顾名思义能够查询到近期的进货记录,包括每日的进货单据,单品推移(即某一商品的库存变化),方便我们核对库存差异。同时也需要查询到每日的销售数据&#xff0…

hexo部署到gitee(码云)

引言 Hexo 是一个基于Node.js的静态博客框架,而 Gitee(也被称为码云)是一个国内的代码托管平台,支持 Git 版本控制系统,与 GitHub 类似。将 Hexo 部署到 Gitee Pages 可以让你的博客受益于 Gitee 的国内服务器&#xf…

Java多态原理

参考 虚方法 JVM杂记:对多态实现原理、虚方法表、虚方法、静态解析、动态链接的一些思考_多态和方法表的关系-CSDN博客 静态分派与动态分派 (JVM)Java虚拟机:静态分派 & 动态分派 原理解析 - 掘金 虚方法表 JVM 栈帧&am…

C++——二叉树

引入 map和set特性需要先铺垫二叉搜索树,而二叉搜索树也是一种树形结构 二叉搜索树的特性了解,有助于更好的理解map和set的特性 1.二叉搜索树的概念及优缺点 1.1二叉搜索树的概念 二叉搜索树又称二叉排序树,它或者是一棵空树,或…

Python中使用opencv-python进行人脸检测

Python中使用opencv-python进行人脸检测 之前写过一篇VC中使用OpenCV进行人脸检测的博客。以数字图像处理中经常使用的lena图像为例,如下图所示: 使用OpenCV进行人脸检测十分简单,OpenCV官网给了一个Python人脸检测的示例程序,…

Backtrader 文档学习- Plotting - Plotting Date Ranges

Backtrader 文档学习- Plotting - Plotting Date Ranges 1.概述 1.9.31.x版本增加了绘制部分图形的功能。 可以使用策略实例中保留完整长度的时间戳数组的索引或者使用实际的datetime.date 或datetime.datetime 实例来限制需要绘制的内容。 仍然可以使用标准的cerebro.plot…

静态时序分析:建立时间分析

静态时序分析https://blog.csdn.net/weixin_45791458/category_12567571.html?spm1001.2014.3001.5482 在静态时序分析中,建立时间检查约束了触发器时钟引脚(时钟路径)和输入数据引脚(数据路径)之间的时序关系&#x…

android中实现设备尺寸适配

1、引言 设备尺寸适配的重要性想必就不用我多说了,在我发布的历史文章中我曾谈过Qt中的设备尺寸适配问题,那这里我就来教大家如何在android中做设备尺寸适配。在android中设备尺寸适配的方式有好几种,但我喜欢的还是使用获取设备真实尺寸然后…

c语言游戏实战(4):人生重开模拟器

前言: 人生重开模拟器是前段时间非常火的一个小游戏,接下来我们将一起学习使用c语言写一个简易版的人生重开模拟器。 网页版游戏: 人生重开模拟器 (ytecn.com) 1.实现一个简化版的人生重开模拟器 (1) 游戏开始的时…

PLC在物联网中位置—承上启下,与上位机下位机的关联。

谈到物联网,就绕不开PLC,本文着重介绍PLC的定义、与单片机的区分,价值、物联网中的位置,以及和上位机、下位机的关联,让友友们对PLC有个全面的认知。 一、什么是PLC PLC是可编程逻辑控制器(Programmable L…

UI自动刷新大法:DataBinding数据绑定

之前我们讲了DataBinding在Activity、Fragment、RecyclerView中的基础使用,而那些常规使用方法里,每当绑定的变量发生数据变化时,都需要ViewDataBinding重新设值才会刷新对应UI。而DataBinding通过内部实现的观察者模式来进行自动刷新UI&…

go消息队列RabbitMQ - 订阅模式-direct

1.发布订阅 在Fanout模式中,一条消息,会被所有订阅的队列都消费。但是,在某些场景下,我们希望不同的消息被不同的队列消费。这时就要用到Direct类型的Exchange。 在Direct模型下: 队列与交换机的绑定,不能…

第 384 场 LeetCode 周赛题解

A 修改矩阵 模拟 class Solution { public:vector<vector<int>> modifiedMatrix(vector<vector<int>> &matrix) {int m matrix.size(), n matrix[0].size();vector<int> mx(n, INT32_MIN);for (int i 0; i < m; i)for (int j 0; j &l…

Java微服务学习Day1

文章目录 认识微服务服务拆分及远程调用服务拆分服务远程调用提供者与消费者 Eureka注册中心介绍构建EurekaServer注册user-serviceorder-service完成服务拉取 Ribbon负载均衡介绍原理策略饥饿加载 Nacos注册中心介绍配置分级存储负载均衡环境隔离nacos注册中心原理 认识微服务…

Python : 使用python实现学生管理系统的功能,详细注释

一、学生管理系统 学生描述&#xff1a;姓名、年龄、成绩 学生管理系统功能&#xff1a;添加学生信息、删除学生信息、根据姓名修改学生信息、根据姓名查询学生信息、显示所有学生信息、退出系统 二、代码说明 1. 将每一个学生的信息放一个元组中&#xff0c;再把元组添加到列表…

单片机基础入门:简单介绍51单片机的工作原理

在电子技术领域&#xff0c;单片机是实现智能化控制不可或缺的关键元件。它们集成了许多功能于一身&#xff0c;成为了各种电子系统的心脏。为了更好地理解单片机如何工作&#xff0c;本文将重点介绍51单片机的基本组成和工作原理。 51单片机是一种广泛使用的微控制器&#xf…