自定义数据 微调CLIP (结合paper)

CLIP 是 Contrastive Language-Image Pre-training 的缩写,是一个擅长理解文本和图像之间关系的模型,下面是一个简单的介绍:

优点: CLIP 在零样本学习方面特别强大,它可以(用自然语言)给出图像的描述,并在基于该描述对新图像进行分类方面表现良好,例如,您可以将图像描述为“a”。猫的黑白照片”,CLIP 可以准确地对猫的新照片进行分类,即使它以前没有见过这些特定图像。
训练: CLIP 在从互联网收集的大量文本图像对数据集上进行训练,这使得它能够学习视觉概念及其描述之间的联系。
局限性: CLIP 也有缺点,训练的计算成本可能很高,并且在需要非常具体或抽象概念的任务上,或者对于与训练所用的文本描述非常不同的数据时,可能表现不佳。训练可能会将社会偏见引入模型中。

paper:Learning Transferable Visual Models From Natural Language Supervision

本文用CLIP做一个零样本分类,
CLIP训练的时候用的是图片和文本描述对,并没有分类的标签,那如何让CLIP做零样本分类?
我们需要给出标签的文本,让图像和所有的文本标签进行匹配,得分高的就是匹配到的标签文本。

paper中提到预测哪个文本整体与哪个图像配对,而不是该文本的准确单词。

在这里插入图片描述

下面通过一个kaggle数据集来具体说明。

这里选用indo fashion dataset, 它有15种印度服饰。

在这里插入图片描述
类别如下:
在这里插入图片描述

数据集结构:
其中images文件夹下又有train, val, test文件夹。

在这里插入图片描述

再看一下json文件,
image_path指的是上面images文件夹下的路径,
product_title是和图片对应的文本描述,训练的时候就是用图片和这个文本进行匹配。
class_label训练的时候不需要,最后验证分类是否正确时会用到。

在这里插入图片描述

import需要的库,定义数据集的文件夹,读取json数据

import json
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import clip
from transformers import CLIPProcessor,CLIPModel
from tqdm import tqdmjson_path = 'your_path/train_data.json'
image_path = 'your_path/images/train/'input_data = []
with open(json_path, 'r') as f:for line in f:obj = json.loads(line)input_data.append(obj)

CLIP模型,如果不能download, 手动下载走offline模式。

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
Setting our device to GPU (Cuda) and loading the pre-trained CLIP model.device = "cuda:0" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

定义Dataloader

# Define a custom dataset
class image_title_dataset():def __init__(self, list_image_path, list_txt):self.image_path = list_image_path# Tokenize text using CLIP's tokenizerself.title = clip.tokenize(list_txt)def __len__(self):# Define the length of the datasetreturn len(self.title)def __getitem__(self, idx):image = preprocess(Image.open(self.image_path[idx]))title = self.title[idx]return image, title

这里的dataset需要传入list_image_path和list_txt,
格式是这种:
list_image_path = [‘folder/image1.jpg’,‘folder2/image2.jpg’]
list_txt = [‘description for image1.jpg’ , ‘description for image2.jpg’]
所以要把image_path和product_title都装进list里面。

注意,CLIP的最大序列长度限制在76, 而有些文本描述非常长,需要截掉一部分,
当然截到76长度也有很多种方法,这里简单粗暴就从开头取长度76.

实际代码中,indo数据集不限制长度会报错,而博主觉得这个76可能是text被tokenize之后的token的长度,而不是原文本的长度,
因为把文本截到长度>77也是可以的。
而token的长度是由tokenize的算法决定的。具体最大极限文本长度是多少没测,这里简单地截取到77.

在这里插入图片描述

list_image_path = []
list_txt = []
for item in input_data:img_path = image_path + item['image_path'].split('/')[-1]caption = item['product_title'][:77]list_image_path.append(img_path)list_txt.append(caption)dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=100, shuffle=True) # Function to convert model's parameters to FP32 format
#转精度省内存.
def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() p.grad.data = p.grad.data.float() if device == "cpu":model.float()  # Convert the model's parameters to float if using CPU

optimizer用Adam,参数按paper中的设置.
不过博主的机器容纳不了这么大的batch_size, 具体batch_size设多少合适,需要自己去验证。

在这里插入图片描述
由于数据集比较小,lr设得更小一些。

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6 ,weight_decay=0.2) 

训练

paper中的训练是这样的
在这里插入图片描述

    for epoch in range(num_epochs):pbar = tqdm(train_dataloader, total=len(train_dataloader))for batch in pbar:optimizer.zero_grad()images, texts = batchimages = images.to(device)texts = texts.to(device)logits_per_image, logits_per_text = model(images, texts)ground_truth = torch.arange(len(images), dtype=torch.long, device=device)total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2total_loss.backward()if device == "cpu":optimizer.step()else:convert_models_to_fp32(model)optimizer.step()clip.model.convert_weights(model)pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")if torch.isnan(total_loss).any():print("epoch {} loss is NaN".format(epoch))epoch = num_epochsbreak

