与AMD GPU上的对比语言-图像预训练(CLIP)模型交互

Interacting with Contrastive Language-Image Pre-Training (CLIP) model on AMD GPU — ROCm Blogs

2024年4月16日,由Sean Song撰写.

引言 

对比语言-图像预训练(CLIP)是一种多模态深度学习模型,连接视觉和自然语言。它在OpenAI的论文“通过自然语言监督学习可转移的视觉模型” (2021) 中被介绍,并在大量(4亿)网页抓取的数据图像-字幕对上进行了对比训练(这是最早进行此类训练的模型之一)。

在预训练阶段,CLIP被训练去预测批次中图像和文本之间的语义关联。这包括确定哪些图像-文本对彼此之间最相关或最密切。这一过程涉及图像编码器和文本编码器的同时训练。其目标是最大化批次中图像和文本对嵌入间的余弦相似度,同时最小化错误对嵌入之间的相似度。通过这种方式,该模型学习到一个多模态的嵌入空间。对这些相似度分数使用对称交叉熵损失进行优化。

图片来源: 通过自然语言监督学习可转移的视觉模型.

In the subsequent sections of the blog, we will leverage the PyTorch framework along 在随后的博客部分,我们将利用*PyTorch*框架与*ROCm*一起运行CLIP模型,以计算任意图像和文本输入之间的相似度。

设置

此演示使用以下设置创建。有关全面的支持详情,请参阅ROCm 文档。

  • 硬件和操作系统:

    • AMD Instinct GPU

    • Ubuntu 22.04.3 LTS

  • 软件:

    • ROCm 5.7.0+

    • Pytorch 2.0+

任意图像和文本输入之间的相似度计算

步骤1:入门

首先,确认GPU的可用性。

!rocm-smi --showproductname
========== ROCm System Management Interface =========================
==================== Product Info ===================================GPU[0]      : Card series:      AMD INSTINCT MI250 (MCM) OAM AC MBAGPU[0]      : Card model:       0x0b0cGPU[0]      : Card vendor:      Advanced Micro Devices, Inc. [AMD/ATI]GPU[0]      : Card SKU:         D65209====================================================================================
====================== End of ROCm SMI Log ================================

接下来,安装CLIP和所需的库。

! pip install git+https://github.com/openai/CLIP.git ftfy regex tqdm matplotlib

步骤2:加载模型

import torch
import clip
import numpy as np# 在这个博客中我们将加载 ViT-L/14@336px 的 CLIP 模型
model, preprocess = clip.load("ViT-L/14@336px")
model.cuda().eval()
# 检查模型架构
print(model)
# 检查预处理器
print(preprocess)

