CLIP在Github上的使用教程

CLIP的github链接:https://github.com/openai/CLIP

CLIP

Blog,Paper,Model Card,Colab
CLIP(对比语言-图像预训练)是一个在各种(图像、文本)对上进行训练的神经网络。可以用自然语言指示它在给定图像的情况下预测最相关的文本片段,而无需直接对任务进行优化,这与 GPT-2 和 3 的零镜头功能类似。我们发现,CLIP 无需使用任何 128 万个原始标注示例,就能在 ImageNet "零拍摄 "上达到原始 ResNet50 的性能,克服了计算机视觉领域的几大挑战。

Usage用法

首先,安装 PyTorch 1.7.1(或更高版本)和 torchvision,以及少量其他依赖项,然后将此 repo 作为 Python 软件包安装。在 CUDA GPU 机器上,完成以下步骤即可:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git

将上面的 cudatoolkit=11.0 替换为机器上相应的 CUDA 版本,如果在没有 GPU 的机器上安装,则替换为 cpuonly

import torch
import clip
from PIL import Imagedevice = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)with torch.no_grad():image_features = model.encode_image(image)text_features = model.encode_text(text)logits_per_image, logits_per_text = model(image, text)probs = logits_per_image.softmax(dim=-1).cpu().numpy()print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

API

CLIP 模块提供以下方法:

clip.available_models()

返回可用 CLIP 模型的名称。例如下面就是我执行的结果。
在这里插入图片描述

clip.load(name, device=..., jit=False)

返回模型和模型所需的 TorchVision 变换(由 clip.available_models() 返回的模型名称指定)。它将根据需要下载模型。name参数也可以是本地检查点的路径。
可以选择指定运行模型的设备,默认情况下,如果有第一个 CUDA 设备,则使用该设备,否则使用 CPU。当 jitFalse 时,将加载模型的非 JIT 版本。

clip.tokenize(text: Union[str, List[str]], context_length=77)

返回包含给定文本输入的标记化序列的 LongTensor。这可用作模型的输入。

clip.load() 返回的模型支持以下方法:

model.encode_image(image: Tensor)

给定一批图像,返回 CLIP 模型视觉部分编码的图像特征。

model.encode_text(text: Tensor)

给定一批文本标记,返回 CLIP 模型语言部分编码的文本特征。

model(image: Tensor, text: Tensor)

给定一批图像和一批文本标记,返回两个张量,其中包含与每张图像和每个文本输入相对应的 logit 分数。这些值是相应图像和文本特征之间的余弦相似度乘以 100。

More Examples更多实例

Zero-Shot预测

下面的代码使用 CLIP 执行零点预测,如论文附录 B 所示。该示例从 CIFAR-100 数据集中获取一张图片,并预测数据集中 100 个文本标签中最有可能出现的标签。

import os
import clip
import torch
from torchvision.datasets import CIFAR100# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)# Calculate features
with torch.no_grad():image_features = model.encode_image(image_input)text_features = model.encode_text(text_inputs)# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

输出结果如下(具体数字可能因计算设备而略有不同):

Top predictions:snake: 65.31%turtle: 12.29%sweet_pepper: 3.83%lizard: 1.88%crocodile: 1.75%

请注意,本示例使用的 encode_image()encode_text() 方法可返回给定输入的编码特征。

Linear-probe evaluation线性探针评估

下面的示例使用 scikit-learn 对图像特征进行逻辑回归。

import os
import clip
import torchimport numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)def get_features(dataset):all_features = []all_labels = []with torch.no_grad():for images, labels in tqdm(DataLoader(dataset, batch_size=100)):features = model.encode_image(images.to(device))all_features.append(features)all_labels.append(labels)return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

请注意,C 值应通过使用验证分割进行超参数扫描来确定。

See Also

OpenCLIP:包括更大的、独立训练的 CLIP 模型,最高可达 ViT-G/14
Hugging Face implementation of CLIP:更易于与高频生态系统集成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/214235.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙HarmonyOS(ArkTS)语法 声明变量及注意事项

