与AMD GPU上的对比语言-图像预训练(CLIP)模型交互

Interacting with Contrastive Language-Image Pre-Training (CLIP) model on AMD GPU — ROCm Blogs

2024年4月16日,由Sean Song撰写.

引言 

对比语言-图像预训练(CLIP)是一种多模态深度学习模型,连接视觉和自然语言。它在OpenAI的论文“通过自然语言监督学习可转移的视觉模型” (2021) 中被介绍,并在大量(4亿)网页抓取的数据图像-字幕对上进行了对比训练(这是最早进行此类训练的模型之一)。

在预训练阶段,CLIP被训练去预测批次中图像和文本之间的语义关联。这包括确定哪些图像-文本对彼此之间最相关或最密切。这一过程涉及图像编码器和文本编码器的同时训练。其目标是最大化批次中图像和文本对嵌入间的余弦相似度,同时最小化错误对嵌入之间的相似度。通过这种方式,该模型学习到一个多模态的嵌入空间。对这些相似度分数使用对称交叉熵损失进行优化。

图片来源: 通过自然语言监督学习可转移的视觉模型.

In the subsequent sections of the blog, we will leverage the PyTorch framework along 在随后的博客部分,我们将利用*PyTorch*框架与*ROCm*一起运行CLIP模型,以计算任意图像和文本输入之间的相似度。

设置

此演示使用以下设置创建。有关全面的支持详情,请参阅ROCm 文档。

  • 硬件和操作系统:

    • AMD Instinct GPU

    • Ubuntu 22.04.3 LTS

  • 软件:

    • ROCm 5.7.0+

    • Pytorch 2.0+

任意图像和文本输入之间的相似度计算

步骤1:入门

首先,确认GPU的可用性。

!rocm-smi --showproductname
========== ROCm System Management Interface =========================
==================== Product Info ===================================GPU[0]      : Card series:      AMD INSTINCT MI250 (MCM) OAM AC MBAGPU[0]      : Card model:       0x0b0cGPU[0]      : Card vendor:      Advanced Micro Devices, Inc. [AMD/ATI]GPU[0]      : Card SKU:         D65209====================================================================================
====================== End of ROCm SMI Log ================================

接下来,安装CLIP和所需的库。

! pip install git+https://github.com/openai/CLIP.git ftfy regex tqdm matplotlib

步骤2:加载模型

import torch
import clip
import numpy as np# 在这个博客中我们将加载 ViT-L/14@336px 的 CLIP 模型
model, preprocess = clip.load("ViT-L/14@336px")
model.cuda().eval()
# 检查模型架构
print(model)
# 检查预处理器
print(preprocess)