训练中,遇到了这些问题:
loss出现了NaN, 调整batch_size能解决,batch_size不要太小。
loss降不下去了,看看paper中的参数,有哪些需要调整。

训练完之后,找来一张图片测试。
这里又有一些注意事项,
请看paper.
因为训练的时候是图片和一段文本描述匹配的,而不是图片和一个单词。
所以你做零样本分类时,类别文本最好不要只写一个单词,比如只写"Saree"。
你要写"A photo of Saree", 这就成了一个句子,效果就会好一些。

在这里插入图片描述

model, preprocess = clip.load("ViT-B/32", device=device)checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint['model_state_dict'])clothing_items = ["Saree","Lehenga","Women Kurta","Dupatta","Gown","Nehru Jacket","Sherwani","Men Kurta","Men Mojari","Leggings and Salwar","Blouse","Palazzo","Dhoti Pants","Petticoat","Women Mojari"
]

这里你可能要问,那json文件里面的标签不是这么写的,比如"Women Kurta",json文件的标签是"women_kurta",
为什么不写成"women_kurta"。
这个博主是测试过的,写成json文件里面的标签形式准确率会降低,可能是因为"Women Kurta"更接近自然语言,更贴合训练数据吧。

把15个类别的标签都写成"A photo of {label}" 进行测试。

#你想测的第几张图片
index_ = 500
image_json = input_data[index_]
image_path = os.path.join("indo-fashion-dataset", image_json['image_path'])
image_class = image_json['class_label']
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in clothing_items]).to(device)with torch.no_grad():# Encode image and textimage_features = model.encode_image(image)text_features = model.encode_text(text)# Calculate similarity scores between image and textlogits_per_image, logits_per_text = model(image, text)probs = logits_per_image.softmax(dim=-1).cpu().numpy()# Normalize image and text features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)# Calculate similarity scores
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)# Print the top predictions
print("\nTop predictions:\n")
for value, index in zip(values, indices):print(f"{clothing_items[index]:>16s}: {100 * value.item():.2f}%")# Display the image with its class label
plt.imshow(plt.imread(image_path))
plt.title(f"Image for class: {image_class}")
plt.axis('off')
plt.show()

请添加图片描述
请添加图片描述

训练中并没有精调参数,也没有训练很多epoch. 效果如下。
统计了一下测试集中7450张图片的top1和top3准确率。
top1: 77.7%, top3: 93.57%

请添加图片描述

paper中说CLIP 模型的 Top-5 准确率明显高于其 Top-1 准确率, 本文虽测的是top3, 但也是明显高于top1的。

在这里插入图片描述

又试了一下这种方法,这里效果并没有变好。

在这里插入图片描述

参考资料1
参考资料2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/2185.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】行人跌倒行为检测软件系统

行人跌倒检测系统在各个领域的应用都对社会的整体健康、安全和福祉产生积极影响,为人们的生活和工作提供了更加安全和可靠的环境, 本文主要使用YOLOV8深度学习框架自训练了一个“行人跌倒检测模型”,基于此模型使用PYQT5实现了一款界面软件用…

Visual Studio2022中使用水晶报表

1.创建水晶报表项目 选择需要的表 自动生成连接 选项:可跳过 后续还有一些 都能跳过 看你自己的需求 自己选的样式

39. 【Android教程】触摸事件分发

