从零开始搭建CLIP模型实现基于文本的图像检索

目录

  • CLIP原理简介
  • 代码实现
  • 参考链接

CLIP原理简介

论文链接,源码链接

CLIP模型由OpenAI在2021年提出,利用双Decoder(Dual Encoder)的架构来学习图像和文本之间的对应关系,是多模态大模型的开创之作,为后续许多高效的多模态模型的提出打下基础。CLIP是一个预训练模型(Pre-trained Model),在学习到图像–文本特征之间的关联后可以迁移到各种下游任务中,如图像分类,文本引导图像分割和目标检测,图像文本检索等。由于模型学习到的是文本语义和图像语义之间的关联,使得其zero-shot能力非常强大,根据论文中的描述,CLIP在很多数据集上zero-shot的结果甚至超越了许多训练好的模型的效果。CLIP的训练范式如下:

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/1d112d364a60434bba8dd07d42d2a1c6.png

CLIP的结构非常简单,数据集包含大量的图像文本对,图像经过图像编码器得到图像特征,文本经过文本编码器得到文本特征,将图像特征和文本特征按照数据集中的对应关系进行配对,不配对的特征给予惩罚,从上图中可以看出,我们希望矩阵中蓝色的值趋近于1,其余值趋近于0,采用对比学习的方式对模型进行训练,算法的伪代码如下:

在这里插入图片描述
从损失函数中可以看出,分别对特征对比矩阵的行和列进行交叉熵损失函数计算,并取平均得到最终的loss。图像编码器一般有两种选择:ResNet和ViT;文本编码器采用Transformer Encoder,均是各自领域中优秀的特征提取网络。
CLIP的推理范式如下:

在这里插入图片描述
在推理阶段,图像编码器中输入图像获取图像特征,文本编码器中输入文本获取文本特征,将图像特征向量和文本特征向量的转置相乘得到每张图像对每个文本的特征相似度,相似度最高的文本即描述了该图像中物体所属的类别。

代码实现

Flickr8k数据集下载,提取码:fbfz
DistilBert模型文件下载

我的运行环境:
CUDA 11.8
pytorch 2.2.2
transformers 4.44.0 # 用于从HuggingFace上加载预训练模型


数据集预览:
图片示例

图片示例

在这里插入图片描述

文本示例

由于作者的显卡算力有限,选取Flickr8k数据集进行模型训练,其中包含8k个图像文本对,其中一张图像对应5条文本。图像编码器采用ResNet50,直接从timm库中导入;文本编码器采用DistilBert,即轻量化的Bert模型,从HuggingFace上下载。闲话少说,小二,上菜!

