多模态(二)--- CoCa原理与源码解读

1 CoCa简介

CoCa代表Contrastive Captioner 的缩写,代表模型用两个目标函数训练出来的,一个是Contrastive Loss,一个是Captioning Loss。

2 CoCa训练流程

  1. 利用ViT对image进行encoder编码获得图像特征token
  2. 对图像特征进行attention pooling(multihead attention), 取第0位作为计算对比损失的cls-token,后255位作为计算生成损失的视觉token
  3. 对text进行embedding编码,在文本token末尾嵌入cls_token
  4. 生成相应的单词遮挡掩膜mask,给text-token加上位置编码
  5. 将text-token和mask-atten送入transformer学习获得文本cls_token(text_latent), 和其余单词token(token_emb)
    在这里插入图片描述

2.1 image encoder

    def _encode_image(self, images, normalize: bool = True):image_latent, tokens_embs = self.visual(images)image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent# image_latent:constractive_token, tokens_embs: caption_tokenreturn image_latent, tokens_embs#### self.visual(images):def forward(self, x: torch.Tensor):# [b, 3, 224, 224]--->[b, 1024, 16, 16]x = self.conv1(x)# [b, 1024, 16, 16]--->[b, 1024, 256]x = x.reshape(x.shape[0], x.shape[1], -1)# [b, 1024, 256]--->[b, 256, 1024]x = x.permute(0, 2, 1)# 在序列长度上给图像嵌入一个类别,x:[b, 256 + 1, 1024]x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)# 嵌入位置编码,x:[b, 256 + 1, 1024]x = x + self.positional_embedding.to(x.dtype)# patch_dropout, x:[b, 256 + 1, 1024]x = self.patch_dropout(x)# LayerNorm处理 x:[b, 256 + 1, 1024]x = self.ln_pre(x)# NLD -> LND [b, 256 + 1, 1024]---> [256 + 1, b, 1024]x = x.permute(1, 0, 2)# transformer网络处理x = self.transformer(x)# LND -> NLD  [256 + 1, b, 1024]--->[b, 256 + 1, 1024]x = x.permute(1, 0, 2)if self.attn_pool is not None:# this is the original OpenCLIP CoCa setup, does not match paper# x:[b, 257, 1024]--->[b, 256, 768]x = self.attn_pool(x)# ln归一化, [b, 256, 768]x = self.ln_post(x)# pooled: 类别token:[b, 768] tokens:图像token:[b, 255, 768]pooled, tokens = self._global_pool(x)# pooled: [b, 768]@[768, 768]--->[b, 768]if self.proj is not None:pooled = pooled @ self.proj# 同时返回cls-token和视觉tokenif self.output_tokens:return pooled, tokensreturn pooled
# self.attn_pool(x)
class AttentionalPooler(nn.Module):def __init__(self,d_model: int,context_dim: int,n_head: int = 8,n_queries: int = 256,norm_layer: Callable = LayerNorm):super().__init__()self.query = nn.Parameter(torch.randn(n_queries, d_model))self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)self.ln_q = norm_layer(d_model)self.ln_k = norm_layer(context_dim)def forward(self, x: torch.Tensor):# ln归一化,NLD -> LND [b, 257, 1024]--->[257, b, 1024]x = self.ln_k(x).permute(1, 0, 2)N = x.shape[1]# q: [256, 768]q = self.ln_q(self.query)# q: [256, 768]--->[256, 1, 768]--->[256,b, 768], k=v=x, x:[257, b, 1024]# out: [256, b, 768], MultiheadAttentionout = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0]# out: [256, b, 768]--->[b, 256, 768]return out.permute(1, 0, 2)  # LND -> NLD

