【多模态大模型】 ALBEF in NeurIPS 2021

一、引言

论文: Align before Fuse: Vision and Language Representation Learning with Momentum Distillation
作者: Salesforce Research
代码: ALBEF
特点: 该方法使用ViT进行图像特征提取,提出将BERT分两部分,一部分进行文本特征提取,另一部分进行图像-文本交互的特征提取;提出使用image-text contrastive learning (ITC)损失、masked language modeling (MLM)损失、image-text matching (ITM)损失进行模型优化;提出Momentum Distillation策略以一个通过exponential moving average (EMA)的网络生成软伪标签提供另一个视角的优化方向。

⚠️ 在学习该方法前,建议补充ViT、BERT、CLIP、MoCo的相关知识。

二、详情

ALBEF的整体结构图如下:

可见,ALBEF在网络结构上主要包括1个图像编码器、1个文本编码器、1个多模态编码器和1个同样包含上述3个编码器的动量模型;ALBEF在损失上主要包括image-text contrastive learning (ITC)损失、masked language modeling (MLM)损失、image-text matching (ITM)损失;此外,ALBEF还引入了动量蒸馏。

2.1 网络结构

如图,ALBEF在网络结构上主要包括1个图像编码器、1个文本编码器、1个多模态编码器和1个同样包含上述3个编码器的动量模型

2.1.1 图像编码器

ALBEF的图像编码器使用ViT-B/16,共12个transformer模块,由在ImageNet-1k上进行预训练的权重初始化。输入图像转为token后会再扩充一个名为 [ CLS ] [\text{CLS}] [CLS]的token(初始化全0的可学习参数向量),用来表达图像的全局信息。最后输出的是经过12个transformer模块优化过的输入图像的token和 [ CLS ] [\text{CLS}] [CLS]的token,记为 { v cls , v 1 , ⋯ , v N } \{\boldsymbol{v}_{\text{cls}},\boldsymbol{v}_{1},\cdots,\boldsymbol{v}_N\} {vcls,v1,,vN}

关于ViT的详情,请参考我之前的博客Vision Transformer。

2.1.2 文本编码器

ALBEF的文本编码器使用6个transformer模块,由 BERT base \textbf{BERT}_{\textbf{base}} BERTbase的前6层初始化。输入文本会在最前面扩充一个名为 [ CLS ] [\text{CLS}] [CLS]的token(直接放在句子最前面,例如原文本是“I am very happy today.”,则新文本应为“ [ CLS ] [\text{CLS}] [CLS] I am very happy today.”),用来表达文本的全局信息。最后输出经Tokenizer和6个transformer模块优化过的输入文本的token和 [ CLS ] [\text{CLS}] [CLS]的token,记为 { w cls , w 1 , ⋯ , w N } \{\boldsymbol{w}_{\text{cls}},\boldsymbol{w}_{1},\cdots,\boldsymbol{w}_N\} {wcls,w1,,wN}

2.1.3 多模态编码器

ALBEF的多模态编码器使用6个transformer模块(含交叉注意力),由 BERT base \textbf{BERT}_{\textbf{base}} BERTbase的后6层初始化(BERT不包含交叉注意力,所以其中交叉注意力是随机初始化的)。输入为图像编码器和文本编码器输出的toekn,即 { v cls , v 1 , ⋯ , v N } \{\boldsymbol{v}_{\text{cls}},\boldsymbol{v}_{1},\cdots,\boldsymbol{v}_N\} {vcls,v1,,vN} { w cls , w 1 , ⋯ , w N } \{\boldsymbol{w}_{\text{cls}},\boldsymbol{w}_{1},\cdots,\boldsymbol{w}_N\} {wcls,w1,,wN},输出以图像token为指导经6个含交叉注意力的transformer模块优化过的文本token。

2.1.4 动量模型

ALBEF的动量模型是额外保存的一个模型,同样包括图像编码器、文本编码器、多模态编码器,且初始化参数也一致。但动量模型不通过梯度优化,而是通过指数移动平均进行参数更新。为方便讲解,我们将正常通过梯度优化更新参数的模型称为梯度模型,额外保存的通过指数移动平均更新参数的模型称为动量模型