输出:

    CLIP((visual): VisionTransformer((conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)(ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(transformer): Transformer((resblocks): Sequential((0): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True))(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=1024, out_features=4096, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=4096, out_features=1024, bias=True))(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))(1): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True))(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=1024, out_features=4096, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=4096, out_features=1024, bias=True))(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))...(23): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True))(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=1024, out_features=4096, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=4096, out_features=1024, bias=True))(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))))(ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))(transformer): Transformer((resblocks): Sequential((0): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True))(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=768, out_features=3072, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=3072, out_features=768, bias=True))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True))(1): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True))(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=768, out_features=3072, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=3072, out_features=768, bias=True))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True))...(11): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True))(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=768, out_features=3072, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=3072, out_features=768, bias=True))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True))))(token_embedding): Embedding(49408, 768)(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)Compose(Resize(size=336, interpolation=bicubic, max_size=None, antialias=warn)CenterCrop(size=(336, 336))<function _convert_image_to_rgb at 0x7f8616295630>ToTensor()Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

步骤3:检查图像和文本

我们从 COCO 数据集 中获取 8 张示例图片及其文本描述,并将图片特征和文本特征进行比对,计算相似度。

import os
import matplotlib.pyplot as plt
from PIL import Image# 使用来自 COCO 数据集的图像及其文本描述
image_urls  = ["http://farm1.staticflickr.com/6/8378612_34ab6787ae_z.jpg","http://farm9.staticflickr.com/8456/8033451486_aa38ee006c_z.jpg","http://farm9.staticflickr.com/8344/8221561363_a6042ba9e0_z.jpg","http://farm5.staticflickr.com/4147/5210232105_b22d909ab7_z.jpg","http://farm4.staticflickr.com/3098/2852057907_29f1f35ff7_z.jpg","http://farm4.staticflickr.com/3324/3289158186_155a301760_z.jpg","http://farm4.staticflickr.com/3718/9148767840_a30c2c7dcb_z.jpg","http://farm9.staticflickr.com/8030/7989105762_4ef9e7a03c_z.jpg"
]text_descriptions = ["a cat standing on a wooden floor","an airplane on the runway","a white truck parked next to trees","an elephant standing in a zoo","a laptop on a desk beside a window","a giraffe standing in a dirt field","a bus stopped at a bus stop","two bunches of bananas in the market"
]

显示八张图片及其对应的文本描述。

import requests
from io import BytesIOimages_for_display=[]
images=[]# 创建一个新图形
plt.figure(figsize=(12, 6))
size = (400, 320)
# 依次遍历每个 URL 并在子图中绘制图像
for i, url1 in enumerate(image_urls):# # 从 URL 获取图像response = requests.get(url1)image = Image.open(BytesIO(response.content))image = image.resize(size)# 添加子图 (2 行,4 列,索引为 i+1)plt.subplot(2, 4, i + 1)# 绘制图像plt.imshow(image)plt.axis('off')  # Turn off axes labels# 添加标题(可选)plt.title(f'{text_descriptions[i]}')images_for_display.append(image)images.append(preprocess(image))# 调整布局以防止重叠
plt.tight_layout()# 显示图
plt.show()

第 4 步:生成特征

接下来,我们准备图像和文本输入,并继续执行模型的前向传播。这一步将分别提取图像和文本特征。

image_inputs = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["It is " + text for text in text_descriptions]).cuda()with torch.no_grad():image_features = model.encode_image(image_inputs).float()text_features = model.encode_text(text_tokens).float()

步骤5:计算文本与图像之间的相似度得分

我们对特征进行归一化,并计算每对的点积。

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity_score = text_features.cpu().numpy() @ image_features.cpu().numpy().T

步骤6:可视化文本与图像之间的相似度

def plot_similarity(text_descriptions, similarity_score, images_for_display):count = len(text_descriptions)fig, ax = plt.subplots(figsize=(18, 15))im = ax.imshow(similarity_score, cmap=plt.cm.YlOrRd)plt.colorbar(im, ax=ax)# y轴刻度:文本描述ax.set_yticks(np.arange(count))ax.set_yticklabels(text_descriptions, fontsize=12)ax.set_xticklabels([])ax.xaxis.set_visible(False) for i, image in enumerate(images_for_display):ax.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")for x in range(similarity_score.shape[1]):for y in range(similarity_score.shape[0]):ax.text(x, y, f"{similarity_score[y, x]:.2f}", ha="center", va="center", size=10)ax.spines[["left", "top", "right", "bottom"]].set_visible(False)# 设置x轴和y轴的限制ax.set_xlim([-0.5, count - 0.5])ax.set_ylim([count + 0.5, -2])# 为图表添加标题ax.set_title("Text and Image Similarity Score calculated with CLIP", size=14)plt.show()plot_similarity(text_descriptions, similarity_score, images_for_display)

png

如论文所述,CLIP的目标是在批次内最大化图像和文本对的嵌入相似度,同时最小化错误对的嵌入相似度。在结果中可以观察到,对角线上的单元格在各自的列和行中表现出最高的值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/60079.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

I/O操作完成事件