好 今天我们来看一个基础的harmonyOS语法 变量声明 这里 我们还是用 ArkTS项目 我们声明变量的语法并不是ArkTS的 而是 javaScript 和 TypeScript的 可以看一下下面一张图 js是最初弱类型语言 于是TS作为js的副类 是一种更严谨的数据限定语法 而ArkTS 是TS的改良版 其实我们…

算法通关村第十八关 | 白银 | 回溯热门问题

1.组合总和问题 原题&#xff1a;力扣39. 元素可以重复拿取&#xff0c;且题目的测试用例保证了组合数少于 150 个。 class CombinationSum {List<List<Integer>> res new ArrayList<>();List<Integer> path new ArrayList<>();public List…

一篇文章教你快速弄懂 web自动化测试中的三种等待方式

前言 现在的网页很多都是动态加载的&#xff0c;如果页面的内容发生了改变&#xff0c;就需要时间来渲染。在咱们做web自动化测试的时候&#xff0c;由于代码是自动执行的&#xff0c;代码在执行的时候&#xff0c;有可能上一步操作而加载的元素还没加载出来&#xff0c;就会报…

配置本地端口镜像示例(1:1)

本地端口镜像简介 本地端口镜像是指观察端口与监控设备直接相连&#xff0c;观察端口直接将镜像端口复制来的报文转发到与其相连的监控设备进行故障定位和业务监测。 配置注意事项 观察端口专门用于镜像报文的转发&#xff0c;因此不要在上面配置其他业务&#xff0c;防止镜像…

建筑学VR虚拟仿真情景实训教学

首先&#xff0c;建筑学VR虚拟仿真情景实训教学为建筑学专业的学生提供了一个身临其境的学习环境。通过使用VR仿真技术&#xff0c;学生可以在虚拟环境中观察和理解建筑结构、材料、设计以及施工等方面的知识。这种教学方法不仅能帮助学生更直观地理解复杂的建筑理论&#xff0…

记录 | ubuntu源码编译安装/更新boost版本

一、卸载当前的版本 1、查看当前安装的boost版本 dpkg -S /usr/include/boost/version.hpp通过上面的命令&#xff0c;你就可以发现boost的版本了&#xff0c;查看结果可能如下&#xff1a; libboost1.54-dev: /usr/include/boost/version.hpp 2、删除当前安装的boost sudo …

记录 | 使用samba将ubuntu文件夹映射到windows实现共享文件夹

一、ubuntu配置 1. 安装 samba samba 是在 Linux 和 UNIX 系统上实现 SMB 协议的一个免费软件&#xff0c;由服务器及客户端程序构成。SMB&#xff08;Server Messages Block&#xff0c;信息服务块&#xff09;是一种在局域网上共享文件和打印机的一种通信协议。 sudo apt-…

Excel COUNT类函数使用

目录 一. COUNT二. COUNTA三. COUNTBLANK四. COUNTIF五. COUNTIFS 一. COUNT ⏹用于计算指定范围内包含数字的单元格数量。 基本语法 COUNT(value1, [value2], ...)✅统计A2到A7所有数字单元格的数量 ✅统计A2到A7&#xff0c;B2到B7的所有数字单元格的数量 二. COUNTA ⏹计…

大数据分析与应用实验任务十一

大数据分析与应用实验任务十一 实验目的 通过实验掌握spark Streaming相关对象的创建方法&#xff1b; 熟悉spark Streaming对文件流、套接字流和RDD队列流的数据接收处理方法&#xff1b; 熟悉spark Streaming的转换操作&#xff0c;包括无状态和有状态转换。 熟悉spark S…

Linux 驱动开发需要掌握哪些编程语言和技术?

Linux 驱动开发需要掌握哪些编程语言和技术&#xff1f; 在开始前我有一些资料&#xff0c;是我根据自己从业十年经验&#xff0c;熬夜搞了几个通宵&#xff0c;精心整理了一份「Linux从专业入门到高级教程工具包」&#xff0c;点个关注&#xff0c;全部无偿共享给大家&#xf…

