一、CLIP模型主要内容讲解
CLIP(Contrastive Language-Image Pre-training)是OpenAI在2021年发布的一种用于图像和文本联合表示学习的模型。CLIP的核心思想是通过对比学习来预训练一个模型,使其能够理解图像和文本之间的关系。以下是CLIP的工作原理和步骤:
1. 数据集
CLIP使用大规模的图像-文本对数据集进行预训练,例如从互联网上收集的4亿个图像-文本对。这些数据集包含了丰富的图像和对应的描述文本,使得模型能够学习到广泛的视觉概念和语言表达。
2. 模型架构
CLIP由两个主要部分组成:
- 图像编码器:用于将图像转换为特征向量。图像编码器可以是卷积神经网络(如ResNet)或Transformer模型(如ViT)。
- 文本编码器:用于将文本转换为特征向量。文本编码器通常是一个Transformer模型。
3. 对比学习
CLIP通过对比学习来训练模型。具体来说,对于一个批次中的每个图像-文本对,模型会计算图像和文本的特征向量,并使用对比损失函数来优化模型参数。对比损失函数的目标是使得匹配的图像-文本对的特征向量尽可能接近,而不匹配的图像-文本对的特征向量尽可能远离。
4. 损失函数
CLIP使用的损失函数是对称的对比损失函数。具体来说,对于每个图像-文本对,模型会计算两个方向的损失:
- 图像到文本的损失:计算图像特征向量和文本特征向量之间的相似度,并优化使得匹配的图像-文本对的相似度最大化。
- 文本到图像的损失:计算文本特征向量和图像特征向量之间的相似度,并优化使得匹配的文本-图像对的相似度最大化。
5. 推理阶段
在推理阶段,CLIP可以用于多种任务,例如:
- 图像分类:给定一个图像,模型可以将其特征向量与预定义的文本类别(如“猫”、“狗”等)的特征向量进行比较,选择相似度最高的类别作为预测结果。
- 文本到图像检索:给定一个文本描述,模型可以将其特征向量与图像库中的图像特征向量进行比较,检索出与文本描述最匹配的图像。
二、 动机
在计算机视觉领域,迁移学习的一种常见做法是先在如ImageNet这样的大规模数据集上进行预训练,然后在具体的下游任务上进行微调。这种预训练通常基于有监督学习,需要大量的数据标注,因此成本较高。近年来,随着自监督学习方法的兴起,这一局面得到了改变。自监督学习方法,包括基于对比学习的方法如MoCo和SimCLR,以及基于图像掩码的方法如MAE和BeiT,它们的优势在于不再依赖于数据标注。然而,无论是传统的监督学习还是新兴的自监督学习,它们在迁移到下游任务时,仍然需要进行有监督的微调,无法实现真正的零样本学习(Zero-shot)。
对于有监督模型而言,由于它们在预训练数据集上使用了固定类别数的分类器,因此在新的数据集上需要重新定义分类器并进行训练。而对于自监督模型,虽然代理任务有助于表征学习,但在迁移到其他数据集时,同样需要添加新的分类器进行有监督训练。
相比之下,在自然语言处理(NLP)领域,基于自回归或语言掩码的预训练方法已经相当成熟,并且预训练模型能够轻松实现零样本迁移到下游任务,例如OpenAI的GPT-3。这种差异一方面源于文本和图像这两种完全不同的模态,另一方面则是因为NLP模型可以利用互联网上丰富多样的文本数据。
这就引出了一个问题:我们能否利用互联网上的大量文本来预训练视觉模型,从而实现类似NLP领域的零样本迁移能力?这一问题的探讨,不仅涉及跨模态学习的深入研究,也为视觉模型的预训练和迁移学习开辟了新的可能性。
所以openai基于前面的工作,从文本信息获取监督信息的方式,做了以下两件事:1.足够大的数据集(爬取清洗了4亿对图像文本对),2.多模态(图像跟文本),统一用transformer架构,使用对比学习训练。
蓝色是回归预测,也就是根据图像去预测对应的文本标签,这个任务难度是巨大的,而且一个图像可以有多个文本表述范式;橘黄色是在特诊空间进行预测,根据图像特征去预测文本特征。绿色是图像-文本匹对的损失,这个是最快收敛的,任务相对简单。同时这样的设置可以更好地将 图像与文本的语义信息绑定到一起。
三、模型
1. 训练过程
伪代码
2. 模型细节
这里的image encoder可以是ResNet也可以是ViT;text encoder可以是CBOW也可以是Transformer。
图像、文本分别经过encoder之后得到的特征,会进行线性投影以及L2归一化操作。
L2 归一化(L2 normalization)是一种常见的数据预处理技术,用于将数据向量的长度缩放到单位范数(即 L2 范数为 1)。对于一个向量 v = [ v 1 , v 2 , … , v n ] \mathbf{v} = [v_1, v_2, \ldots, v_n] v=[v1,v2,…,vn],其 L2 范数定义为:
∥ v ∥ 2 = v 1 2 + v 2 2 + ⋯ + v n 2 \|\mathbf{v}\|_2 = \sqrt{v_1^2 + v_2^2 + \cdots + v_n^2} ∥v∥2=v12+v22+⋯+vn2
L2 归一化后的向量 v ′ \mathbf{v}' v′ 计算如下:
v ′ = v ∥ v ∥ 2 = [ v 1 ∥ v ∥ 2 , v 2 ∥ v ∥ 2 , … , v n ∥ v ∥ 2 ] \mathbf{v}' = \frac{\mathbf{v}}{\|\mathbf{v}\|_2} = \left[ \frac{v_1}{\|\mathbf{v}\|_2}, \frac{v_2}{\|\mathbf{v}\|_2}, \ldots, \frac{v_n}{\|\mathbf{v}\|_2} \right] v′=∥v∥2v=[∥v∥2v1,∥v∥2v2,…,∥v∥2vn]
这样处理后,向量 v ′ \mathbf{v}' v′ 的 L2 范数为 1:
∥ v ′ ∥ 2 = ( v 1 ∥ v ∥ 2 ) 2 + ( v 2 ∥ v ∥ 2 ) 2 + ⋯ + ( v n ∥ v ∥ 2 ) 2 = 1 \|\mathbf{v}'\|_2 = \sqrt{\left( \frac{v_1}{\|\mathbf{v}\|_2} \right)^2 + \left( \frac{v_2}{\|\mathbf{v}\|_2} \right)^2 + \cdots + \left( \frac{v_n}{\|\mathbf{v}\|_2} \right)^2} = 1 ∥v′∥2=(∥v∥2v1)2+(∥v∥2v2)2+⋯+(∥v∥2vn)2=1
3. 损失函数细节
对称性的损失设计
在CLIP的训练中,损失设计包括两个方向:从图像到文本和从文本到图像。这两个方向的交叉熵损失计算如下:
-
计算相似度矩阵:
- 假设有 n n n 对图像-文本对,图像编码表示为 I \mathbf{I} I,文本编码表示为 T \mathbf{T} T。
- 相似度矩阵 S \mathbf{S} S 的元素 s i j s_{ij} sij 表示第 i i i 个图像和第 j j j个文本之间的相似度。通常,使用余弦相似度计算:
S i j = cos ( I i , T j ) = I i ⋅ T j ∥ I i ∥ ∥ T j ∥ S_{ij} = \cos(\mathbf{I}_i, \mathbf{T}_j) = \frac{\mathbf{I}_i \cdot \mathbf{T}_j}{\|\mathbf{I}_i\| \|\mathbf{T}_j\|} Sij=cos(Ii,Tj)=∥Ii∥∥Tj∥Ii⋅Tj
-
定义标签:
- 标签向量为 t e x t l a b e l s = np.arange ( n ) text{labels} = \text{np.arange}(n) textlabels=np.arange(n),表示对角线上的元素是匹配的图像-文本对。
-
交叉熵损失:
-
图像到文本的交叉熵损失 loss i \text{loss}_i lossi:
loss i = cross_entropy_loss ( S , labels , axis = 0 ) \text{loss}_i = \text{cross\_entropy\_loss}(\mathbf{S}, \text{labels}, \text{axis}=0) lossi=cross_entropy_loss(S,labels,axis=0)
这里, axis = 0 \text{axis}=0 axis=0 表示对每一行(即每个图像对应的所有文本)计算损失。 -
文本到图像的交叉熵损失 ( \text{loss}_t ):
loss t = cross_entropy_loss ( S , labels , axis = 1 ) \text{loss}_t = \text{cross\_entropy\_loss}(\mathbf{S}, \text{labels}, \text{axis}=1) losst=cross_entropy_loss(S,labels,axis=1)
这里, axis = 1 \text{axis}=1 axis=1 表示对每一列(即每个文本对应的所有图像)计算损失。
-
-
综合损失:
- 最终的损失是这两个方向的交叉熵损失的平均值:
loss = loss i + loss t 2 \text{loss} = \frac{\text{loss}_i + \text{loss}_t}{2} loss=2lossi+losst
- 最终的损失是这两个方向的交叉熵损失的平均值:
其中,
-
图像到文本的交叉熵损失:
- 对于每个图像 I i \mathbf{I}_i Ii,计算它与所有文本的相似度 S i , : \mathbf{S}_{i,:} Si,:。
- 使用交叉熵损失,将这些相似度与标签(正确匹配的文本索引)进行比较,确保每个图像找到正确的文本描述。
-
文本到图像的交叉熵损失:
- 对于每个文本 T i \mathbf{T}_i Ti,计算它与所有图像的相似度 S : , i \mathbf{S}_{:,i} S:,i。
- 使用交叉熵损失,将这些相似度与标签(正确匹配的图像索引)进行比较,确保每个文本找到正确的图像。
这种设计的对称性确保了模型在训练过程中同时优化图像到文本和文本到图像的匹配效果。通过这种方式,模型不仅能够从图像中找到相应的文本描述,还能够从文本中找到对应的图像。这种双向优化使得CLIP模型在实际应用中表现更加鲁棒和准确。
4. zero shot tranfer
这个算是CLIP最大的创新点,之前的有监督学习或者说无监督学习,主要的目的是获取一个强大的特征抽取器(backbone),但是在应用到下游任务的时候,基本上还是要收集下游数据进行微调的。CLIP做到这样的两点:1.无需微调;2.文本信息引导模型迁移
5. promot engineering and ensembing
在做zero-shot推理的时候要用到。
- 标签单词文本多义性polysemy,所以只给一个标签单词并不合适。
- 训练的时候文本信息是句子,而预测的时候标签是单词,这有影响,存在distribution gap的问题,可以做promot template来巧妙解决,如设定好模板:"A photo of a {label}",同时可以再细化,比如已经知道数据集是宠物分类,那就可以使用"A photo of a {label},a type of pet."
- ensembing就是使用多个提示模板,然后综合多个模板的结果。
在CLIP模型中,文本和图像都会被提取为特征嵌入(feature embeddings)。这些特征嵌入的形状(shape)取决于模型的架构和输入数据的处理方式。
6. 图像特征、文本特征嵌入的形状
对于图像特征嵌入,通常的形状是 (batch_size, embedding_dim)
,其中:
batch_size
是输入图像的批次大小。embedding_dim
是图像特征嵌入的维度,这个维度取决于模型的架构。例如,对于CLIP的ViT-B/32模型,embedding_dim
通常是512。
对于文本特征嵌入,形状通常也是 (batch_size, embedding_dim)
,其中:
batch_size
是输入文本的批次大小。embedding_dim
是文本特征嵌入的维度,这个维度同样取决于模型的架构。对于CLIP的ViT-B/32模型,embedding_dim
通常也是512。
代码
这里使用PyTorch和Hugging Face的Transformers库来加载和使用CLIP模型。这个代码展示了如何使用CLIP模型进行图像分类和文本到图像的检索。
- 安装必要的库:
pip install torch transformers pillow
- 加载CLIP模型并进行图像分类:
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel# 加载预训练的CLIP模型和处理器
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")# 定义图像路径和文本类别
image_path = "path_to_your_image.jpg"
text_labels = ["a photo of a cat", "a photo of a dog", "a photo of a bird"]# 加载图像
image = Image.open(image_path)# 处理图像和文本
inputs = processor(text=text_labels, images=image, return_tensors="pt", padding=True)# 计算特征向量并进行分类
with torch.no_grad():outputs = model(**inputs)logits_per_image = outputs.logits_per_image # 图像到文本的相似度得分
probs = logits_per_image.softmax(dim=1) # 概率分布# 输出分类结果
for i, label in enumerate(text_labels):print(f"{label}: {probs[0][i].item():.2%}")# 输出最高概率的类别
predicted_class = text_labels[probs[0].argmax()]
print(f"Predicted class: {predicted_class}")