CLIP大模型图文检索——原理解读及代码实现

图片

一. 核心思想

通过自然语言处理获得的监督信号可用于训练迁移效果出色的视觉模型。本论文的作者团队构建了一个庞大的图像文本配对数据集,其中包含400 million个图片文本的配对。利用最大规模的ViT-large模型,他们提出了CLIP(Contrastive Language-Image Pre-training)方法,这是一种有效的从自然语言监督中学习的方法。研究团队在30个数据集上进行了实验,结果显示CLIP模型的性能与之前的有监督模型相当,甚至更好。

二. 模型实现

图片

(1)CLIP的训练过程

CLIP的训练过程是基于图像和文字配对的数据,其中图像输入经过图像编码器得到特征,而文本输入则经过文本编码器得到特征。每个训练批次包含n个图像-文本配对,从而获得n个图像特征和n个文本特征。随后,利用这些特征进行对比学习,其中对比学习的灵活性要求定义正样本和负样本。在这里,配对的图像-文本对即为正样本,因为它们描述的是同一物体。在特征矩阵中,对角线上的元素表示正样本,而非对角线上的元素则表示负样本。有了正负样本,模型便可以通过对比学习的方式进行无监督训练。这种无监督训练方式需要大量的训练数据支持。

(2)CLIP的推理过程

在预训练之后,CLIP模型只能得到图像和文本的特征,而没有分类头。为了进行分类,作者提出了一种利用自然语言的方法,即"prompt template"。例如,对于ImageNet的类别,可以将其转化为类似"A photo of a {object}"这样的句子,对于ImageNet的1000个类别,就可以生成1000个这样的句子。然后,通过之前预训练好的文本编码器,可以得到这1000个句子对应的文本特征。虽然也可以直接使用类别单词提取文本特征,但在预训练阶段,图像与文本的配对是以句子形式出现的,因此在推理阶段使用单词效果会下降。推理时,将需要分类的图像送入图像编码器以获取特征,然后计算图像特征与1000个文本特征的余弦相似度,选择最相似的文本特征对应的句子,从而完成分类任务。CLIP模型不仅局限于这1000个类别,任何类别都可以进行分类,因此彻底摆脱了分类标签的限制,无需在训练和推理阶段提前定义好标签列表。

(3)CLIP的损失函数

CLIP的损失函数使用了对称的损失函数,其中包括图像编码器、文本编码器、学习的投影矩阵、以及温度参数。具体步骤包括提取各模态的特征表示,计算它们之间的余弦相似度,然后应用交叉熵损失函数计算图像和文本的损失。通过将两个损失的平均值作为最终损失,得到了模型的整体损失。

# Image encoder - ResNet or Vision Transformer
# Text encoder - CBOW or Text Transformer
# Input: minibatch of aligned images I[n, h, w, c] and minibatch of aligned text T[n, 1]
# Parameters: W_i[d_i, d_e] learned projection of image to embed, W_t[d_t, d_e] learned projection of text to embed, t learned temperature parameter
# Extract feature representations of each modality
I_f = image_encoder(I)  # [n, d_i]
T_f = text_encoder(T)   # [n, d_t]
# Joint multimodal embedding [n, d_e]
I_e = 1/2 * normalize(np.dot(I_f, W_i), axis=1)
T_e = 1/2 * normalize(np.dot(T_f, W_t), axis=1)
# Scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# Symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t) / 2

模型接收两个输入,一个是图片,一个是文本。图片的维度为[n,h,w,c],文本的维度为[n,l],其中l是序列长度。这些输入分别通过图片编码器和文本编码器提取特征,然后经过一个投射层学习如何从单模态变为多模态。投射完成后,对特征进行l2范数归一化,得到最终用于对比的特征。接下来,计算余弦相似度得到对比学习的logits。最后,使用对称的损失函数计算loss,其中正样本为对角线上的元素。损失函数包括图片损失和文本损失,将两者加起来并取平均。这种操作在对比学习中很常见,是一种对称的目标函数。

三.API代码实现

论文地址:https://arxiv.org/pdf/2103.00020.pdf

代码地址:https://github.com/openai/CLIP

文末可快速免费获取论文和代码~~~

import torch
import clip
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torch.no_grad():image_features = model.encode_image(image)text_features = model.encode_text(text) logits_per_image, logits_per_text = model(image, text)probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