指数移动平均的公式可表示为 θ t = λ θ t + ( 1 − λ ) θ g \theta_{t}=\lambda\theta_{t}+(1-\lambda)\theta_{g} θt=λθt+(1λ)θg,其中 θ t \theta_{t} θt θ g \theta_{g} θg分别为动量模型和梯度模型的参数, λ \lambda λ取的很大这里为0.995,这是为了保证动量模型的稳定使其不会因为噪声干扰而偏离原本正确的优化方向。
简单来说,就是通过梯度更新过梯度模型后,新的梯度模型和旧的动量模型参数加权求和即得到了新的动量模型参数。

2.2 损失

如图,ALBEF主要包括3个损失,分别为image-text contrastive learning (ITC)损失、masked language modeling (MLM)损失、image-text matching (ITM)损失。

2.2.1 image-text contrastive learning 损失

ITC损失旨在更好地学习两个模态的特征表达,使两个模态的特征能够对齐,即图像特征与对应文本描述的特征更相似。相关模块如下图红框所示:

如上图,ITC损失的计算需要用到梯度模型中图像编码器和文本编码器输出的 [ CLS ] [\text{CLS}] [CLS]token;还需要用到动量模型中图像编码器和文本编码器输出的一堆 [ CLS ] [\text{CLS}] [CLS]token。

这里用到了MoCo的自训练思想,关于MoCo 的详情请参考我之前的博客SAVC的第1.2.2节。

我们以一个批次中的一张图像为例讲解具体的计算过程:

假设一个批次包括 B B B个图像-文本对, I 1 I_1 I1是第一张图片, T 1 T_1 T1是与之配对的文本, T 2 , ⋯ , T B T_2,\cdots,T_B T2,,TB是与当前批次其他图片配对的文本; T 1 , T 2 , ⋯ , T B T_1,T_2,\cdots,T_B T1,T2,,TB Q B + 1 T , Q B + 2 T , ⋯ , Q M T Q^T_{B+1},Q^T_{B+2},\cdots,Q^T_{M} QB+1T,QB+2T,,QMT组成一个队列, M = 65536 M=65536 M=65536,初始是从训练集随机选出的 M M M个文本,新批次的文本数据到来后会被推入队列,队列另一头的数据会被推出。

I 1 I_1 I1会经过梯度模型的图像编码器输出 [ CLS ] [\text{CLS}] [CLS]token,即 v cls \boldsymbol{v}_{\text{cls}} vcls T 1 , T 2 , ⋯ , T B , Q B + 1 T , Q B + 2 T , ⋯ , Q M T T_1,T_2,\cdots,T_B,Q^T_{B+1},Q^T_{B+2},\cdots,Q^T_{M} T1,T2,,TB,QB+1T,QB+2T,,QMT会经过动量模型的文本编码器输出 [ CLS ] [\text{CLS}] [CLS]token,即 w cls , 1 ′ , w cls , 2 ′ , ⋯ , w cls , M ′ \boldsymbol{w}^{\prime}_{\text{cls},1},\boldsymbol{w}^{\prime}_{\text{cls},2},\cdots,\boldsymbol{w}^{\prime}_{\text{cls},M} wcls,1,wcls,2,,wcls,M

梯度模型在图像编码器后增加一个全连接映射使token维度从768降低至256,并施加归一化,映射+归一化操作记为 g v ( ⋅ ) g_v(\cdot) gv(),得到 q q q;动量模型在文本编码器后进行同样的操作,记为 g w ′ ( ⋅ ) g^{\prime}_w(\cdot) gw(),得到 k 0 , k 1 , ⋯ , k M k_0,k_1,\cdots,k_M k0,k1,,kM

q q q k 1 , k 2 , ⋯ , k M k_1,k_2,\cdots,k_{M} k1,k2,,kM一一计算相似度,记为 s ( I , T m ) , m = 1 , 2 , ⋯ , M s(I,T_m),m=1,2,\cdots,M s(I,Tm),m=1,2,,M。于是可以获取该图像的softmax-normalized image-to-text similarity:

M M M个概率值组成 p i2t ( I ) \boldsymbol{p}^{\text{i2t}}(I) pi2t(I),对应的真实标签则应该是 y i2t ( I ) = { 1 , 0 , ⋯ , 0 } \boldsymbol{y}^{\text{i2t}}(I)=\{1,0,\cdots,0\} yi2t(I)={1,0,,0},真实标签是one-hot形式的, I I I与队列中哪个文本对应,对应位置就应该为 1 1 1,其余为 0 0 0

相应地,以文本为基准利用梯度模型的文本编码器、动量模型的图像编码器、图像队列也可以获得 p t2i ( T ) \boldsymbol{p}^{\text{t2i}}(T) pt2i(T) y t2i ( T ) \boldsymbol{y}^{\text{t2i}}(T) yt2i(T)。最后便可得到ITC损失:

其中 H ( ⋅ , ⋅ ) H(\cdot,\cdot) H(,)为标准交叉熵。

2.2.2 masked language modeling 损失

MLM损失利用图像和文本上下文来预测被mask的单词,以此提升模型的理解能力。相关模块如下图红框所示:

对于一个图像-文本对,图像被完整送入图像编码器,文本会被随机mask,即由原来的单词替换为 [ MASK ] [\text{MASK}] [MASK],然后送入文本编码器。

mask的规则是每个单词有15%的概率被选中,被选中的单词中80%被替换为 [ MASK ] [\text{MASK}] [MASK],10%被随机替换成其他token,10%没有任何改变。

之所以不是直接选15%*90%的进行mask和替换,是因为15%*10%没有任何变化的单词也需要模型对其进行预测。

以“I am very happy today.”为例,讲解mask的过程。经过Tokenizer其变为:

tokens = [I, am, very, happy, today]

增加 [ CLS ] [\text{CLS}] [CLS] [ SEP ] [\text{SEP}] [SEP]得到:

tokens = [[CLS], I, am, very, happy, today, [SEP]]

假设我们以15%的概率选中I、happy、today,再以80%-10%-10%的概率进行调整后得到:

tokens = [[CLS], I, am, very, [MASK], good, [SEP]]

可见,I没有发生变化,happy被替换为 [ MASK ] [\text{MASK}] [MASK],today被替换为good。此时要求模型利用图像和 [ [ CLS ] , I , a m , v e r y , [ MASK ] , g o o d , [ SEP ] ] [[\text{CLS}], I, am, very, [\text{MASK}], good, [\text{SEP}]] [[CLS],I,am,very,[MASK],good,[SEP]]来预测出句子原本的单词。

图像信息图像编码器和交叉注意力与文本信息交互从而起到指导作用,文本信息经文本编码器和多模态解码器输出优化后的tokens。每个token后跟一个FFN和softmax进行当前位置对应单词的预测。 下图给出了一个“Paris is a beautiful city. I love Paris.”中city被替换为 [ MASK ] [\text{MASK}] [MASK]后模型的预测过程以帮助理解:

可见,该部分预测仍是一个概率分布,所以MLM损失同样使用标准交叉熵:

其中, I I I T T T是原始图像-文本对, T ^ \hat{T} T^是经mask后的文本; p msk ( I , T ^ ) \boldsymbol{p}^{\text{msk}(I,\hat{T})} pmsk(I,T^)是对一个被mask的单词预测的概率分布, y msk \boldsymbol{y}^{\text{msk}} ymsk是该单词真实的one-hot标签。

⚠️ 由于MLM损失与其它损失,例如ITC损失,的输入不同(有无mask),所以该损失会额外产生一次文本编码器和多模态编码器的forward。

2.2.3 image-text matching 损失

ITM损失用来预测输入的图像-文本对是否匹配,匹配为1,不匹配为0,是一个二分类损失。相关模块如下图红框所示:

对于一个批次的图像-文本对,该损失是较简单的,因为非原配的图像-文本对很容易被判定为否,所以ALBEF利用在ITC损失计算时得到的本批次图像-文本相似度来挑选hard的负例。