输出:

    CLIP((visual): VisionTransformer((conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)(ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(transformer): Transformer((resblocks): Sequential((0): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True))(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=1024, out_features=4096, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=4096, out_features=1024, bias=True))(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))(1): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True))(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=1024, out_features=4096, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=4096, out_features=1024, bias=True))(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))...(23): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True))(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=1024, out_features=4096, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=4096, out_features=1024, bias=True))(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))))(ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True))(transformer): Transformer((resblocks): Sequential((0): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True))(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=768, out_features=3072, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=3072, out_features=768, bias=True))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True))(1): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True))(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=768, out_features=3072, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=3072, out_features=768, bias=True))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True))...(11): ResidualAttentionBlock((attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True))(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)(mlp): Sequential((c_fc): Linear(in_features=768, out_features=3072, bias=True)(gelu): QuickGELU()(c_proj): Linear(in_features=3072, out_features=768, bias=True))(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True))))(token_embedding): Embedding(49408, 768)(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)Compose(Resize(size=336, interpolation=bicubic, max_size=None, antialias=warn)CenterCrop(size=(336, 336))<function _convert_image_to_rgb at 0x7f8616295630>ToTensor()Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

步骤3:检查图像和文本

我们从 COCO 数据集 中获取 8 张示例图片及其文本描述,并将图片特征和文本特征进行比对,计算相似度。

import os
import matplotlib.pyplot as plt
from PIL import Image# 使用来自 COCO 数据集的图像及其文本描述
image_urls  = ["http://farm1.staticflickr.com/6/8378612_34ab6787ae_z.jpg","http://farm9.staticflickr.com/8456/8033451486_aa38ee006c_z.jpg","http://farm9.staticflickr.com/8344/8221561363_a6042ba9e0_z.jpg","http://farm5.staticflickr.com/4147/5210232105_b22d909ab7_z.jpg","http://farm4.staticflickr.com/3098/2852057907_29f1f35ff7_z.jpg","http://farm4.staticflickr.com/3324/3289158186_155a301760_z.jpg","http://farm4.staticflickr.com/3718/9148767840_a30c2c7dcb_z.jpg","http://farm9.staticflickr.com/8030/7989105762_4ef9e7a03c_z.jpg"
]text_descriptions = ["a cat standing on a wooden floor","an airplane on the runway","a white truck parked next to trees","an elephant standing in a zoo","a laptop on a desk beside a window","a giraffe standing in a dirt field","a bus stopped at a bus stop","two bunches of bananas in the market"
]

显示八张图片及其对应的文本描述。

import requests
from io import BytesIOimages_for_display=[]
images=[]# 创建一个新图形
plt.figure(figsize=(12, 6))
size = (400, 320)
# 依次遍历每个 URL 并在子图中绘制图像
for i, url1 in enumerate(image_urls):# # 从 URL 获取图像response = requests.get(url1)image = Image.open(BytesIO(response.content))image = image.resize(size)# 添加子图 (2 行,4 列,索引为 i+1)plt.subplot(2, 4, i + 1)# 绘制图像plt.imshow(image)plt.axis('off')  # Turn off axes labels# 添加标题(可选)plt.title(f'{text_descriptions[i]}')images_for_display.append(image)images.append(preprocess(image))# 调整布局以防止重叠
plt.tight_layout()# 显示图
plt.show()

第 4 步:生成特征

接下来,我们准备图像和文本输入,并继续执行模型的前向传播。这一步将分别提取图像和文本特征。

image_inputs = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["It is " + text for text in text_descriptions]).cuda()with torch.no_grad():image_features = model.encode_image(image_inputs).float()text_features = model.encode_text(text_tokens).float()

步骤5:计算文本与图像之间的相似度得分

我们对特征进行归一化,并计算每对的点积。

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity_score = text_features.cpu().numpy() @ image_features.cpu().numpy().T

步骤6:可视化文本与图像之间的相似度

def plot_similarity(text_descriptions, similarity_score, images_for_display):count = len(text_descriptions)fig, ax = plt.subplots(figsize=(18, 15))im = ax.imshow(similarity_score, cmap=plt.cm.YlOrRd)plt.colorbar(im, ax=ax)# y轴刻度:文本描述ax.set_yticks(np.arange(count))ax.set_yticklabels(text_descriptions, fontsize=12)ax.set_xticklabels([])ax.xaxis.set_visible(False) for i, image in enumerate(images_for_display):ax.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")for x in range(similarity_score.shape[1]):for y in range(similarity_score.shape[0]):ax.text(x, y, f"{similarity_score[y, x]:.2f}", ha="center", va="center", size=10)ax.spines[["left", "top", "right", "bottom"]].set_visible(False)# 设置x轴和y轴的限制ax.set_xlim([-0.5, count - 0.5])ax.set_ylim([count + 0.5, -2])# 为图表添加标题ax.set_title("Text and Image Similarity Score calculated with CLIP", size=14)plt.show()plot_similarity(text_descriptions, similarity_score, images_for_display)

png

如论文所述,CLIP的目标是在批次内最大化图像和文本对的嵌入相似度,同时最小化错误对的嵌入相似度。在结果中可以观察到,对角线上的单元格在各自的列和行中表现出最高的值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/60079.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年第四届“网鼎杯”网络安全比赛---朱雀组Crypto- WriteUp