四.论文实验总结

在CLIP预训练完成后,系统具备两个编码器:一个用于图像,一个用于文本。在推理过程中,给定一张图片,通过图像编码器可得到该图片的特征。对于文本方面的输入,则包括用户感兴趣的标签,例如"plane"、"car"、"dog"等。这些标签经过prompt工程处理,转换成对应的句子,如"A photo of a plane"、"A photo of a dog"等。一旦得到这些句子,它们就会被送入文本编码器,以获取相应的文本特征。假设有三个标签,分别是"plane"、"car"、"dog",那么通过文本编码器得到对应的文本特征。接下来,将这三个文本特征与图片的特征进行余弦相似度计算,得到相似度后再经过softmax处理,得到一个概率分布。其中概率最大的那个句子,就是最可能描述这张图片的句子。

👇👇👇

免费领取方式

在下方公众号内回复关键词:CLIP

如果你想要进一步了解更多的相关知识,可以关注下面公众号联系~会不定期发布相关设计内容包括但不限于如下内容:信号处理、通信仿真、算法设计、matlab appdesigner,gui设计、simulink仿真......希望能帮到你!

5a8015ddde1e41418a38e958eb12ecbd.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/812214.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

两个链表的交集(力扣349)

题目如下: 给定两个数组 nums1 和 nums2 ,返回 它们的 交集 。输出结果中的每个元素一定是 唯一 的。我们可以 不考虑输出结果的顺序 。 示例 1: 输入:nums1 [1,2,2,1], nums2 [2,2] 输出:[2]示例 2:…

go 利用channel实现定时任务

