从零实现CLIP模型

1. 引言

CLIP代表语言图像对比预训练模型,是OpenAI于2021年开发的一个深度学习模型。CLIP模型中图像和文本嵌入共享相同的潜在特征空间,从而能够在两种模式之间直接进行对比学习。这是通过训练模型使相关的图像和文本更紧密地结合在一起,同时将不相关的图像在特征空间距离分开来实现的。

闲话少说,我们直接开始吧!

2. 相关应用

关于CLIP模型的一些应用总结如下:

  • 图像分类和检索:CLIP可以通过将图像与自然语言文本描述关联起来进而可用于图像分类任务。它允许更通用和灵活的图像检索系统,用户可以使用文本查询来在数据库里搜索图像。

  • 内容调节:CLIP可用于通过分析图像和附带文本来识别和过滤不适当或有害的内容,从而调节在线平台上的展示内容。

3. 核心思想

CLIP模型旨在预测一个batchN×N个潜在(img,text)配对具体哪些是实际匹配的。为了实现这一点,CLIP通过图像编码器和文本编码器的联合训练建立了一个多模态嵌入空间。CLIP的损失函数旨在最大化批处理中N个真实配对的图像和文本嵌入之间的余弦相似性,同时最小化N²−N个错误配对的余弦相似度。以下是伪代码(取自原始论文),概述了CLIP的核心实现。
在这里插入图片描述
接着我们将伪代码中每一行的逐步描述,将其转化为使用PyTorch来实现。

4. 网络结构

在进行代码实现之前,我们先来简单回顾下clip模型具体的网络结构:
在这里插入图片描述

ClIP模型使用两种独立的网络结构来作为图像编码和文本编码的主干,其中:

  • image_encoder:负责编码图像的神经网络主干(eg,ResNetVision Transformer等)。
  • text_encoder:表示负责编码文本信息的神经网络架构(eg,CBOWBERT等)。

原始CLIP模型是从零开始训练的,而没有使用预训练的权重来初始化图像编码器和文本编码器,因为它们用于训练其CLIP模型的数据集体量很大(4亿个图像-文本对)。在这篇博客文章的例子中,我们将采取一些不同的做法。我们将从resnet(用于图像)和distilbert(用于文本)模型的预训练权重开始初始化这些部分。

5. 数据输入

该模型每个批次以n个图像和文本对作为输入,其中:

  • I[n,h,w,c]:表示对齐的图像的小批次输入,其中n是batch大小,h是图像高度,w是图像宽度,c是通道数。
  • T[n,l]:表示对齐文本的小批次输入,其中n是batch大小,l是文本序列的长度。

我们的实现中,我们默认batch的大小为128,如下所示:
在这里插入图片描述

6. 特征提取

关于文本和图像的特征提取,这里使用resnet34distilbert来分别提取图像和文本的特征,如下:

  • I_f = image_encoder(I) : 从图像编码器中获取的图像特征表示I_fI_f的大小为[n,d_I],其中d_I是图像特征的维度。
  • T_f=text_encoder(T):从文本编码器中获取的文本特征表示T_fT_f的大小为[n,d_T],其中d_T是文本特征的维度。

在本文实现中,相应的代码如下:

# for encoding images
I_f = models.resnet34(pretrained=True)      
# for encoding captions
T_f= AutoModel.from_pretrained("distilbert-base-multilingual-cased") 

7. 特征映射

接着,我们将相应的文本和图像特征,映射到同一嵌入特征空间,如下:

  • W_i[d_i,d_e]:表示用于将图像特征i_f映射到嵌入特征空间i_e的投影矩阵。W_i的形状大小是[d_i,d_e],其中d_e表示的是联合嵌入特征空间的维度。
  • W_t[d_t,d_e]:表示用于将文本特征t_f映射到相同嵌入空间t_e的投影矩阵。W_t的形状大小是[d_t,d_e]

投影操作可以使用具有两个线性层的神经网络进行编码,其权重是学习的投影矩阵。在大多数情况下,投影权重是唯一可以在新数据集上需要训练的权重。此外,投影层在对齐图像和文本嵌入的尺寸方面发挥着至关重要的作用,确保它们具有相同的维度。

相应的代码实现如下:
在这里插入图片描述

8. 组合

在上一节中,我们将文本和图像特征分别统一到相同的维度,接着我们将上述相关组件进行整合:

  • I_e = l2_normalize(np.dot(I_f, W_i), axis=1) :在联合嵌入空间I_e中嵌入并归一化图像特征
  • T_e = l2_normalize(np.dot(T_f, W_t), axis=1) :在联合嵌入空间T_e中嵌入并归一化文本特征

接着我们使用以下Pytorch代码来描述图像和文本数据的处理次序。首先,相应的数据通过基本编码器进行处理,然后通过投影层进行处理。最后,为两种模态特征进行嵌入归一化化并返回。如下:

在这里插入图片描述

9. 余弦相似度