### 模型参数配置 ###
import argparse
from dataclasses import dataclassparser = argparse.ArgumentParser(description="CLIP from zero")
parser.add_argument("--image_dir", default="user/Flickr8k/Images", help='path to image folder')  # 存放图像的文件路径
parser.add_argument("--caption_dir", default="user/Flickr8k", help='path to caption folder')  # 存放文本的文件路径
parser.add_argument("--weight_dir", default='user/checkpoints', help='path to save output weight')  # 存放训练权重的文件路径
args = parser.parse_args()@dataclass
class CLIPConfig:image_path: str = args.image_dir  # 图像存放路径image_size: int = 224  # resize后的图像尺寸,便于构建Dataloadercaption_path: str = args.caption_dir  # 文本存放路径batch_size: int = 8  # 一个批次中的数据数量epochs: int = 3  # 训练世代image_encoder_model: str = "resnet50"  # 图像编码器的名称image_embedding_dim: int = 2048  # 图像特征的维度text_encoder_model: str = "distilbert-base-uncased"  # 文本编码器的名称text_embedding_dim: int = 768  # 文本特征的维度text_tokenizer: str = text_encoder_model  # 文本分词器模型的名称max_length: int = 200  # 文本编码器可输入的最长文本长度pretrained: bool = False  # 是否加载预训练好的编码器trainable: bool = True  # 在训练过程中是否更新编码器的参数temperature: float = 1.0  # 计算loss时的正则化系数proj_dim: int = 256  # 图像特征和文本特征统一后的维度dropout_rate: float = 0.1  # dropout系数,避免过拟合### 载入数据集并初始化 ###
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import pandas as pd
import cv2class CLIPDataset(Dataset):def __init__(self, config, image_path, caption_path, transforms=True):"""图片文件名和标题的长度必须相同如果一个图片对应多个标题,该图片文件名需要重复多次"""self.image_path = image_path  # 图像路径self.caption_path = caption_path  # 文本路径self.dataframe = pd.read_csv(f"{self.caption_path}/captions.csv")  # 读取文本self.tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)  # 载入分词器self.image_filenames = self.dataframe["image"].values  # 获取图像文件名self.captions = list(self.dataframe["caption"].values)   # 获取图像对应的描述文本self.encoded_captions = self.tokenizer(self.captions, padding=True, truncation=True, max_length=config.max_length)  # 文本分词self.transforms = transforms  # 对输入图像进行预处理def __getitem__(self, idx):  # 获取数据集中第idx个数据,其中包含图片名称和对应的标题(可能不止一个)item = {key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()}image = cv2.imread(f"{self.image_path}/{self.image_filenames[idx]}")  # 获取原始图像image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transforms:image = self.get_transforms(mode="train")(image=image)["image"]  # 对图像进行预处理item["image"] = torch.tensor(image).permute(2, 0, 1).float()  # 将图片转换为tensor格式,并调整为RGB顺序item["caption"] = self.captions[idx]  # 获取标题return itemdef __len__(self):return len(self.captions)  # 获取文本长度def get_transforms(self, mode="train"):if mode == "train":return A.Compose([A.Resize(config.image_size, config.image_size, always_apply=True),  # 对图像进行resizeA.Normalize(max_pixel_value=255.0, always_apply=True)  # 对像素值进行归一化])### 图像编码器 ###
import torch.nn as nn
import timmclass ImageEncoder(nn.Module):"""图像编码器,采用ResNet50"""def __init__(self, config):super().__init__()self.model = timm.create_model(config.image_encoder_model, pretrained=config.pretrained, num_classes=0, global_pool="avg")  # 创建ResNet50for p in self.model.parameters():p.requires_grad = config.trainable  # 设置参数可训练def forward(self, x):image_encoded = self.model(x)  # 获得图像特征编码,形状为[batch_size, image_embedding_dim]return image_encoded### 文本编码器 ###
class TextEncoder(nn.Module):"""文本编码器,采用DistilBERT"""def __init__(self, config):super().__init__()if config.pretrained:self.model = DistilBertModel.from_pretrained(config.text_encoder_model)  # 导入下载好的模型文件else:self.model = DistilBertModel(DistilBertConfig())for p in self.model.parameters():p.requires_grad = config.trainable  # 设置参数可训练self.target_token_idx = 0# 提取出和图像对应的文本特征向量def forward(self, input_ids, attention_mask):output = self.model(input_ids=input_ids, attention_mask=attention_mask)text_encoded = output.last_hidden_state[:, self.target_token_idx, :]  # [batch_size, text_embedding_dim]return text_encoded### 投影层 (MLP) ###
class ProjectionHead(nn.Module):"""将图像编码和文本编码映射到相同维度"""def __init__(self, config, input_embedding_dim):super().__init__()self.proj = nn.Linear(input_embedding_dim, config.proj_dim)self.act_fn = nn.GELU()self.fc = nn.Linear(config.proj_dim, config.proj_dim)self.dropout = nn.Dropout(config.dropout_rate)self.layer_norm = nn.LayerNorm(config.proj_dim)def forward(self, x):x_proj = self.proj(x)x = self.act_fn(x_proj)x = self.fc(x)x = self.dropout(x)x = x + x_projx = self.layer_norm(x)return x### 定义损失函数 ###
def cross_entropy(logits, labels, reduction='none'):log_softmax = nn.LogSoftmax(dim=-1)loss = (-labels * log_softmax(logits)).sum(dim=1)if reduction == 'mean':return loss.mean()else:return loss.sum()### 模型主体 ###
import torch.nn.functional as Fclass CLIP(nn.Module):def __init__(self, config):super().__init__()self.image_encoder = ImageEncoder(config)  # 实例化图像编码器self.text_encoder = TextEncoder(config)  # 实例化文本编码器self.image_proj = ProjectionHead(config, config.image_embedding_dim)  # 图像特征投影self.text_proj = ProjectionHead(config, config.text_embedding_dim)  # 文本特征投影self.temperature = config.temperaturedef forward(self, batch):image_features = self.image_encoder(batch["image"])  # 图像编码# 文本编码,tokenizer处理后的文本序列自带input_ids和attention_masktext_features = self.text_encoder(batch["input_ids"], batch["attention_mask"])image_embeddings = self.image_proj(image_features)  # 图像特征投影text_embeddings = self.text_proj(text_features)  # 文本特征投影logits = (text_embeddings @ image_embeddings.T) / self.temperature  # tensor形状为[batch_size, batch_size]images_similarity = image_embeddings @ image_embeddings.T  # tensor形状为[batch_size, batch_size]text_similarity = text_embeddings @ text_embeddings.T  # tensor形状为[batch_size, batch_size]# 软标签,不配对的位置设置为较小的数,而非0labels = F.softmax((images_similarity + text_similarity) / 2 * self.temperature, dim=-1)  loss_T = cross_entropy(logits, labels)  # 计算文本损失loss_I = cross_entropy(logits.T, labels.T)  # 计算图像损失total_loss = (loss_T + loss_I) / 2  # 对比学习平均损失return total_loss, logits### 训练函数 ###
def train(model, optimizer, scheduler, train_loader, device):model.train()  # 模型设置为训练模式train_loss = 0train_loader = tqdm(train_loader, total=len(train_loader))  # 显示训练进度条cnt = 0for batch in train_loader:# print(batch.keys())cnt += 1batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}  # 将dataloader中一个batch的数据转换为字典形式loss, _ = model(batch)optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step(metrics=loss.item())  # 根据上次训练的损失更新学习率train_loss += loss.item()# 训练100个batch显示一次lossif cnt % 100 == 0:print(f' ==> Epoch: {epoch + 1}, Batch: {cnt}, Loss: {loss.item():.4f}')return train_loss / len(train_loader)  # 平均训练损失### 测试函数 ###
def eval(model, val_loader, device):model.eval()  # 模型设置为测试模式val_loss = 0val_loader = tqdm(val_loader, total=len(val_loader))with torch.no_grad():for batch in val_loader:batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}loss, _ = model(batch)val_loss += loss.item()return val_loss / len(val_loader)  # 平均测试损失if __name__ == '__main__':config = CLIPConfig()  # 实例化配置信息model = CLIP(config)  # 实例化CLIP模型device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 查看模型的总参数量total_params = sum(p.numel() for p in model.parameters())print(f"Total parameters: {total_params / 1e6} M")optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)dataset = CLIPDataset(config, args.image_dir, args.caption_dir)  # 读取并预处理数据train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])  # 80%为训练数据,20%为测试数据dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)# 开始训练best_loss = float("inf")for epoch in range(config.epochs):print(f"Epoch: {epoch + 1}")train_loss_avg = train(model, optimizer, scheduler, train_loader, device)val_loss_avg = eval(model, val_loader, device)if val_loss_avg < best_loss:best_loss = val_loss_avgtorch.save(model.state_dict(), f'{args.weight_dir}' + f'/CLIP_{epoch}.pth')print("Best model saved!")# 图像文本检索推理并可视化# dataframe = pd.read_csv(f"{config.caption_path}/captions.csv")# tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)# model.load_state_dict(torch.load(f'{args.weight_dir}' + f'/CLIP_1.pth', map_location=device))# model.eval()# # image_embeddings = []# with torch.no_grad():#     for batch in tqdm(dataloader):#         image_features = model.image_encoder(batch["image"].to(device))  # 获取图像特征#         cur_image_embeddings = model.image_proj(image_features)  # [batch_size, proj_dim]  # 图像特征投影#         image_embeddings.append(cur_image_embeddings)  # 将一个batch的图像特征保存# # image_embeddings = torch.cat(image_embeddings, dim=0)  # [image_number, proj_dim]# input_query = "two dogs sitting on the grass"  # 输入文本# image_filenames = dataframe["image"].values  # 待检索的图片# # encoded_query = tokenizer([input_query])  # 对输入文本进行分词# batch = {key: torch.tensor(values).to(device) for key, values in encoded_query.items()}# # with torch.no_grad():#     text_features = model.text_encoder(batch["input_ids"], batch["attention_mask"])  # 获取文本特征#     text_embeddings = model.text_proj(text_features)  # 文本特征投影,与图像特征维度一致# # image_embeddings_n = F.normalize(image_embeddings, dim=-1)  # [image_number, proj_dim]# text_embeddings_n = F.normalize(text_embeddings, dim=-1)  # [1, proj_dim]# dot_similarity = text_embeddings_n @ image_embeddings_n.T  # 输入文本的特征和数据集中每张图像特征之间的相似度# # values, indices = torch.topk(dot_similarity.squeeze(0), k=45)  # 获取前45个相似度最高的图像# matches = [image_filenames[idx] for idx in indices[::5]]  # 获取对应的图像文件名(9张图像)# # f, axes = plt.subplots(3, 3, figsize=(10, 10))# f.suptitle(f"Retrieving text: {input_query}")  # 设置主标题# for match, ax in zip(matches, axes.flatten()):  # 显示检索出的图像#     image = cv2.imread(f"{args.image_dir}/{match}")#     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#     ax.imshow(image)#     ax.axis("off")# # plt.show()