对于一个批次中的一张图像来说, { p 1 i2t ( I ) , p 2 i2t ( I ) , ⋯ , p B i2t ( I ) } \{p_1^{\text{i2t}}(I),p_2^{\text{i2t}}(I),\cdots,p_B^{\text{i2t}}(I)\} {p1i2t(I),p2i2t(I),,pBi2t(I)}就是计算ITC损失时得到的相似度,其中非原配的最高相似度所对应的文本即为hard的负例文本。同样地,对于每个文本来说,也可以选出自己的hard负例图像。

图像和文本分别经过各自的编码器再通过相似度选出各自的hard负例之后每个图像或文本都有1个正例和1个负例与之对应,将它们的 [ CLS ] [\text{CLS}] [CLS]token送入多模态编码器即可得到优化后的 [ CLS ] [\text{CLS}] [CLS]token。在后面跟一个全连接映射和softmax即可进行二元预测判断图像-文本是否匹配。所以ITM损失也可以使用标准交叉熵:

其中,一个图像-文本对可以产生3项损失,包括1个原配的图像-文本对、2个hard的图像-文本对(因为分别是梯度模型和动量模型的输出之间计算相似度,不一定是同一对图像-文本互为hard)。

3 动量蒸馏

动量模型不仅在计算ITC损失时发挥作用,ALBEF还用它来应对从网络爬虫下来的图像-文本对富含噪声的问题。

首先,我们需要知道网络数据的噪声是什么样的。一般我们看到一个蛋糕图片后希望获取的是它的店铺位置从而去购买,所以我们从网上下载的数据很可能是一个蛋糕图片和一个对商铺的描述;但实际我们是希望与图片匹配的文本描述应该是针对图片中内容的描述,例如这个蛋糕的外观,如下图:

可能还有些数据的图像-文本是匹配的,但是明显有更合适的描述,如下图:

事实上,网络上很多都是这种噪声数据,如果我们使用one-hot形式的标签进行模型训练和学习,就会很大程度上被这些数据误导。于是,ALBEF使用动量模型来生成软伪标签约束和指导模型的学习。

软标签是相对one-hot形式的硬标签而言的。例如三分类问题中,one-hot只有一个值是1,其余均为0,例如 { 1 , 0 , 0 } \{1,0,0\} {1,0,0} { 0 , 1 , 0 } \{0,1,0\} {0,1,0} { 0 , 0 , 1 } \{0,0,1\} {0,0,1};软标签则只要求各个值的和为1,允许多个类别上有值,例如 { 0.6 , 0.3 , 0.1 } \{0.6,0.3,0.1\} {0.6,0.3,0.1} { 0.3 , 0.7 , 0 } \{0.3,0.7,0\} {0.3,0.7,0}等等。
伪标签是相对真实标签而言的,非原始的真实标签,而是通过其它手段生成的标签都称为伪标签。

其次,就是如何生成软伪标签。ALBEF是将动量模型的预测作为软伪标签。

例如,ITC损失原本是将图像或文本输入梯度网络然后将队列输入动量网络再计算相互间的相似度得到预测,如果原本图像和文本是一对,则标签值为1,否则为0。动量蒸馏是将图像或文本以及队列均输入到动量网络中,然后计算动量网络输出间的相似度。

下图说明了两者相似度计算的差异:

可见,主要区别就是图像-文本对是送入梯度网络(原始)还是动量网络(动量蒸馏)。有了新的相似度之后,再通过softmax-normalized image-to-text similarity即可得到动量网络的预测,即伪软标签 q i2t ( I ) \boldsymbol{q}^{\text{i2t}}(I) qi2t(I) q t2i ( T ) \boldsymbol{q}^{\text{t2i}}(T) qt2i(T)。于是得到ITC损失的动量蒸馏损失:

其中, α = 0.4 \alpha=0.4 α=0.4。由于 q i2t ( I ) \boldsymbol{q}^{\text{i2t}}(I) qi2t(I) q t2i ( T ) \boldsymbol{q}^{\text{t2i}}(T) qt2i(T)不是one-hot形式的,所以这里用KL散度衡量梯度网络的预测与动量网络的预测的一致性。

当真实图像-文本不太匹配时,这种操作能允许模型将图像或文本与其它文本或图像做匹配。但是我们又不希望随机找一个进行匹配,所以用比较稳定的动量网络提供一个合适的匹配。