用户在使用 Andriod 系统的时候会不断的和我们的 App 进行各种类型的交互(类似点击、滑动等等),“事件”就是一个非常有效的用来收集用户行为的方式。在前面章节有提到过:Android 系统采用一个先进先出(FIFO&#xff0…

PostgreSQL 免费的对象-关系数据库

目录 一、什么是数据库 二、ORDBMS 的一些术语 三、PostgreSQL 概述 四、PostgreSQL数据库优点和缺点 4.1PostgreSQL数据库的优点 4.2PostgreSQL数据库的缺点 4.3PostgreSQL 特征 五、Linux 上安装 PostgreSQL 5.1Yum 安装 PostgreSQL 5.1.1安装postgreSQL的官方yum仓…

54、图论-实现Trie前缀树

思路: 主要是构建一个trie前缀树结构。如果构建呢?看题意,应该当前节点对象下有几个属性: 1、next节点数组 2、是否为结尾 3、当前值 代码如下: class Trie {class Node {boolean end;Node[] nexts;public Node(…

如何在PostgreSQL中使用索引覆盖扫描提高查询性能?

文章目录 解决方案1. 创建合适的索引2. 确保查询能够使用索引覆盖扫描3. 调整查询以利用索引覆盖扫描4. 监控和调优 示例代码1. 创建索引2. 编写查询3. 检查是否使用索引覆盖扫描4. 调整索引 总结 在PostgreSQL中,索引是提高查询性能的关键工具之一。索引允许数据库…

文章生成器免费版有哪些,哪个好用?

作为一个长期需要写作的人,对文章生成器自然是非常了解,如果搜文章生成器互联网上多到让人应接不暇,但小编今天要谈的是文章生成器免费版,因为看到很多写手朋友都想找一个免费的文章生成器来用,但是大家在网上搜可能很…

GITHUB的VB代码无法加载的问题解决

GITHUB里有不少好的VB代码,但是下载之后,经常出现工程加载出错的问题,例如: LOG文件为: 不能加载 0 行 0: 不能加载文件 D:\xxxx\Semi VB API Loader\frmMain.frm 。 原因其实很简单,github里的换行符是u…

Promise.all 的方法还没执行完就执行了.then

碰见一个问题,接盘了一个有问题的页面修改。 改变日期后 查询很多数据再去重新加载页面上的数据显示相关的组件。 问题就来了。 加载异常捏…… 最后我一通查: 重点来了 是因为这个Promise.all(数组),里边这个数组的问题。现在是在数据中…

【机器学习】分类与预测算法的评价与优化

以实际案例解析F1值与P-R曲线的应用 一、分类算法与性能评价的重要性二、F1值与P-R曲线的概念与意义三、实例解析:以垃圾邮件检测为例四、代码实现与结果分析五、结论与展望 在数据驱动的时代,机器学习算法以其强大的数据处理和分析能力,成为…

Linux - tar (tape archive)

tar 的全称是 Tape Archive。它最初是在 Unix 系统中用于将数据写入磁带的工具,但现在它通常用于创建、维护、修改和提取文件的归档文件。尽管 tar 可以用于压缩和解压缩文件,但它本身并不进行压缩,而是通常与 gzip 或 bzip2 等压缩工具一起使…

【圆桌论坛】个人作为嘉宾参与问答环节的总结,Create 2024百度AI开发者大会之AI智能体开发与应用论坛

目录 ⭐前言⭐讨论话题✨本质和价值✨端侧部署✨应用商业模式✨商业模式 ⭐主题总结⭐有趣分享 ⭐前言 首先,非常荣幸和开心作为开发者和创业者代表参加百度Create AI大会分论坛圆桌论坛的问答环节。 在分论坛活动开始前,参加了文心智能体平台&#xff…

vi编辑器的用法linux中的vim编辑器大全

vim的介绍 vi 和 vim 命令是linux中强⼤的⽂本编辑器, 由于Linux系统⼀切皆⽂件,⽽配置⼀个服务就是在修改其配置⽂件的参数。 vim 编辑器是运维⼯程师必须掌握的⼀个⼯具, 没有它很多⼯作都⽆法完成。 其中有vi和vim两种 vi和vim的区别 Vim是Vi的升级版本&#…

【QT学习】9.绘图,三种贴图,贴图的转换

一。绘图的解释 Qt 中提供了强大的 2D 绘图系统,可以使用相同的 API 在屏幕和绘图设备上进行绘制,它主要基于QPainter、QPaintDevice 和 QPaintEngine 这三个类。 QPainter 用于执行绘图操作,其提供的 API 在 GUI 或 QImage、QOpenGLPaintDev…

【Linux】学习记录_14_线程

14 线程 14.1 线程和进程 进程是资源管理的最小单位,每个进程都有数据段、代码段和堆栈段,进程切换时都有复杂的上下文切换等动作。进程切换上下文时, 需要重新映射虚拟地址空间、进出OS内核、寄存器切换,还会干扰处理器的缓存机…

11408知识点集合

文章目录 一、数学(一) 高数0.初等数学补充1.函数、极限、连续2.导数3.中值定理4.积分5.微分方程6.空间解析几何7.多元微分8.重积分9.曲线曲面积分10.无穷级数11.其他杂记(二) 线代0.串联各章的等价条件1.行列式、矩阵的秩、矩阵的初等变换2.向量3.方程组、矩阵方程AXB4.特征值…

第G8周:ACGAN任务

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 参考论文 这周主要任务就是根据之前GAN,CGAN,SGAN网络架构搭建…

python_django中小学家校互动系统vue_flask家校联系

实现了一个完整的家校互动系统,其中主要有作业信息模块、学校管理员模块、学生学籍模块、学生成绩模块、学科模块、系统新闻模块、系统公告模块、校内新闻模块、校内公告模块、用户表模块、token表模块、关于我们模块、收藏表模块、年级模块、家长模块、教师模块、互…

Spark Standalone模式部署

准备至少2台虚拟机,装好linux系统,我装的是Ubuntu20.04。 1.修改主机名(每台) 1)修改/etc/hostsname内容,主节点改为master,子节点改为slaver1 sudo vim /etc/hostname 2)在/etc/…

如何通过外发文件控制,保障企业对核心业务数据的控制力?

外发文件控制是企业数据安全管理的重要组成部分,它涉及到对从企业内网向外发送的文件进行严格控制和管理,以防止敏感或机密信息的泄露。以下是常见的一些外发手段及问题: (1)IM通讯工具 如微信、QQ、企业微信、钉钉、…