2.2 Unimodal text decoder

    def _encode_text(self, text, normalize: bool = True):# text_latent:[b, 768], token_emb:[b, 76, 768]text_latent, token_emb = self.text(text)text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latentreturn text_latent, token_embdef forward(self, text):cast_dtype = self.transformer.get_cast_dtype()seq_len = text.shape[1]# x:[b, 76, 768], 将text:[b, 76]进行embeding, F.embedding(text, weight=[40408, 768])49408---一共49408个单词,每个单词维度768x = self.token_embedding(text).to(cast_dtype)attn_mask = self.attn_maskif self.cls_emb is not None:seq_len += 1# 在文本token末尾嵌入cls_token, x:[b, 76, 768] ---> [b, 76+1, 768]x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)# cls_mask: [12b, 77, 77], text:[b, 76]cls_mask = self.build_cls_mask(text, cast_dtype)# 将单词有序遮挡mask, attn_mask: [[0, -inf, -inf,...-inf], [0, 0, -inf, ..., -inf],...[0, 0, 0,...,0,-inf], [0, 0, 0, ...,0]]if attn_mask is not None:# attn_mask: [1,77, 77] + cls_mask[12b, 77, 77] ===> 获得最终的attn_mask: [12b, 77, 77], 有单词的位置为0, 被遮挡以及没单词的位置为-infattn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]# 加上位置编码, x: [b, 77, 768]x = x + self.positional_embedding[:seq_len].to(cast_dtype)# x: [b, 77, 768]--->[77, b, 768]x = x.permute(1, 0, 2)  # NLD -> LND# 进入transformer学习, x:[77, b, 768]x = self.transformer(x, attn_mask=attn_mask)# x: [77, b, 768]--->[b, 77, 768]x = x.permute(1, 0, 2)  # LND -> NLD# x.shape = [batch_size, n_ctx, transformer.width]if self.cls_emb is not None:# presence of appended cls embed (CoCa) overrides pool_type, always take last token# pooled: cls_token:[b, 768] tokens:图像token:[b, 76, 768]pooled, tokens = text_global_pool(x, pool_type='last')# layernormpooled = self.ln_final(pooled)  # final LN applied after pooling in this case# [b, 768] @ 【768, 768】---> [b, 768]pooled = pooled @ self.text_projection# pooled:[b, 768], tokens:[b, 76, 768]if self.output_tokens:return pooled, tokensreturn pooled
    def build_cls_mask(self, text, cast_dtype: torch.dtype):# 找到text中存在单词的cls_mask,值设为True,text:[b, 76], cls_mask: [b, 1, 76]cls_mask = (text != self.pad_id).unsqueeze(1)# cls_mask: [b, 1, 76]--->[b, 77, 77]cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)# 随机一个[b, 77, 77]的maskadditive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)# 全部填充为0 additive_mask:[b, 77, 77]additive_mask.fill_(0)# 不满77长度的单词中,0填充的位置换为-infadditive_mask.masked_fill_(~cls_mask, float("-inf"))# 将additive_mask在batch维度上重复self.heads(12)次,[b, 77, 77]--->[12b, 77, 77]additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)return additive_mask

2.3 Multimodal text decoder

# logits: [b, 76, 49408], image_embs:caption_embedings[b, 255, 768], token_embs:文本embedings [b, 76, 768]
logits = self.text_decoder(image_embs, token_embs)
# self.text_decoder forward
def forward(self, image_embs, text_embs):# [b, 76, 768]--->[76, b, 768]text_embs = text_embs.permute(1, 0, 2)# [b, 255, 768]--->[255, b, 768]image_embs = image_embs.permute(1, 0, 2)# 76seq_len = text_embs.shape[0]# cross-attention: q=text_embs, k_x=image_embs, v_x=image_embsfor resblock, cross_attn in zip(self.resblocks, self.cross_attn):text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])# q=text_embs, k_x=image_embs, v_x=image_embstext_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)# x: [76, b, 768]--->[b, 76, 768]x = text_embs.permute(1, 0, 2)  # LND -> NLD# layer_normx = self.ln_final(x)# x: [b, 76, 768] @ [768, 49408] ---> [b, 76, 49408]if self.text_projection is not None:x = x @ self.text_projection# [b, 76, 49408]return x