2024年第四届“网鼎杯”网络安全比赛---朱雀组Crypto-WriteUp Crypto&#xff1a;Crypto-2&#xff1a;Crypto-3&#xff1a; 前言&#xff1a;本次比赛已经结束&#xff0c;用于赛后复现&#xff0c;欢迎大家交流学习&#xff01; Crypto&#xff1a; Crypto-2&#xff1a; …

下载mysql的jar,添加至jmeter中,编写jdbc协议脚本1106

下载jar包&#xff1a; 步骤1&#xff1a;进入maven仓库官网https://mvnrepository.com/ 步骤2&#xff1a;搜索实际的数据库 步骤3&#xff1a;点击 Mysql connnector/J 步骤5、查看数据库的版本号&#xff0c;选择具体版本&#xff0c;我的是mysql 8.0.16,下图&#xff0c;…

从“点”到“面”,热成像防爆手机如何为安全织就“透视网”?

市场上测温产品让人眼花缭乱&#xff0c;通过调研分析&#xff0c;小编发现测温枪占很高比重。但是&#xff0c;测温枪局限于显示单一数值信息&#xff0c;无法直观地展示物体的整体温度分布情况&#xff0c;而且几乎没有功能拓展能力。以AORO A23为代表的热成像防爆手机改变了…

模型训练中GPU利用率低?

买了块魔改华硕猛禽2080ti&#xff0c;找了下没找到什么测试显存的软件&#xff0c;于是用训练模型来测试魔改后的显存稳定性&#xff0c;因为模型训练器没有资源监测&#xff0c;于是用了Windows任务管理器来查看显卡使用情况&#xff0c;却发现GPU的利用率怎么这么低&#xf…

开源代码管理平台Gitlab如何本地化部署并实现公网环境远程访问私有仓库

文章目录 前言1. 下载Gitlab2. 安装Gitlab3. 启动Gitlab4. 安装cpolar5. 创建隧道配置访问地址6. 固定GitLab访问地址6.1 保留二级子域名6.2 配置二级子域名 7. 测试访问二级子域名 前言 本文主要介绍如何在Linux CentOS8 中搭建GitLab私有仓库并且结合内网穿透工具实现在公网…

在vue3的vite网络请求报错 [vite] http proxy error:

在开发的过程中 代理proxy报错: [vite] http proxy error: /ranking/hostRank?dateType1 Error: connect ETIMEDOUT 43.xxx.xxx.xxx:443 网络请求是http的: // vite.config.ts import { Agent } from node:http;server: {host: 0.0.0.0,port: port,open: true,https: false,…

云计算 esxi 如何 部署iscsi ,配合windows 2012 iscsi 存储

1 windows 2012 如何创建iscsi 存储服务器&#xff0c;看前面的文章 iscsi 服务上的地址 192.168.10.196 192.168.10.196 2 如何在esxi 创建iscsi 注意地址是192.168.10.196 这是服务器的地址 很明显这是我们esxi 主机上发现的iscsi 磁盘 、

【Python爬虫实战】深入解锁 DrissionPage:ChromiumPage 自动化网页操作指南

&#x1f308;个人主页&#xff1a;易辰君-CSDN博客 &#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/2401_86688088/category_12797772.html ​ 目录 前言 一、ChromiumPage基础操作 &#xff08;一&#xff09;初始化Drission 和 ChromiumPage 对象 &#xff0…

H5播放器EasyPlayer.js 流媒体播放器是否支持npm(yarn) install 安装?

EasyPlayer.js H5播放器是一款功能强大的H5视频播放器&#xff0c;它支持多种流媒体协议播放&#xff0c;包括WebSocket-FLV、HTTP-FLV、HLS&#xff08;m3u8&#xff09;、WebRTC等格式的视频流。它不仅支持H.264和H.265编码格式&#xff0c;还具备实时录像、低延时直播等功能…