本文内容由智谱清言产生。 在计算机编程中&#xff0c;I/O&#xff08;输入/输出&#xff09;操作完成事件是指一个I/O操作&#xff08;如读取文件、写入数据库、网络通信等&#xff09;已经完成的通知。这种事件通常由操作系统或框架生成&#xff0c;以通知应用程序或程序中的…

2024年第四届“网鼎杯”网络安全比赛---朱雀组Crypto- WriteUp

2024年第四届“网鼎杯”网络安全比赛---朱雀组Crypto-WriteUp Crypto&#xff1a;Crypto-2&#xff1a;Crypto-3&#xff1a; 前言&#xff1a;本次比赛已经结束&#xff0c;用于赛后复现&#xff0c;欢迎大家交流学习&#xff01; Crypto&#xff1a; Crypto-2&#xff1a; …

下载mysql的jar,添加至jmeter中,编写jdbc协议脚本1106

下载jar包&#xff1a; 步骤1&#xff1a;进入maven仓库官网https://mvnrepository.com/ 步骤2&#xff1a;搜索实际的数据库 步骤3&#xff1a;点击 Mysql connnector/J 步骤5、查看数据库的版本号&#xff0c;选择具体版本&#xff0c;我的是mysql 8.0.16,下图&#xff0c;…

IntelliJ IDEA的快捷键

IntelliJ IDEA 是一个非常强大的集成开发环境&#xff0c;它提供了大量的快捷键来加速开发者的日常工作。这里为您整理了一份 IntelliJ IDEA 的快捷键大全&#xff0c;包含了编辑、导航、重构、运行等多个方面的快捷键。请注意&#xff0c;这些快捷键是基于 Windows 版本的 Int…

Rust:启动与关闭线程

在 Rust 编程中&#xff0c;启动和关闭线程是并发编程的重要部分。Rust 提供了强大的线程支持&#xff0c;允许你轻松地创建和管理线程。下面将详细解释如何在 Rust 中启动和关闭线程。 启动线程 在 Rust 中&#xff0c;你可以使用标准库中的 std::thread 模块来创建和启动新…

从“点”到“面”,热成像防爆手机如何为安全织就“透视网”?

市场上测温产品让人眼花缭乱&#xff0c;通过调研分析&#xff0c;小编发现测温枪占很高比重。但是&#xff0c;测温枪局限于显示单一数值信息&#xff0c;无法直观地展示物体的整体温度分布情况&#xff0c;而且几乎没有功能拓展能力。以AORO A23为代表的热成像防爆手机改变了…

模型训练中GPU利用率低?

买了块魔改华硕猛禽2080ti&#xff0c;找了下没找到什么测试显存的软件&#xff0c;于是用训练模型来测试魔改后的显存稳定性&#xff0c;因为模型训练器没有资源监测&#xff0c;于是用了Windows任务管理器来查看显卡使用情况&#xff0c;却发现GPU的利用率怎么这么低&#xf…

开源代码管理平台Gitlab如何本地化部署并实现公网环境远程访问私有仓库

文章目录 前言1. 下载Gitlab2. 安装Gitlab3. 启动Gitlab4. 安装cpolar5. 创建隧道配置访问地址6. 固定GitLab访问地址6.1 保留二级子域名6.2 配置二级子域名 7. 测试访问二级子域名 前言 本文主要介绍如何在Linux CentOS8 中搭建GitLab私有仓库并且结合内网穿透工具实现在公网…

在vue3的vite网络请求报错 [vite] http proxy error:

在开发的过程中 代理proxy报错: [vite] http proxy error: /ranking/hostRank?dateType1 Error: connect ETIMEDOUT 43.xxx.xxx.xxx:443 网络请求是http的: // vite.config.ts import { Agent } from node:http;server: {host: 0.0.0.0,port: port,open: true,https: false,…

组合AC c++