理想结果:

在这里插入图片描述

参考链接

https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/77551.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

熊海cms代码审计

目录 sql注入 1. admin/files/login.php 2. admin/files/columnlist.php 3. admin/files/editcolumn.php 4. admin/files/editlink.php 5. admin/files/editsoft.php 6. admin/files/editwz.php 7. admin/files/linklist.php 8. files/software.php 9. files…

[Java微服务组件]注册中心P3-Nacos中的设计模式1-观察者模式

在P1-简单注册中心实现和P2-Nacos解析中&#xff0c;我们分别实现了简单的注册中心并总结了Nacos的一些设计。 本篇继续看Nacos源码&#xff0c;了解一下Nacos中的设计模式。 目录 Nacos 观察者模式 Observer Pattern观察者模式总结 Nacos 观察者模式 Observer Pattern 模式定…

电脑 访问 github提示 找不到网页,处理方案

1、找到 本机的 host文件 例如 windows 的 一般在 C:\Windows\System32\drivers\etc\hosts 用管理员身份打开 hosts 文件 如果文件中没有 github的配置&#xff0c;需要自己手动添加上去&#xff1b; 如果有&#xff0c;则需要 检查 github.com 与 github.global.ssl.fastly.…

Linux系统中的网络管理

1.RHEL9版本中&#xff0c;使用nm进行网络配置&#xff0c;ifcfg不再是网络配置文件的主存储&#xff0c;样式仍然可用&#xff0c;但它不再是NetworkManger存储新网络配置文件的默认位置&#xff0c;RHEL以key-file格式在etc/NetworkManger/system-connections/中存储新的网络…