2024年入职_转行网络安全,该如何规划?

前言 前段时间&#xff0c;知名机构麦可思研究院发布了 《2023年中国本科生就业报告》&#xff0c;其中详细列出近五年的本科绿牌专业&#xff0c;其中&#xff0c;信息安全位列第一。 网络安全前景 对于网络安全的发展与就业前景&#xff0c;想必无需我多言&#xff0c;作为…

ElasticSearch备考 -- 集群配置常见问题

一、集群开启xpack安全配置后无法启动 在配置文件中增加 xpack.security.enabled: true 后无法启动&#xff0c;日志中提示如下 Transport SSL must be enabled if security is enabled. Please set [xpack.security.transport.ssl.enabled] to [true] or disable security b…

力扣17-电话号码的数字组合

力扣17-电话号码的数字组合 思路代码 题目链接 思路 原题&#xff1a; 给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 输…

vs code使用git管理代码

1.vs code连接远程服务 ①安装Remote - SSH插件。 安装好远程连接插件后&#xff0c;按照步骤点击远程连接、选择加号、按照指定格式输入ssh ip连接远程服务器。 2.远程推送、对比代码 ①查看你当前所在的分支号&#xff0c;任意点开一个文件下都有对应的分支号。 ②点开右小…

2024 网鼎杯 - 青龙组 Web WP

2024 网鼎杯 - 青龙组 WEB - 02 打开容器一个登录界面&#xff0c;随便输入账号密码可以进到漏洞界面 这里有一个发送给boss的功能&#xff0c;一眼xss 有三个接口&#xff1a;/flag 、/update 、/submit /flag &#xff1a;要求boss才能访问&#xff0c;/update &#xf…

验证码-滑动验证码和点选验证码

1.csdn登录 存在多个内部框架&#xff0c;学习使用driver.switch_to.default_content() from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.common.action_chains import ActionChains import timedriver webdriver.Chrom…

停车场微信小程序的设计与实现(lw+演示+源码+运行)

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了停车场微信小程序的开发全过程。通过分析停车场微信小程序管理的不足&#xff0c;创建了一个计算机管理停车场微信小程序的方案。文章介绍了停车场微信小程序的…

如何保证kafka生产者数据可靠性

ack参数的设置&#xff1a; 0&#xff1a;生产者发送过来的数据&#xff0c;不需要等数据落盘应答 假如发送了Hello 和 World两个信息&#xff0c;Leader直接挂掉&#xff0c;数据就会丢失 生产者 ---> Kafka集群 一放进去就跑 数据可靠性分析&#xff1a;丢数 1&#…

实习作假:阿里健康实习做了RABC中台,还优化了短信发送流程

最近有二本同学说&#xff1a;“大拿老师&#xff0c;能帮忙看下简历吗&#xff1f;” 如果是从面试官的角度来看&#xff0c;这个同学的实习简历是很虚假的。 但是我们一直强调的是&#xff1a;校招的实习简历是不能出现明显的虚假。 首先&#xff0c;你去公司做事情&#…

路过宝安乌石岩庙记

​每周带娃从上屋地铁去罗租大道的七彩城堡儿童乐园玩&#xff0c;路上都会经过乌石岩庙附近。听说香火很繁盛&#xff0c;娃说也想去看看&#xff0c;于是来到了乌石岩庙。 石岩乌石岩庙 广东省深圳市宝安区老街一区94号 ​从百度知悉&#xff1a;乌石岩庙&#xff0c;又称“…

测度论原创(三)

Morden Prob 文章目录 Morden ProbWeek3多维扩展和随机向量定理3.1推论&#xff1a;random variable的变换定理3.2 连续函数的可测性定理3.3 可测函数的线性组合关于拓展实数集的延伸定理3.4 可测函数的极限依旧为可测性随机变量的概率律&#xff08;Law of X X X&#xff09;…