起因:
在DA-CLIP的开源库的DA-CLIP.md中自述该项目基于CLIP 和open_clip,在之前的退化类型检测中 我一度以为仓库只是使用了CLIP 的源码, 然而当发现缺少da-clip的模型名称时,我发现DA-CLIP使用的完全是open_clip的代码版本,专门配置了da-clip.json在open_clip的model_configs
This repository is based on the OpenAI's CLIP (Contrastive Language-Image Pre-training) and open_clip.
We extend the CLIP to a degradation-aware version (DA-CLIP) which predicts both degradation embedding and clean content embedding from corrupted images.
Then we can use the embeddings to improve image restoration performance and assist unified image restoration.
Moreover, we use the pretrained ViT CLIP model (ViT-B/32 on LAION-2B) and add an additional controller to control the image encoder.
该库基于OpenAI的CLIP(对比语言图像预训练)和open_clip。
我们将CLIP扩展到退化感知版本(DA-CLIP),该版本预测退化嵌入和从损坏图像中嵌入干净内容。然后,我们可以利用该嵌入来提高图像恢复性能,帮助统一图像恢复。
此外,我们使用预训练的ViT CLIP模型(LAION-2B上的ViT- b /32),并添加一个额外的控制器来控制图像编码器。
纵观DA-CLIP代码,关于CLIP模块,基本是从open_clip库上进行扩展。对于clip代码使用已有博主进行说明
CLIP模型原理与代码实现详解-CSDN博客https://blog.csdn.net/weixin_38252409/article/details/133828294
背景:
然而CSDN上对open_clip项目的介绍寥寥,以下项目虽有涉猎但都未对其仓库源码进行解析。
ImageNet零样本准确率首次超过80%!OpenCLIP:性能最强的开源CLIP模型-CSDN博客https://blog.csdn.net/amusi1994/article/details/129036171?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171120819416800227464964%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=171120819416800227464964&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_click~default-1-129036171-null-null.142%5Ev99%5Epc_search_result_base6&utm_term=OpenCLIP&spm=1018.2226.3001.4449
项目解读:open_clip
结构
docs/
- 包含项目文档和使用说明。src/
- 包含项目的主要源代码。tests/
- 包含项目测试代码。setup.py
- 用于安装和项目的 Python 包设置文件。
src下的文件夹,包含open_clip和training,
由于DA-CLIP与open_clip数据集、模型结构等不同,
其中DA-CLIP 主要使用了open_clip文件夹
该文件夹下:
model.py
- 定义了 CLIP 模型的结构,包括图像编码器和文本编码器。这是模型训练和推理的核心部分。config.py
- 包含了模型配置的类和函数,用于设置模型的不同参数,如学习率、批次大小等。tokenizer.py
- 包含了文本处理的分词器,用于将文本转换为模型可以理解的格式。trainer.py
- 包含了模型训练的主要逻辑,如设置优化器、损失函数和训练循环。evaluator.py
- 包含了评估模型性能的代码,通常用于在验证集或测试集上计算指标。utils.py
- 包含了一些实用工具函数,如数据处理、日志记录等。transform.py
- 包含了数据预处理和增强的转换函数,用于准备输入数据。- model_configs:包含各类模型配置,如coca、convnext、EVA、Vit、RN50等json文件
在 open_clip
目录下,每个文件都有其独特的价值和作用,但如果要挑选一个文件进行深入解读,model.py
可能是最值得关注的。这是因为 model.py
定义了 CLIP 模型的核心结构,它直接关联到模型的性能和功能。以下是对 model.py
文件的详细解读:
model.py 的作用
model.py
文件定义了 CLIP 模型的架构,包括图像编码器和文本编码器的设计。CLIP 模型是一个双流(two-stream)模型,它分别处理图像和文本输入,并通过对比学习(contrastive learning)的方式,学习图像和文本之间的对应关系。
主要组件
-
图像编码器(Image Encoder): 通常使用预训练的卷积神经网络(如 ResNet 或 Vision Transformer)作为基础,用于提取图像特征。
-
文本编码器(Text Encoder): 通常基于 Transformer 架构,用于处理文本输入并提取文本特征。文本编码器可能使用 BERT 或类似的 Transformer 变体。
-
投影头(Projection Heads): 用于将图像和文本编码器的输出映射到共同的特征空间,以便进行相似性比较。
-
损失函数(Loss Function): CLIP 模型使用对比损失函数来训练,这要求模型能够区分匹配的图像-文本对和不匹配的对。
关键概念
-
对比学习(Contrastive Learning): 一种自监督学习方法,模型通过比较正样本对和负样本对来学习特征表示。
-
多模态学习(Multimodal Learning): 涉及处理和理解多种类型数据(如图像和文本)的机器学习方法。
-
零样本学习(Zero-Shot Learning): 模型能够在没有见过特定类别的样本的情况下进行分类或识别。
代码结构
model.py
文件包含以下部分:
-
类定义(Class Definitions): 定义了图像编码器、文本编码器和整个 CLIP 模型的结构。
-
前向传播(Forward Pass): 描述了数据如何通过模型,以及如何计算图像和文本的特征表示。
-
初始化方法(Initialization Methods): 描述了模型权重的初始化过程,这对于训练的稳定性和收敛速度至关重要。
-
损失计算(Loss Computation): 实现了对比损失函数,用于训练过程中的优化。
通过对 model.py
文件的深入理解,可以更好地把握 CLIP 模型的工作原理,以及如何修改和扩展模型以适应不同的应用场景。这个文件是进行模型训练和评估的基础,对于任何想要深入了解或贡献于 OpenCLIP 项目的人来说都是必读的。
该文件的CLIP类如下
class CLIP(nn.Module):output_dict: torch.jit.Final[bool]def __init__(self,embed_dim: int,vision_cfg: CLIPVisionCfg,text_cfg: CLIPTextCfg,quick_gelu: bool = False,init_logit_scale: float = np.log(1 / 0.07),init_logit_bias: Optional[float] = None,cast_dtype: Optional[torch.dtype] = None,output_dict: bool = False,):super().__init__()self.output_dict = output_dictself.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)self.transformer = text.transformerself.context_length = text.context_lengthself.vocab_size = text.vocab_sizeself.token_embedding = text.token_embeddingself.positional_embedding = text.positional_embeddingself.ln_final = text.ln_finalself.text_projection = text.text_projectionself.text_pool_type = text.pool_typeself.register_buffer('attn_mask', text.attn_mask, persistent=False)self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)if init_logit_bias is not None:self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)else:self.logit_bias = Nonedef lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):# lock image tower as per LiT - https://arxiv.org/abs/2111.07991self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)@torch.jit.ignoredef set_grad_checkpointing(self, enable=True):self.visual.set_grad_checkpointing(enable)self.transformer.grad_checkpointing = enabledef encode_image(self, image, normalize: bool = False):features = self.visual(image)return F.normalize(features, dim=-1) if normalize else featuresdef encode_text(self, text, normalize: bool = False):cast_dtype = self.transformer.get_cast_dtype()x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]x = x + self.positional_embedding.to(cast_dtype)x = x.permute(1, 0, 2) # NLD -> LNDx = self.transformer(x, attn_mask=self.attn_mask)x = x.permute(1, 0, 2) # LND -> NLDx = self.ln_final(x) # [batch_size, n_ctx, transformer.width]x, _ = text_global_pool(x, text, self.text_pool_type)if self.text_projection is not None:if isinstance(self.text_projection, nn.Linear):x = self.text_projection(x)else:x = x @ self.text_projectionreturn F.normalize(x, dim=-1) if normalize else xdef get_logits(self, image, text):image_features = self.encode_image(image, normalize=True)text_features = self.encode_text(text, normalize=True)image_logits = self.logit_scale.exp() * image_features @ text_features.Tif self.logit_bias is not None:image_logits += self.logit_biastext_logits = image_logits.Treturn image_logits, text_logitsdef forward(self,image: Optional[torch.Tensor] = None,text: Optional[torch.Tensor] = None,):image_features = self.encode_image(image, normalize=True) if image is not None else Nonetext_features = self.encode_text(text, normalize=True) if text is not None else Noneif self.output_dict:out_dict = {"image_features": image_features,"text_features": text_features,"logit_scale": self.logit_scale.exp()}if self.logit_bias is not None:out_dict['logit_bias'] = self.logit_biasreturn out_dictif self.logit_bias is not None:return image_features, text_features, self.logit_scale.exp(), self.logit_biasreturn image_features, text_features, self.logit_scale.exp()
这个类定义了 CLIP 模型的结构和行为。
初始化
def __init__(self,embed_dim: int,vision_cfg: CLIPVisionCfg,text_cfg: CLIPTextCfg,quick_gelu: bool = False,cast_dtype: Optional[torch.dtype] = None,output_dict: bool = False,):super().__init__()self.output_dict = output_dictself.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)self.transformer = text.transformerself.context_length = text.context_lengthself.vocab_size = text.vocab_sizeself.token_embedding = text.token_embeddingself.positional_embedding = text.positional_embeddingself.ln_final = text.ln_finalself.text_projection = text.text_projectionself.register_buffer('attn_mask', text.attn_mask, persistent=False)self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
参数解释
self
: 指向类的实例。embed_dim
: 嵌入维度,这是模型中嵌入层的维度。vision_cfg
: 图像配置对象,它包含了构建图像处理部分(视觉塔)所需的配置信息。text_cfg
: 文本配置对象,它包含了构建文本处理部分(文本塔)所需的配置信息。quick_gelu
: 布尔值,指示是否使用快速的GELU(Gaussian Error Linear Unit)激活函数。cast_dtype
: 可选参数,指定数据类型,用于将模型参数转换为指定的数据类型。output_dict
: 布尔值,指示模型输出是否应该是一个字典。
方法体解释
super().__init__()
: 调用父类的构造函数。self.output_dict
: 存储传入的output_dict
参数,这可能影响模型输出的格式。self.visual
: 通过调用一个内部函数_build_vision_tower
来构建视觉塔,并存储结果。text
: 通过调用一个内部函数_build_text_tower
来构建文本塔,并存储结果。self.transformer
: 从文本塔中提取变换器(transformer)模块。self.context_length
: 存储文本塔的上下文长度。self.vocab_size
: 存储文本塔的词汇表大小。self.token_embedding
: 存储文本塔的词嵌入层。self.positional_embedding
: 存储文本塔的位置嵌入层。self.ln_final
: 存储文本塔的最终层归一化(Layer Normalization)。self.text_projection
: 存储文本塔的文本投影层。self.register_buffer('attn_mask', text.attn_mask, persistent=False)
: 注册一个缓冲区,用于存储文本塔的注意力掩码(attention mask),这个掩码在自注意力机制中用于指示哪些位置应该被模型关注。self.logit_scale
: 创建一个可学习的参数,用于缩放模型的输出(logits),初始化为一个全1的向量,乘以一个基于经验的对数缩放因子。
锁定参数
下面两个方法的目的是为了在训练过程中冻结图像塔和文本塔的参数,这与 LiT 方法的核心思想相符,即在对比学习(contrastive learning)过程中,只更新文本模型的参数,而保持图像模型的参数不变。这样做可以利用预训练图像模型的强大特征提取能力,同时通过文本模型来适应新任务,实现零样本(zero-shot)迁移学习。
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):# lock image tower as per LiT - https://arxiv.org/abs/2111.07991self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):for param in self.transformer.parameters():param.requires_grad = Falseself.token_embedding.requires_grad = Falseself.positional_embedding.requires_grad = Falseself.text_projection.requires_grad = False
这段代码定义了两个方法,lock_image_tower
和 lock_text_tower
,它们似乎是用于控制神经网络模型中的图像塔(image tower)和文本塔(text tower)的参数更新机制。CLIP类实现了类似于 LiT(Locked-image Tuning)的机制,这是一种在图像和文本模型对齐时锁定预训练图像模型参数的技术,如在论文 "LiT: Zero-Shot Transfer with Locked-image text Tuning" 中所描述的。
lock_image_tower
方法
unlocked_groups
参数:指定哪些层组应该保持未锁定(即可训练的)。默认值为 0,意味着所有层组都被锁定。freeze_bn_stats
参数:决定是否冻结批量归一化(Batch Normalization, BN)层的统计数据。如果设置为True
,则BN层的运行时统计数据(均值和方差)不会在训练过程中更新。
在这个方法中,调用了 self.visual.lock()
,这是一个自定义的方法,用于锁定图像塔中的参数。这个方法可能会根据 unlocked_groups
参数来决定哪些层或层组应该保持可训练状态。
lock_text_tower
方法
unlocked_layers
参数:指定文本塔中应该保持未锁定的层的数量。默认值为 0,意味着所有层都被锁定。freeze_layer_norm
参数:决定是否冻结层归一化(Layer Normalization, LN)的参数。如果设置为True
,则LN层的参数不会在训练过程中更新。
在这个方法中,遍历了 self.transformer
中的所有参数,并将它们的 requires_grad
属性设置为 False
,这意味着在训练过程中这些参数不会更新。此外,还冻结了文本塔中的词嵌入(token_embedding
)、位置嵌入(positional_embedding
)和文本投影(text_projection
)的参数。
set_grad_checkpointing
方法设置梯度检查点
def set_grad_checkpointing(self, enable=True):self.visual.set_grad_checkpointing(enable)self.transformer.grad_checkpointing = enable
enable
: 布尔值,指示是否启用梯度检查点。当启用时,可以减少模型训练过程中的内存消耗,但可能会增加计算成本。
这个方法在模型的两个主要组件,visual
(视觉塔)和transformer
(文本塔)上设置梯度检查点。梯度检查点是一种内存优化技术,它允许模型在前向传播过程中保存中间激活的梯度,从而在反向传播时减少内存使用。
encode_image
方法
def encode_image(self, image, normalize: bool = False):features = self.visual(image)return F.normalize(features, dim=-1) if normalize else features
image
: 输入的图像数据。normalize
: 布尔值,指示是否对编码后的特征进行归一化。
这个方法使用模型的视觉塔来编码输入的图像数据。如果normalize
参数为True
,则使用F.normalize
函数对特征进行归一化处理,否则直接返回特征。
encode_text
方法
def encode_text(self, text, normalize: bool = False):cast_dtype = self.transformer.get_cast_dtype()x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]x = x + self.positional_embedding.to(cast_dtype)x = x.permute(1, 0, 2) # NLD -> LNDx = self.transformer(x, attn_mask=self.attn_mask)x = x.permute(1, 0, 2) # LND -> NLDx = self.ln_final(x) # [batch_size, n_ctx, transformer.width]# take features from the eot embedding (eot_token is the highest number in each sequence)x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projectionreturn F.normalize(x, dim=-1) if normalize else x
text
: 输入的文本数据。normalize
: 布尔值,指示是否对编码后的文本特征进行归一化。
这个方法首先将文本数据通过词嵌入层和位置嵌入层来获取嵌入表示,然后将这些嵌入表示转换为适合transformer的格式。接下来,使用transformer处理这些嵌入,并通过层归一化(self.ln_final
)进行归一化。最后,通过文本投影层(self.text_projection
)将嵌入映射到最终的特征空间,并进行归一化处理(如果normalize
为True
)。
前向传播函数
def forward(self,image: Optional[torch.Tensor] = None,text: Optional[torch.Tensor] = None,):image_features = self.encode_image(image, normalize=True) if image is not None else Nonetext_features = self.encode_text(text, normalize=True) if text is not None else Noneif self.output_dict:return {"image_features": image_features,"text_features": text_features,"logit_scale": self.logit_scale.exp()}return image_features, text_features, self.logit_scale.exp()
段代码定义了一个名为 forward
的方法,它是神经网络模型中的前向传播函数。该方法接收图像和文本作为输入,并输出它们的特征表示以及用于缩放 logits 的比例因子。这个方法是模型核心功能的一部分,负责将输入数据转换为模型可以处理的嵌入向量。
参数解释
self
: 指向类的实例。image
: 输入的图像数据,类型为torch.Tensor
。如果为None
,则表示没有图像输入。text
: 输入的文本数据,类型为torch.Tensor
。如果为None
,则表示没有文本输入。
方法体解释
image_features
: 如果提供了图像输入,使用self.encode_image
方法对图像进行编码,并在返回前进行归一化处理。如果没有图像输入,则设置为None
。text_features
: 如果提供了文本输入,使用self.encode_text
方法对文本进行编码,并在返回前进行归一化处理。如果没有文本输入,则设置为None
。if self.output_dict
: 判断是否以字典格式输出。如果output_dict
属性为True
,则将图像特征、文本特征和 logits 缩放因子封装成一个字典返回。否则,将这三个值作为独立的返回值。
返回值
- 如果
self.output_dict
为True
,则返回一个包含图像特征、文本特征和 logits 缩放因子的字典。 - 如果
self.output_dict
为False
,则返回一个包含图像特征、文本特征和 logits 缩放因子的元组。
总结
forward
方法是模型的入口点,它根据输入的图像和文本数据,通过模型的编码器生成对应的特征表示。这些特征表示可以用于后续的多模态任务,例如图像-文本匹配、联合嵌入学习或零样本分类等。此外,该方法还提供了 logits 缩放因子,这在某些情况下(如对比学习或分类任务)可能是必需的。通过灵活的输出格式,该方法可以适应不同的使用场景和后处理需求。