AI技术深度解析:从移动芯片到AIoT的全面突破

作为全球无线通信技术和半导体解决方案的重要参与者,高通始终将技术创新作为核心驱动力,在移动通信、物联网(IoT)、汽车电子、AI计算等领域占据关键地位。本文将从其核心产品线、技术突破、应用场景及未来布局四个维度,客观解析高通的技术积累与行业角色。 一、核心产品线…

使用CS Roofline Toolkit测量带宽

使用CS Roofline Toolkit测量带宽 工程下载&#xff1a;使用CS Roofline Toolkit测量带宽-案例工程文件&#xff0c;也可以按照下面的说明使用git clone下载 目录 使用CS Roofline Toolkit测量带宽0、Roofline模型理解1、CS Roofline Toolkit下载1.1、设置代理1.2、git clone下…

EAGLE代码研读+模型复现

要对代码下手了&#xff0c;加油(ง •_•)ง 作者在他们自己的设备上展现了推理的评估结果&#xff0c;受第三方评估认证&#xff0c;EAGLE为目前最快的投机方法&#xff08;虽然加速度是评估投机解码方法的主要指标&#xff0c;但其他点也值得关注。比如PLD和Lookahead无需额…

基于SFC的windows修复程序,修复绝大部分系统损坏

效果:可以自动修复大部分由系统文件损坏而导致的错误 例如:系统应用无法打开 系统窗口(例如开始菜单)无法使用 电脑蓝屏或者卡死.....文章 01技术背景 Windows自带了一个SFC命令行应用程序,可以检查大部分的系统文件错误,以及复这些文件 其中自动检查所有系统文件&#x…

liunx日志问题

一、日志定向 Linux 系统的日志配置文件&#xff08;如/etc/syslog.conf或/etc/rsyslog.conf &#xff09;中&#xff0c;用于定义系统日志的记录规则&#xff0c;决定哪些类型的日志消息会被记录到特定的日志文件中。 *.info;mail.none;authpriv.none;cron.none /va…

2.凸包优化求解