2.4 Loss计算

    def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):clip_loss = torch.tensor(0)# constractive lossif self.clip_loss_weight:# image_features: [b, 768], text_features:[b, 768], logit_scale:温度系数clip_loss = super().forward(image_features, text_features, logit_scale)clip_loss = self.clip_loss_weight * clip_loss# caption loss, self.caption_loss:CE losscaption_loss = self.caption_loss(logits.permute(0, 2, 1), # [b, 76, 49408]labels, # [b, 76])caption_loss = caption_loss * self.caption_loss_weightif output_dict:return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}return clip_loss, caption_loss# clip_lossdef forward(self, image_features, text_features, logit_scale, output_dict=False):device = image_features.device# 假设有N个图像-文本对: logits_per_image: [N, N], logits_per_text: [N, N]logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)# 假设有N个图像-文本对:labels=[0, 1, 2,....N]labels = self.get_ground_truth(device, logits_per_image.shape[0])# 总损失 = (图像维度的损失 + 文本维度的损失)/ 2total_loss = (F.cross_entropy(logits_per_image, labels) +    # 图像维度的损失F.cross_entropy(logits_per_text, labels)       # 文本维度的损失) / 2return {"contrastive_loss": total_loss} if output_dict else total_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/678748.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java回忆性记录5

java回忆性记录5 for循环 循环语句,属于重复结构中的流程控制语句,一旦条件符合就会执行循环,反之则不会。假如让计算机再屏幕打印500次我们不可能傻傻的把打印语句输出500次。依稀记得在没有学习循环语句的时候自己傻傻的ctrlc 、ctrlv了几…

第四篇:SQL语法-DDL-数据定义语言

大年初一限定篇😀 (祝广大IT学习者、工作者0 error 0 warning!) 一,DDL数据库操作 (一)库的查询操作 1.列出所有已定义数据库 show databases; 2.查询当前所处数据库 select database(); &…

基于Linux的HTTP代理服务器搭建与配置实战

在数字化世界中,HTTP代理服务器扮演着至关重要的角色,它们能够帮助我们管理网络请求、提高访问速度,甚至在某些情况下还能保护我们的隐私。而Linux系统,凭借其强大的功能和灵活性,成为了搭建HTTP代理服务器的理想选择。…

【C语言——打印乘法口诀表】

乘法表: 我们可以定义一个i控制行的变化,外加看上图的表得知我们需要用到循环结构,i是行需要不停的加加,因此,for循环比较好用,可以用两个嵌套的循环,外层循环即用到的i表示的是每一行的打印&am…

【从Python基础到深度学习】4. Linux 常用命令

1.配置root用户密码 root用户为系统默认最高权限用户,其他用户密码修改命令与root用户修改密码命令相同 sudo passwd root 2.添加用户(henry) sudo useradd -m henry -s /bin/bash 3.配置henry用户密码 Xshell下连接新用户(hen…

15.Swift闭包

Swift 闭包 在 Swift 中,闭包是一种自包含的函数代码块,可以在代码中被传递和使用。闭包可以捕获并存储其所在上下文中的任意变量和常量的引用,这就是所谓的闭包的特性。闭包在 Swift 中被广泛用于函数式编程和异步编程,具有灵活…

二级C语言笔试9

(总分89.5,考试时间90分钟) 一、选择题 1. 下列对队列的叙述正确的是 。 A) 队列属于非线性表 B) 队列按“先进后出”原则组织数据 C) 队列在队尾删除数据 D) 队列按“先进先出”原则组织数据 2. 下列关于栈的描述中错误的是( )。 A) 栈是先进后出的…

weilai8游戏爬虫

#!/usr/bin/python # -*- coding: UTF-8 -*- #!/usr/bin/python # -*- coding: UTF-8 -*- import os,csv import re import random import time import requests from lxml import etreefrom urllib.parse import quote, unquotepage=98 sess = requests.Session()#创建一个ses…

linux系统上tomcat简介以及安装tomcat

tomcat简介以及安装 Tomcat简介安装环境安装jdk安装tomcat浏览器访问 Tomcat简介 Tomcat是一个开源的Web服务器和servlet容器,由Apache软件基金会开发和维护。它是一种流行的Java Web应用服务器,用于运行Java编写的Web应用程序。 Tomcat提供了一个轻量级…