类似地,将被mask后的文本输入动量网络,也能得到动量网络的预测 q msk ( I , T ^ ) \boldsymbol{q}^{\text{msk}}(I,\hat{T}) qmsk(I,T^),即软伪标签。于是,得到MLM损失的动量蒸馏损失:

⚠️ 因为ITM损失就是根据原标签进行0和1的分配的,所以不太适合采用该策略,ALBEF没有对其进行修改。

作者还提供了一些例子,来证明软伪标签有时是更好更合适的:

上面3幅中,被mask的部分真实单词没有伪标签的单词合适;下面2幅中,原描述没有伪标签的描述合适。

致谢:

本博客仅做记录使用,无任何商业用途,参考内容如下:
是时候彻底弄懂BERT模型了
多模态论文串讲·上

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/51615.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Cocos Creator2D游戏开发(3)-飞机大战(1)-背景动起来

资源见: https://pan.baidu.com/s/1cryYNdBOry5A4YEEcLwhDQ?pwdzual 步骤 1, 让背景动起来 2, 玩家飞机显现,能操控,能发射子弹 3.敌机出现 4. 碰撞效果(子弹和敌机,敌机和玩家) 5. 积分和游戏结束 6. 游戏存档,对接微信小游戏,保存历史最高分 7. cocos发布到微信小游戏 资源…

探索Python的进度条神器:tqdm

文章目录 探索Python的进度条神器:tqdm一、背二、tqdm简介三、安装tqdm四、tqdm的五个简单使用示例五、tqdm在不同场景下的应用六、常见问题及解决方案七、总结 探索Python的进度条神器:tqdm 一、背 景:为什么选择tqdm? 在Python…

苦学Opencv的第十四天:人脸检测和人脸识别

Python OpenCV入门到精通学习日记:人脸检测和人脸识别 前言 经过了十三天的不懈努力,我们终于也是来到了人脸检测和人脸识别啦!相信大家也很激动吧。接下来我们开始吧! 人脸识别是基于人的脸部特征信息进行身份识别的一种生物识…

Spring 常用的三种拦截器详解

前言 在开发过程中,我们常常使用到拦截器来处理一些逻辑。最常用的三种拦截器分别是 AOP、 Interceptor 、 Filter,但其实很多人并不知道什么时候用AOP,什么时候用Interceptor,什么时候用Filter,也不知道其拦截顺序&am…

spring —— 事务管理器

事务管理主要针对数据源进行操作:在数据库方面,通过 TransactionManager 事务管理器进行管理,表明一旦出现错误,该数据源的所有数据全部复原。那么数据库如何判断是否发生了错误呢?这就需要在代码方面,通过…

抖音直播弹幕数据逆向:websocket和JS注入

🔍 思路与步骤详解 🕵️‍♂️ 思路介绍 首先,我们通过抓包工具进入的直播间,捕获其网络通信数据,重点关注WebSocket连接。发现直播弹幕数据通过WebSocket传输,这种方式比传统的HTTP更适合实时数据的传输。…

前端基于 axios 实现批量任务调度管理器 demo

一、背景介绍 这是一个基于 axios 实现的批量任务调度管理器的 demo。它使用了axios、promise 等多种技术和原理来实现批量处理多个异步请求,并确保所有请求都能正确处理并报告其状态。 假设有一个场景:有一个任务列表,有单个任务的处理功能…

【Qt】QLCDNumberQProgressBarQCalendarWidget

目录 QLCDNumber 倒计时小程序 相关属性 QProgressBar 进度条小程序 相关设置 QLCDNumber QLCDNumber是Qt框架中用于显示数字或计数值的小部件。通常用于显示整数值,例如时钟、计时器、计数器等 常用属性 属性说明intValueQLCDNumber显示的初始值(int类型)va…

企业版邮箱适用哪些企业

企业邮箱适合哪些企业呢?企业版邮箱为企业提供安全、稳定、集成的邮件服务,支持初创、中小、大型企业及特定行业需求。ZohoMail作为优质提供商,提供多层安全措施、移动访问、集成能力及定制化服务,满足不同规模企业需求。 一、企…