题目描述 老师获得了一行字符串&#xff0c;想知道在不改变字符顺序的情况下&#xff0c;从前到后最多能组合出多少个ac? (a和c的位置可以不连续) 比如:字符串为addcadcc&#xff0c;可以找到5个ac&#xff0c;即下标组合为(0&#xff0c;3)、(0&#xff0c;6)、(0&#xff…

云计算 esxi 如何 部署iscsi ,配合windows 2012 iscsi 存储

1 windows 2012 如何创建iscsi 存储服务器&#xff0c;看前面的文章 iscsi 服务上的地址 192.168.10.196 192.168.10.196 2 如何在esxi 创建iscsi 注意地址是192.168.10.196 这是服务器的地址 很明显这是我们esxi 主机上发现的iscsi 磁盘 、

【Python爬虫实战】深入解锁 DrissionPage:ChromiumPage 自动化网页操作指南

&#x1f308;个人主页&#xff1a;易辰君-CSDN博客 &#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/2401_86688088/category_12797772.html ​ 目录 前言 一、ChromiumPage基础操作 &#xff08;一&#xff09;初始化Drission 和 ChromiumPage 对象 &#xff0…

H5播放器EasyPlayer.js 流媒体播放器是否支持npm(yarn) install 安装?

EasyPlayer.js H5播放器是一款功能强大的H5视频播放器&#xff0c;它支持多种流媒体协议播放&#xff0c;包括WebSocket-FLV、HTTP-FLV、HLS&#xff08;m3u8&#xff09;、WebRTC等格式的视频流。它不仅支持H.264和H.265编码格式&#xff0c;还具备实时录像、低延时直播等功能…

pipreqs:快速准确生成当前项目的requirements.txt,还有和freeze的对比

大家好&#xff0c;这里是程序员晚枫。 今天给大家推荐一个快速生成requirements.txt的小工具&#xff1a;pipreqs。 什么是requirements.txt&#xff1f; 我们在开发Python项目的时候&#xff0c;需要用到requirements.txt来管理项目中使用的第三方库。 当我们把项目部署到…

2024年入职_转行网络安全,该如何规划?

前言 前段时间&#xff0c;知名机构麦可思研究院发布了 《2023年中国本科生就业报告》&#xff0c;其中详细列出近五年的本科绿牌专业&#xff0c;其中&#xff0c;信息安全位列第一。 网络安全前景 对于网络安全的发展与就业前景&#xff0c;想必无需我多言&#xff0c;作为…

比较相邻两个元素求最大值

任务描述 本关任务&#xff1a;比较数组相邻两个元素求最大值。 相关知识 比较相邻的元素。如果第一个比第二个大&#xff0c;就交换他们两个&#xff0c;对每一对相邻元素做同样的工作&#xff0c;从开始第一对到结尾的最后一对&#xff0c;最后的元素应该会是最大的数。 如…

ElasticSearch备考 -- 集群配置常见问题

一、集群开启xpack安全配置后无法启动 在配置文件中增加 xpack.security.enabled: true 后无法启动&#xff0c;日志中提示如下 Transport SSL must be enabled if security is enabled. Please set [xpack.security.transport.ssl.enabled] to [true] or disable security b…

nVisual前端配置文件

自定义接口 描述 此配置文件作用是自定义连接后台服务器的地址。 文件位置 dist/config/api.js 字段说明 diagramApiHost&#xff1a;除了报表页面的所有接口host地址。 reportApiHost&#xff1a;报表页面接口host地址。 reportAdapterHost&#xff1a;报表适配器地址。 web…

力扣17-电话号码的数字组合

力扣17-电话号码的数字组合 思路代码 题目链接 思路 原题&#xff1a; 给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 输…

gatewayworker 读取laravel框架的配置

我把gatewayworker放到了vendor目录&#xff0c;在laravel配置文件里配置了url。 return [webSorketUrl > env(WEBSOCKET_URL, ws://127.0.0.1:8282),gatewayWebSorketUrl > env(GATEWAY_WEBSORKET_URL, Websocket://127.0.0.1:8282), ];由于在Gatewayworker/application…