package mainimport ("fmt""net/http""time" )func main() {// 创建一个定时器,每隔1秒钟执行一次ticker : time.NewTicker(1 * time.Second)done : make(chan bool)//设置3s超时,避免请求时间过长client : http.Client{T…

Postgresql获取指定时间前的时间

1、获取指定时间前12小时数据 SELECT* FROMdispatch_team_real WHEREto_timestamp( start_time, YYYY-MM-DD HH24:MI:SS ) ( to_timestamp( 2023-09-17 06:00:00, YYYY-MM-DD HH24:MI:SS ) - INTERVAL 12 HOUR ) AND to_timestamp( end_time, YYYY-MM-DD HH24:MI:SS ) ( t…

机器学习和深度学习 -- 李宏毅(笔记与个人理解)Day 13

Day13 Error surface is rugged…… Tips for training :Adaptive Learning Rate critical point is not the difficult Root mean Square --used in Adagrad 这里为啥是前面的g的和而不是直接只除以当前呢? 这种方法的目的是防止学习率在训练过程中快速衰减。如果只用当前的…

自然语言处理NLP关键知识点

大家好,在人工智能出现之前,机器智能处理结构化的数据,例如 Excel 里的数据。但是网络中大部分的数据都是非结构化的,例如文章、图片、音频、视频等。在非结构数据中,文本的数量是最多的,他虽然没有图片和视…

信息系统项目管理师——第27章管理科学基础知识

1 最大流量问题[简单] 百度百科:最大流问题,一种组合最优化问题,就是要讨论如何充分利用装置的能力,使得运输的流量最大,以取得最好的效果。 教材P869:在起点和终点之间可能存在多条运输路径,总的最大流量就是求出各…

智能EDM邮件营销推广工具哪个好?

有效且精准的客户沟通已经成为企业成功的关键要素之一,云衔科技以其尖端的智能EDM邮件营销系统解决方案脱颖而出,为全球各行业的企业提供了一个强有力的竞争优势和业绩增长引擎。 云衔科技深谙市场营销的艺术与科学,凭借多年积累的专业技术研…

C#:判断一个数是不是水仙花数

任务描述 本关任务:编写一个程序,判断从键盘输入的数是不是水仙花数。 水仙花数是指一个3位数字,它各位数字的3次幂之和等于它本身。如153是一个水仙花数,因为: 1531 3 5 3 3 3 相关知识 为了完成本关任务&am…

SPI 机制

一、简述 本文介绍 SPI 机制。 二、什么是 SPI 机制 SPI(Service Provider Interface)机制是 Java 编程语言中的一种机制,用于实现组件之间的解耦和扩展。SPI 允许开发者编写服务接口(Service Interface)&#xff0…

Python基础教程

随着科技的快速发展,编程已成为一项重要的技能。在众多编程语言中,Python因其简洁、易读、强大的功能库而备受青睐。无论你是编程新手,还是希望了解Python的开发者,本文都将为你提供一个Python基础教程,带你走进Python…

计算机网络 路由器基本配置

一、实验内容 1、按照下表配置好PC机IP地址和路由器端口IP地址 2、配置好路由器特权密文密码“abcd+两位班内序号”和远程登录密码“star” 3、验证测试 a.验证各个接口的IP地址是否正确配置和开启 b.PC1 和 PC2 互ping c.验证PC1通过远程登陆到路由器上&#…

目前深圳嵌入式单片机就业环境如何?

深圳作为中国的科技创新中心之一,嵌入式行业的就业环境相对较好。我这里有一套嵌入式入门教程,不仅包含了详细的视频讲解,项目实战。如果你渴望学习嵌入式,不妨点个关注,给个评论222,私信22,我在…

docker 上达梦导入dump文件报错:本地编码:PG GBK,导入女件编码:PGGB18030

解决方案: 第一步进入达梦数据容器内部 docker exec -it fc316f88caff /bin/bash 第二步:在容器中 /opt/dmdbms/bin目录下 执行命令 cd /opt/dmdbms/bin./dimp USERIDSYSDBA/SYSDBA001 FILE/opt/dmdbms/ZFJG_LJ20240407.dmp SCHEMASZFJG_LJUSERIDSYSD…

Lua语法(三)——元表与元方法

参考链接: 系列链接: Lua语法(一) 系列链接: Lua语法(二)——闭包/日期和时间 系列链接: Lua语法(三)——元表与元方法 系列链接: Lua语法(四)——协程 系列链接: Lua语法(五)——垃圾回收 系列链接: Lua语法(六)——面相对象编程 元表与元方法目录 简介正文元表元方法表相关常…

linux安装

1、解压vm ware压缩包 2双击安装 3点击自定义硬件 4双击cd/dvd,给虚拟光驱里放虚拟光盘 5记得启动时链接勾上,勾上起点系统时 虚拟光驱才会一起启动 6点击确认即可! 开机 选择第一个 7进入图形化安装界面 8设置时区 9选择硬盘 10网络配置 开启以太网&am…

C语言进阶课程学习记录-数组指针和指针数组分析

C语言进阶课程学习记录-数组指针和指针数组分析 实验-数组指针的大小实验-指针数组小结 本文学习自狄泰软件学院 唐佐林老师的 C语言进阶课程&#xff0c;图片全部来源于课程PPT&#xff0c;仅用于个人学习记录 实验-数组指针的大小 #include <stdio.h>typedef int(AINT…

简述Java中synchronized关键字的底层工作原理

在Java中&#xff0c;synchronized 关键字是一个重要的同步机制&#xff0c;用于控制多线程对共享资源的访问&#xff0c;以防止并发问题。了解 synchronized 的底层工作原理&#xff0c;可以帮助我们更好地编写线程安全的代码。synchronized 关键字可以应用于方法或者代码块&a…

【MoS2】应变增强的单层MoS2光电探测器

这篇文章的标题是《Strain-Enhanced Large-Area Monolayer MoS2 Photodetectors》&#xff0c;作者是Borna Radatovic等人&#xff0c;发表在《ACS Applied Materials & Interfaces》期刊的2024年第16卷。文章主要研究了应变增强的大面积单层MoS2光电探测器的性能和应用潜力…

【题目】【信息安全管理与评估】2022年国赛高职组“信息安全管理与评估”赛项样题1

【题目】【信息安全管理与评估】2022年国赛高职组“信息安全管理与评估”赛项样题1 信息安全管理与评估 网络系统管理 网络搭建与应用 云计算 软件测试 移动应用开发 任务书&#xff0c;赛题&#xff0c;解析等资料&#xff0c;知识点培训服务 添加博主wx&#xff1a;liuliu548…

Testng测试框架(3)-数据驱动TestNG@DataProvider

TestNG 是一个强大的 Java 测试框架&#xff0c;它提供了许多高级功能&#xff0c;如参数化测试、依赖注入、分组等。其中&#xff0c;DataProvider 是 TestNG 中一个非常有用的注解&#xff0c;用于为测试方法提供数据。 DataProvider 的作用 使用 DataProvider 注解的方法可…