2023年系统架构设计师考试总结

原文链接:https://www.cnblogs.com/zhaotianff/p/17812187.html 上周六参加了2023年系统架构设计师考试,这次考试与以前有点区别,是第一次采用电子化考试,也是教材改版后的第一次考试。 说说考前准备:为了准备这次考试…

基于微信小程序的校园警务系统/校园安全管理系统/校园出入管理系统

摘要 伴随着社会以及科学技术的发展,小程序已经渗透在人们的身边,小程序慢慢的变成了人们的生活必不可少的一部分,紧接着网络飞速的发展,小程序这一名词已不陌生,越来越多的学校机构等都会定制一款属于自己个性化的小程…

《通讯世界》是什么级别的期刊?是正规期刊吗?能评职称吗?

问题解答 问:《通讯世界》是不是核心期刊? 答:不是,是知网收录的第一批认定学术期刊。 问:《通讯世界》级别? 答:国家级。主管单位:科学技术部 主办单位:中国科学技…

关于虚拟机在桥接模式下连接网络问题的记录

2024年7月28日03:49:19 环境:ubuntu22.04 desktop 虚拟机 问题:使用wget下载nginx安装包时出现问题,443端口持续无连接成功回复。 随后在确定配置ip无问题,检查了其正常访问互联网,随后试图ping niginx网站&#xff…

基于OSS前端直传的分片上传以及断点续传

一、大文件分片上传 原型 大文件如果直接上传的话由于nginx的限制会导致响应500报错,或者响应时间过长导致响应超时 并且大文件上传有如下缺点 上传时间长: 对于大文件,直接上传可能需要较长时间,特别是在网络速度较慢或不稳定的情况下。这…

ChatGPT秘籍:如何用AI阅读文献,提升你的学术效率

在当今信息泛滥的时代,迅速高效地搜集与处理信息显得尤为关键。本文将聚焦于如何利用ChatGPT高效阅读文献与文档,并提供详尽的技巧、心得以及实用的指令和插件解析,助你充分发挥ChatGPT的潜能。无论你是学生、科研人员还是行业从业者&#xf…

雪花算法的一些问题解析

前言 最近做项目,有些老旧项目,需要生成分布式唯一ID,不允许重复,此时如果要对其他中间件和数据库依赖小,那么就需要一套固定的ID生成规则,雪花算法就正当合适,当时Twitter就是用来存储数据库I…

JSP基础语法与指令

任何语言都有自己的语法&#xff0c;在java中有&#xff0c;JSP作为java技术的一种应用&#xff0c;它拥有一些自己扩充的语法(了解知道即可&#xff01;&#xff01;&#xff01;)&#xff0c; Java所有语法都支持&#xff01; JSP表达式 <html><head><title…

【Redis 初阶】初识 Redis

一、了解 Redis Redis 官网&#xff1a;Redis - The Real-time Data Platform Redis 是一种基于键值对&#xff08;key-value&#xff09;的 NoSQL 数据库。与很多键值对数据库不同的是&#xff0c;Redis 中的 key 都是 string&#xff08;字符串&#xff09;&#xff0c;值&a…

计算机毕业设计LSTM+Tensorflow股票分析预测 基金分析预测 股票爬虫 大数据毕业设计 深度学习 机器学习 数据可视化 人工智能

|-- 项目 |-- db.sqlite3 数据库相关 重要 想看数据&#xff0c;可以用navicat打开 |-- requirements.txt 项目依赖库&#xff0c;可以理解为部分技术栈之类的 |-- data 原始数据文件 |-- data 每个股票的模型保存位置 |-- app 主要代码文件夹 | |-- mod…

汽车辐射大?技术来救它:整车辐射抗扰发射天线仿真建模及性能预测

摘要 针对车辆电磁辐射抗扰度测试条件要求高、预测难度大的问题&#xff0c;通过仿真软件建立电磁抗扰度测试发射天线&#xff08;简称抗扰发射天线&#xff09;模型及无车情况下的电磁抗扰试验场强环境&#xff0c;为整车电磁辐射抗扰性能的预测搭建了一个仿真平台。 验证试验…