1. mycat入门

1、mycat介绍 Mycat 是一个开源的分布式数据库系统&#xff0c;但是由于真正的数据库需要存储引擎&#xff0c;而 Mycat 并没有存 储引擎&#xff0c;所以并不是完全意义的分布式数据库系统。MyCat是目前最流行的基于Java语言编写的数据库中间件&#xff0c;也可以理解为是数据…

鸿蒙HarmonyOS4.0 入门与实战

一、开发准备: 熟悉鸿蒙官网安装DevEco Studio熟悉鸿蒙官网 HarmonyOS应用开发官网 - 华为HarmonyOS打造全场景新服务 应用设计相关资源: 开发相关资源: 例如开发工具 DevEco Studio 的下载 应用发布: 开发文档:

3易懂AI深度学习算法:长短期记忆网络(Long Short-Term Memory, LSTM)生成对抗网络 优化算法进化算法

继续写&#xff1a;https://blog.csdn.net/chenhao0568/article/details/134920391?spm1001.2014.3001.5502 1.https://blog.csdn.net/chenhao0568/article/details/134931993?spm1001.2014.3001.5502 2.https://blog.csdn.net/chenhao0568/article/details/134932800?spm10…

LeetCode 1631. 最小体力消耗路径:广度优先搜索BFS

【LetMeFly】1631.最小体力消耗路径&#xff1a;广度优先搜索BFS 力扣题目链接&#xff1a;https://leetcode.cn/problems/path-with-minimum-effort/ 你准备参加一场远足活动。给你一个二维 rows x columns 的地图 heights &#xff0c;其中 heights[row][col] 表示格子 (ro…

视频如何提取文字?这四个方法一键提取视频文案

视频如何提取文字&#xff1f;你用过哪些视频提取工具&#xff1f;视频转文字工具&#xff0c;又称为语音识别软件&#xff0c;是一款能够将视频中的语音或对话转化为文字的实用工具。它运用了尖端的声音识别和语言理解技术&#xff0c;能精准地捕捉视频中的音频&#xff0c;并…

弧形导轨的工作原理

弧形导轨是一种能够将物体沿着弧形轨道运动的装置&#xff0c;它由个弧形轨道和沿着轨道运动的物体组成&#xff0c;弧形导轨的工作原理是利用轨道的形状和物体的运动方式来实现运动&#xff0c;当物体处于轨道上时&#xff0c;它会受到轨道的引导&#xff0c;从而沿着轨道的弧…

Nginx正则表达式

目录 1.nginx常用的正则表达式 2.location location 大致可以分为三类 location 常用的匹配规则 location 优先级 location 示例说明 优先级总结 3.rewrite rewrite功能 rewrite跳转实现 rewrite执行顺序 语法格式 rewrite示例 实例1&#xff1a; 实例2&#xf…

生活小记录

上个月项目总算上线了&#xff0c;节奏也慢慢调整正常。发现自己好久没有记录生活点滴了&#xff0c;正好写写。其实&#xff0c;最近这段日子发生的事情还是挺多的。 流感 媳妇11.24得流感&#xff0c;这件事情特别好笑&#xff0c;大晚上她和我妹妹想喝酒试试&#xff0c;结…

【Python必做100题】之第六题(求圆的周长)

圆的周长公式&#xff1a;C 2 * pi * r 代码如下&#xff1a; pi 3.14 r float(input("请输入圆的半径&#xff1a;")) c 2 * pi *r print(f"圆的周长为{c}") 运行截图&#xff1a; 总结 1、圆周长的公式&#xff1a;C 2 * pi * r 2、输出结果注意…

webrtc 工具类

直接上代码&#xff1b;webrtc 工具类 package com.example.mqttdome;import android.app.Activity; import android.content.Context; import android.content.Intent; import android.media.projection.MediaProjection; import android.media.projection.MediaProjectionMa…