1.减而治之(Decrease and Conquer) 插入排序 典型的减而治之算法就是插入排序方法 插入排序法: 在未排序中选择一个元素&#xff0c;插入到已经排序号的序列中 将凸包也采用减而治之的方法 2.In-Convex-Polygon Test 怎么判断引入的极点存在于多边形里面还是外面&#xff1…

系统思考:危机中的转型机遇

“危机不仅是挑战&#xff0c;更是转型的机会” 每当大事发生&#xff0c;很多企业老板常常被眼前的困境压得喘不过气&#xff0c;焦虑与压力让人难以思考长远。特别是在危机面前&#xff0c;大家忙于应对眼前的风险&#xff0c;却忽略了背后隐藏的机遇。而危机&#xff0c;恰…

大模型Rag - 如何评估Rag

一.RAG流程与评估标准补充 RAG&#xff08;Retrieval-Augmented Generation&#xff09;是一种结合检索与生成的问答架构。为了确保系统效果&#xff0c;需要从以下三个角度对其评估&#xff1a; 回顾RAG流程 用户提出问题 → 系统检索相关上下文 → 基于上下文由大语言模型…

Linux RT RT RT

RT的最终目的是尽可能多的让原来系統不可抢占的部分变成可抢占&#xff0c;让高优先级的程序先跑。这里的rt引入了一个deadline的说法&#xff0c;此时的实时性是保证在最大一个时间间隔内&#xff0c;程序被执行。比如每100ms算法做一次决策。 所以此时面临着几座大山…

演员柳琦正式加入创星演员出道计划,开创演艺事业新天地

4月18日&#xff0c;演员柳琦正式加入“创星演员出道计划”&#xff0c;不仅得到参演都市爱情喜剧《和我结婚吧》角色的机会&#xff0c;还获得文旅精品网剧《醉梦灵州》的出演机会&#xff0c;自此开启全新影视之路。对表演艺术极具天赋的柳琦&#xff0c;相信未来可以凭借自身…

16.Chromium指纹浏览器开发教程之WebGPU指纹定制

WebGPU指纹概述 WebGPU是下一代的Web图形和计算API&#xff0c;旨在提供高性能的图形渲染和计算能力。它是WebGL的后继者&#xff0c;旨在利用现代GPU的强大功能&#xff0c;使得Web应用能够实现接近原生应用的图形和计算性能。而且它是一个低级别的API&#xff0c;可以直接与…

HTTP:九.WEB机器人

概念 Web机器人是能够在无需人类干预的情况下自动进行一系列Web事务处理的软件程序。人们根据这些机器人探查web站点的方式,形象的给它们取了一个饱含特色的名字,比如“爬虫”、“蜘蛛”、“蠕虫”以及“机器人”等!爬虫概述 网络爬虫(英语:web crawler),也叫网络蜘蛛(…

Vue3+TS中svg图标的使用

安装依赖 pnpm i vite-plugin-svg-icons -D配置引入 vite.config.ts ... import { createSvgIconsPlugin } from vite-plugin-svg-icons import path from node:pathconst svgIconsPlugin createSvgIconsPlugin({iconDirs: [path.resolve(process.cwd(), src/assets/icons)]…

【java实现+4种变体完整例子】排序算法中【堆排序】的详细解析,包含基础实现、常见变体的完整代码示例,以及各变体的对比表格

以下是堆排序的详细解析&#xff0c;包含基础实现、常见变体的完整代码示例&#xff0c;以及各变体的对比表格&#xff1a; 一、堆排序基础实现 原理 基于二叉堆结构&#xff08;最大堆&#xff09;&#xff0c;通过以下步骤实现排序&#xff1a; 构建最大堆&#xff1a;将…

论文阅读笔记:Generative Modeling by Estimating Gradients of the Data Distribution

1、参考来源 论文《Generative Modeling by Estimating Gradients of the Data Distribution》 来源&#xff1a;NeurIPS 2019 论文链接&#xff1a;https://arxiv.org/abs/1907.05600 参考链接&#xff1a; 【AI知识分享】真正搞懂扩散模型Score Matching一定要理解的三大核心…

Kubernetes相关的名词解释CNI插件(1)

&#xff08;一&#xff09;什么是CNI插件&#xff1f; 在 Kubernetes 中&#xff0c;CNI 插件&#xff08;Container Network Interface Plugin&#xff09; 是一种用于配置容器网络接口的标准工具&#xff0c;负责为 Pod 分配网络资源&#xff08;如 IP 地址&#xff09;并建…