接着在嵌入空间,我们来计算文本图像嵌入特征的相似度:

  • logits = np.dot(I_e, T_e.T) * np.exp(t):用以计算图像和文本对在联合嵌入空间的特征余弦相似度,通过可学习的参数t进行缩放。

在我们的例子中,我们考虑暂不使用参数t,代码如下:

logits = T_e @ T_e.T

10. 损失函数

CLIP使用对比损失用以将相关图像和文本在嵌入特征空间拉近,同时将不相关的图像和文本距离拉远。

  • labels = np.arange(n): 用以生成表示batch索引的真值标签。
  • loss_i = cross_entropy_loss(logits, labels, axis=0):用以计算图像特征和真值标签的损失
  • loss_t = cross_entropy_loss(logits, labels, axis=1):用以计算文本特征和真值标签的损失
  • loss = (loss_i + loss_t)/2:计算图像和文本损失的加权平均值。

代码实现如下:

在这里插入图片描述

11. 构建完整模型

将所有不同的部件组合在一起,最终的自定义CLIP模型如下所示:

在这里插入图片描述

12. 构建数据集

我们的自定义CLIP模型将使用flickr30k数据集进行训练。该数据集包括31000多张图像,每张图像至少有5个独立的人工生成文本描述。在这个例子中,我们将为每个图像使用两个标题,总共有62000个图像和文本对用于训练。 代码实现如下:
在这里插入图片描述
上述模型关键常数包括用于学习表示特征空间的维度embed_dim, 用于transformer特征维度的transformer_embed_dim和用于文本输入长度的max_len。所选的text_modeldistilbert base multilanguage-cased。用以训练的模型的epoch为3,同时batch_size的大小为128,这些常数将用于模型构建和训练。如下所示:
在这里插入图片描述

13. 数据集测试用例

DataLoader是为训练期间的高效迭代而设置的,提供图像文本对的迭代访问。调用代码如下:

# Create the DataLoader
clip_dataloader = DataLoader(flickr30k_custom_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

以下是数据集中一个批次中的图像文本对的示例:

import numpy as np
import matplotlib.pyplot as plt
# Create an iterator from the dataloader
data_iter = iter(clip_dataloader)# Get one batch
batch = next(data_iter)image = batch["image"][0]  # get one image from the batch
caption = batch["caption"][0]  # get one text from the batch# Convert the image tensor to a NumPy array and permute dimensions
image_np = np.transpose(image.numpy(), (1, 2, 0))# Display the image and caption
plt.imshow(image_np)
plt.title(f"Caption: {caption}")
plt.show()

运行结果如下:
在这里插入图片描述

14. 优化器选择

此外,我们还需要指定在整个训练过程中需要优化的参数。上文中我们已经固定了文本和图像编码器的特征提取层,那么只有与投影层相关的参数才会在新的数据集上进行训练。

# Create an instance of your model
model = CustomModel().to(device)# Define optimizer
optimizer = torch.optim.Adam([{'params': model.vision_encoder.parameters()},{'params': model.caption_encoder.parameters()}
], lr=model.lr)

15. 模型训练

我们使用Tesla T4的GPU机器进行3个epoch的训练,相应的训练代码如下:
在这里插入图片描述

执行上述训练代码,可以得到训练过程如下:
在这里插入图片描述

16. 总结

总之,这篇博客文章探讨了CLIP模型,揭示了其广泛应用的潜力。随着我们对CLIP应用的了解,很明显,它的影响远远超出了最初的预期,为不同领域的创新解决方案铺平了道路。

您学废了嘛?

完整代码:戳我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/603199.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

浅谈 JVM 类加载过程

🚗🚗🚗今天给大家分享的是HTTPS加密的工作过程。 清风的CSDN博客 🛩️🛩️🛩️希望我的文章能对你有所帮助,有不足的地方还请各位看官多多指教,大家一起学习交流! ✈️✈…

四. 基于环视Camera的BEV感知算法-BEVDistill

目录 前言0. 简述1. 算法动机&开创性思路2. 主体结构3. 损失函数4. 性能对比总结下载链接参考 前言 自动驾驶之心推出的《国内首个BVE感知全栈系列学习教程》,链接。记录下个人学习笔记,仅供自己参考 本次课程我们来学习下课程第四章——基于环视Cam…

Transformer架构和对照代码详解

1、英文架构图 下面图中展示了Transformer的英文架构,英文架构中的模块名称和具体代码一一对应,方便大家对照代码、理解和使用。 2、编码器 2.1 编码器介绍 从宏观⻆度来看,Transformer的编码器是由多个相同的层叠加⽽ 成的,每个…

Auto tokenizer和Bert tokenizer的区别

"AutoTokenizer" 和 "BERT Tokenizer" 是两个不同概念,而不是两种不同的tokenizer。让我为您解释它们的含义: AutoTokenizer: "AutoTokenizer" 不是一个具体的tokenizer,而是 Hugging Face Transformers 库中提…

10款有趣的前端源码分享(附效果图及在线演示)

分享10款非常有趣的前端特效源码 其中包含css动画特效、js原生特效、svg特效以及小游戏等 下面我会给出特效样式图或演示效果图 但你也可以点击在线预览查看源码的最终展示效果及下载源码资源 自毁按钮动画特效 自毁按钮动画特效 点击打开盒子可以点击自毁按钮 进而会出现自毁…

【网络层】网际控制报文协议ICMP(湖科大慕课自学笔记)

网际控制报文协议ICMP 1:网际控制报文协议ICMP基本概述 ICMP报文被封装在IP数据报中发送 1:ICMP报文格式 ICMP报文作为IP数据报的数据载荷,IP协议为其添加一个首部使之成为IP数据报 2:ICMP报文类型 ICMP报文分为两大类&#x…

NGUI基础-三大基础组件之Event System(Uicameras)

目录 主要作用 相关参数 (建议:红色是重点,黑色的了解即可) Event Type Events go to Process Events in Event Mask​编辑 Debug Command Click Allow Multi Touch Auto Hide Cursor Sticky ToolTip/Long press ToolTip/ToolTip…

vue实现点击复制功能方法封装demo。

源码如下 copyTextToClipboard(text, that) { const textArea document.createElement("textarea"); textArea.value text; document.body.appendChild(textArea); 在子节点末尾添加元素 textArea.select(); select方法讲解可以了解一下 JavaScri…

工业协议转换网关:打破通信壁垒,实现设备互联

在工业自动化领域,各种设备和系统间的通信协议不尽相同,这给不同设备间的集成和数据交互带来了挑战。工业协议转换网关作为一种解决这一问题的关键设备,能够实现不同协议间的转换和数据传输,打破通信壁垒,提高设备的协…

【PostgreSQL】模式Schema

PostgreSQL 数据库集群包含一个或多个命名数据库。角色和一些其他对象类型在整个集群中共享。与服务器的客户端连接只能访问单个数据库中的数据,该数据库在连接请求中指定。 数据库包含一个或多个命名schema,而这些schema又包含表。schema还包含其他类型…

Java 基础知识点1 (含面试题)

本次Java 知识点主要是关于SE的相关基础,同时也包含了数据结构中的一些API,例如Set,List,Map等,最后也附上了相关重要的面试题,可供大家学习与参考! 目录 重要知识点数据结构API面试题 重要知识点 Java 是一门面向对象…

税法相关的基础知识

文章目录 税法原则1.税法基本原则2.税法适用原则3.税收收入划分 来和大家聊聊税法相关的基础知识 税法原则 1.税法基本原则 2.税法适用原则 3.税收收入划分

Flume基础知识(十):Flume 聚合实战

1)案例需求: hadoop100上的 Flume-1 监控文件/opt/module/group.log, hadoop101上的 Flume-2 监控某一个端口的数据流, Flume-1 与 Flume-2 将数据发送给 hadoop102 上的 Flume-3,Flume-3 将最终数据打印 到控制台。…

数据库事务的特性

数据库事务具有 ACID 特性,其中 ACID 是指原子性(Atomicity)、一致性(Consistency)、隔离性(Isolation)和持久性(Durability)。这些特性是为了确保数据库在事务处理中的可…

Android 13.0修改recovery 菜单项字体大小

1.概述 在13.0的系统rom定制化开发中,在系统进入recovery模式后,界面会g_menu_actions 菜单选项和 提示文字,而这些文字的大小不像上层一样是通过设置属性来表示大小的 而它确是通过字体png图片的大小来计算文字的宽和高的,然后可以修改字体大小,接下来就实现菜单项字体大…

RocketMQ详细介绍及核心问题解释(很全)

1. RocketMq是什么 一个纯Java、分布式队列模型的消息中间件,具有高可用、高可靠、高实时、低延迟的特点。(记住这句就行了) 2. RocketMq有什么功能 1、业务解耦:这也是发布订阅的消息模型。生产者发送指令到MQ中,然…

python中parsel模块的css解析

一、爬虫页面分类 1.想要爬取的内容全部在标签中,可以使用xpath去进行解析如下图 2.想要爬取的内容呈现json的数据特征,用.json()转换为字典格式 3.页面不规则,标签中包含大括号,如下面想要获取键值内容怎么做,先用re正…

Binius:基于binary fields的SNARKs(Part 2)

1. 引言 前序博客有: Binius:基于binary fields的SNARKs(Part 1)Binius:助力ZK行业发展 本文重点关注: 1)concatenated codes:可扩展对small fields的多项式承诺方案2&#xff0…

Docker学习与应用(六)-Docker网络

1、Docker网络 Docker有多种网络模式可以选择,可以根据应用场景和需求选择合适的网络模式。 桥接模式(Bridge Mode):默认情况下,Docker使用桥接模式创建一个虚拟网络,所有容器会连接到这个虚拟网络中。每个…

回家用go?还是go to?

语法 go 副词 go to 名词 home比较特殊 前面无修饰词就是副词 前面有修饰词就是名词 案例