基于javaEE的ssm仓库管理系统

仓库管理系统的重中之重是进销存分析这一板块,在这一板块中,顾名思义能够查询到近期的进货记录,包括每日的进货单据,单品推移(即某一商品的库存变化),方便我们核对库存差异。同时也需要查询到每日的销售数据&#xff0…

hexo部署到gitee(码云)

引言 Hexo 是一个基于Node.js的静态博客框架,而 Gitee(也被称为码云)是一个国内的代码托管平台,支持 Git 版本控制系统,与 GitHub 类似。将 Hexo 部署到 Gitee Pages 可以让你的博客受益于 Gitee 的国内服务器&#xf…

Java多态原理

参考 虚方法 JVM杂记:对多态实现原理、虚方法表、虚方法、静态解析、动态链接的一些思考_多态和方法表的关系-CSDN博客 静态分派与动态分派 (JVM)Java虚拟机:静态分派 & 动态分派 原理解析 - 掘金 虚方法表 JVM 栈帧&am…

假期作业8

线程和进程服务器 线程 #include <myhead.h>#define SIP "192.168.0.114" #define SPORT 8888void *task(void *arg){printf("客户端连接\n");sleep(1);pthread_exit(NULL); }int main(int argc, const char *argv[]) {int sfd socket(AF_INET, S…

16.1 Spring框架_SpringIoC容器与Bean管理(❤❤)

16.1 Spring框架_SpringIoC容器与Bean管理 1. Spring1.1 SpringIoC1. IoC控制反转2. DI依赖注入1.2 Spring概念1. Spring含义2. 传统开发与SpringIoC开发模式比较2. IoC基础实现案例(❤❤)1. 传统方式2. IoC与DI方式3. bean管理1. xml方式(❤❤)1. bean的实例化方式

C++——二叉树

引入 map和set特性需要先铺垫二叉搜索树&#xff0c;而二叉搜索树也是一种树形结构 二叉搜索树的特性了解&#xff0c;有助于更好的理解map和set的特性 1.二叉搜索树的概念及优缺点 1.1二叉搜索树的概念 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或…

12.4 OpenGL顶点后处理:图元裁剪

图元裁剪 Primitive Clipping Primitive Clipping&#xff08;图元裁剪&#xff09;是图形渲染管线中的一个重要步骤&#xff0c;发生在顶点处理之后、光栅化之前。这个阶段主要目的是去除位于视体&#xff08;View Volume&#xff09;之外或者被用户自定义裁剪平面&#xff0…

【Spring和Spring Boot的区别——详细讲解】

Spring和Spring Boot的区别 1. 介绍2. Spring框架3. Spring Boot4. 结论 1. 介绍 Spring和Spring Boot都是现代Java开发中常用的技术和框架&#xff0c;它们之间的关系紧密&#xff0c;Spring Boot是建立在Spring之上的&#xff0c;它简化了Spring应用的创建和开发过程。下面是…

Python中使用opencv-python进行人脸检测

Python中使用opencv-python进行人脸检测 之前写过一篇VC中使用OpenCV进行人脸检测的博客。以数字图像处理中经常使用的lena图像为例&#xff0c;如下图所示&#xff1a; 使用OpenCV进行人脸检测十分简单&#xff0c;OpenCV官网给了一个Python人脸检测的示例程序&#xff0c;…

Backtrader 文档学习- Plotting - Plotting Date Ranges

Backtrader 文档学习- Plotting - Plotting Date Ranges 1.概述 1.9.31.x版本增加了绘制部分图形的功能。 可以使用策略实例中保留完整长度的时间戳数组的索引或者使用实际的datetime.date 或datetime.datetime 实例来限制需要绘制的内容。 仍然可以使用标准的cerebro.plot…

TCP Server工具类,BIO阻塞和非阻塞NIO

启动自定义代码的方式 WebServer Initialized Event //Component//ApplicationContext context 注入//PostConstruct//AsyncEventListener(ApplicationReadyEvent.class)Componentpublic class TcpServerListener implements ApplicationListener<WebServerInitializedEven…