论文学习16:Learning Transferable Visual Models From Natural Language Supervision

代码来源

Learning Transferable Visual Models From Natural Language Supervisionhttps://arxiv.org/pdf/2103.00020

模块作用

当前最先进的计算机视觉系统被训练用于预测一组固定的、预先定义的目标类别。这种受限的监督方式限制了它们的通用性和可用性,因为要识别其他视觉概念时,仍然需要额外的标注数据。直接从关于图像的原始文本中学习是一种更具潜力的替代方案,它可以利用更广泛的监督信息。

模块结构

  • 图像编码器
    • 支持多种架构,包括ResNet(如ResNet-50、ResNet-101、RN50x4、RN50x16、RN50x64)和Vision Transformer(ViT,如ViT-B/32、ViT-B/16、ViT-L/14)。
    • ResNet版本包括ResNet-D改进、抗混叠rect-2模糊池化和注意力池化(多头QKV注意力)。
    • ViT版本在Transformer前增加层归一化,训练模型包括ViT-B/32、ViT-B/16和ViT-L/14,其中ViT-L/14@336px(在336像素分辨率下预训练额外一轮)表现最佳。
  • 文本编码器
    • 使用63M参数的12层Transformer,宽度512,8个注意力头,处理小写字节对编码(BPE),词汇表大小49,152,最大序列长度76,用[SOS]和[EOS]标记括住,使用屏蔽自注意力。

代码

class CLIP(nn.Module):def __init__(self,embed_dim: int,# visionimage_resolution: int,vision_layers: Union[Tuple[int, int, int, int], int],vision_width: int,vision_patch_size: int,# textcontext_length: int,vocab_size: int,transformer_width: int,transformer_heads: int,transformer_layers: int):super().__init__()self.context_length = context_lengthif isinstance(vision_layers, (tuple, list)):vision_heads = vision_width * 32 // 64self.visual = ModifiedResNet(layers=vision_layers,output_dim=embed_dim,heads=vision_heads,input_resolution=image_resolution,width=vision_width)else:vision_heads = vision_width // 64self.visual = VisionTransformer(input_resolution=image_resolution,patch_size=vision_patch_size,width=vision_width,layers=vision_layers,heads=vision_heads,output_dim=embed_dim)self.transformer = Transformer(width=transformer_width,layers=transformer_layers,heads=transformer_heads,attn_mask=self.build_attention_mask())self.vocab_size = vocab_sizeself.token_embedding = nn.Embedding(vocab_size, transformer_width)self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))self.ln_final = LayerNorm(transformer_width)self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))self.initialize_parameters()def initialize_parameters(self):nn.init.normal_(self.token_embedding.weight, std=0.02)nn.init.normal_(self.positional_embedding, std=0.01)if isinstance(self.visual, ModifiedResNet):if self.visual.attnpool is not None:std = self.visual.attnpool.c_proj.in_features ** -0.5nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)attn_std = self.transformer.width ** -0.5fc_std = (2 * self.transformer.width) ** -0.5for block in self.transformer.resblocks:nn.init.normal_(block.attn.in_proj_weight, std=attn_std)nn.init.normal_(block.attn.out_proj.weight, std=proj_std)nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)if self.text_projection is not None:nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)def build_attention_mask(self):# lazily create causal attention mask, with full attention between the vision tokens# pytorch uses additive attention mask; fill with -infmask = torch.empty(self.context_length, self.context_length)mask.fill_(float("-inf"))mask.triu_(1)  # zero out the lower diagonalreturn mask@propertydef dtype(self):return self.visual.conv1.weight.dtypedef encode_image(self, image):return self.visual(image.type(self.dtype))def encode_text(self, text):x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]x = x + self.positional_embedding.type(self.dtype)x = x.permute(1, 0, 2)  # NLD -> LNDx = self.transformer(x)x = x.permute(1, 0, 2)  # LND -> NLDx = self.ln_final(x).type(self.dtype)# x.shape = [batch_size, n_ctx, transformer.width]# take features from the eot embedding (eot_token is the highest number in each sequence)x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projectionreturn xdef forward(self, image, text):image_features = self.encode_image(image)text_features = self.encode_text(text)# normalized featuresimage_features = image_features / image_features.norm(dim=1, keepdim=True)text_features = text_features / text_features.norm(dim=1, keepdim=True)# cosine similarity as logitslogit_scale = self.logit_scale.exp()logits_per_image = logit_scale * image_features @ text_features.t()logits_per_text = logits_per_image.t()# shape = [global_batch_size, global_batch_size]return logits_per_image, logits_per_text

总结

本文研究了在自然语言处理(NLP)领域取得成功的、与具体任务无关的大规模网络预训练方法,是否可以迁移到另一个领域。研究表明,采用这一方法后,在计算机视觉领域会出现类似的行为,我们也探讨了这一研究方向的社会影响。为了优化训练目标,CLIP 模型在预训练过程中学习执行多种不同的任务。这种任务学习可以通过自然语言提示(prompting)加以利用,从而实现对许多现有数据集的零样本(zero-shot)迁移。在足够大的规模下,这种方法的性能可以与特定任务的监督学习模型相竞争,尽管仍有很大的改进空间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74722.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[MySQL初阶]MySQL(9)事务机制

标题:[MySQL初阶]MySQL(9)事物机制 水墨不写bug 文章目录 一、认识事务1、多线程访问数据库出现的问题2、对CURD的限制是通过事务机制实现的3、事务的四个属性4、哪些引擎支持事务 二、事务的提交与autocommit设置三、事务的隔离性和隔离级别…

spring-cloud-alibaba-nacos-config使用说明

一、核心功能与定位 Spring Cloud Alibaba Nacos Config 是 Spring Cloud Alibaba 生态中的核心组件之一,专为微服务架构提供动态配置管理能力。它通过整合 Nacos 的配置中心功能,替代传统的 Spring Cloud Config,提供更高效的配置集中化管理…

SonarQube数据库配置

SonarQube部署完成后,在浏览器地址栏输入http://IP:9000可以进入登录页面,以本机运行为例,地址为http://127.0.0.1:9000/,默认登录名:admin,登录密码也是admin。登录后会要求设置密码: 按要求设…

医药档案区块链系统

1. 医生用户模块​​ ​​目标用户​​:医护人员 ​​核心功能​​: ​​检索档案​​:通过关键词或筛选条件快速定位患者健康档案。​​请求授权​​:向个人用户发起档案访问权限申请,需经对方确认。​​查看档案​…

CSS3学习教程,从入门到精通, 化妆品网站 HTML5 + CSS3 完整项目(26)

化妆品网站 HTML5 CSS3 完整项目 下面是一个完整的化妆品网站项目,包含主页、登录页面和注册页面。我将按照您的要求提供详细的代码和注释。 1. 网站规划与需求分析 需求分析 展示化妆品产品信息提供用户注册和登录功能响应式设计,适配不同设备美观…

ROS2 多机时间同步(Chrony配置简明指南)

适用场景: 主机运行 ROS2 Humble(发布 /scan 等),板子运行 ROS2 Foxy(发布 /tf 等),两边通过 ROS_DOMAIN_ID 跨平台通讯。需要保证系统时间对齐,避免 TF 插值失败、建图抖动等问题。…

Nginx配置伪静态,URL重写

Nginx配置伪静态,URL重写 [ Nginx ] 在Nginx低版本中,是不支持PATHINFO的,但是可以通过在Nginx.conf中配置转发规则实现: location / { // …..省略部分代码if (!-e $request_filename) {rewrite ^(.*)$ /index.php?s/$1 l…

电路笔记(元器件):ADC LTC系列模数转换器的输出范围+满量程和偏移调整

LTC1740(LTC1740官方文档)是Analog Devices(原Linear Technology)公司生产的一款高性能、低功耗的14位模数转换器(ADC)。它通常用于需要高精度和快速采样率的应用中,如通信系统、数据采集设备等。同类产品 LTC1746:一款14位、40Ms…

续-算法-数学知识

3、欧拉函数 1、定义: 1~n 中与 n 互质的数的个数 例如:6 的有 1 2 3 4 5 6 其中,与 n 互质 的 数的个数为 2个分别是:1、5 2、计算: $ N p_1^{a1} p_2^{a2} p_3^{a3} … p_k^{ak} $(例如&#x…

C/C++测试框架googletest使用示例

文章目录 文档编译安装示例参考文章 文档 https://github.com/google/googletest https://google.github.io/googletest/ 编译安装 googletest是cmake项目,可以用cmake指令编译 cmake -B build && cmake --build build将编译产物lib和include 两个文件夹…

LintCode第974题-求矩阵各节点的最短路径(以0为标准)

描述 给定一个由0和1组成的矩阵,求每个单元格最近的0的距离。 两个相邻细胞之间的距离是1。 给定矩阵的元素数不超过10,000。 在给定的矩阵中至少有一个0。 单元格在四个方向上相邻:上,下,左和右。 样例 例1: 输入: [[0,0,0],[0,0,0],[0…

Redis核心机制-缓存、分布式锁

目录 缓存 缓存更新策略 定期生成 实时生成 缓存问题 缓存预热(Cache preheating) 缓存穿透(Cache penetration) 缓存雪崩(Cache avalanche) 缓存击穿(Cache breakdown) 分…

CF每日5题(1300-1500)

最近急速补练蓝桥杯中,疏于cf练习。 感觉自己过题还是太慢了。 今日水题,我水水水水。 1- 1979C lcm 水 1400 第 i i i局赢了,1个硬币顶 k [ i ] k[i] k[i]个贡献,所以每局分硬币 x i 1 k [ i ] x_i{1\over k[i]} xi​k[i]1​个…

从代码学习深度学习 - LSTM PyTorch版

文章目录 前言一、数据加载与预处理1.1 代码实现1.2 功能解析二、LSTM介绍2.1 LSTM原理2.2 模型定义代码解析三、训练与预测3.1 训练逻辑代码解析3.2 可视化工具功能解析功能结果总结前言 深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文…

easy-poi 一对多导出

1. 需求: 某一列上下两行单元格A,B值一样且这两个单元格, 前面所有列对应单元格值一样的话, 就对A,B 两个单元格进行纵向合并单元格 1. 核心思路: 先对数据集的国家,省份,城市...... id 身份证进行排序…

AI比人脑更强,因为被植入思维模型【42】思维投影思维模型

giszz的理解:本质和外在。我们的行为举止,都是我们的内心的表现。从外边可以看内心,从内心可以判断外在。曾国藩有7个识人的方法,大部分的人在他的面前如同没穿衣服一样。对于我们自身的启迪,我认为有四点&…

Spring Boot 打印日志

1.通过slf4j包中的logger对象打印日志 Spring Boot内置了日志框架slf4j,在程序中调用slf4j来输出日志 通过创建logger对象打印日志,Logger 对象是属于 org.slf4j 包下的不要导错包。 2.日志级别 日志级别从高到低依次为: FATAL:致命信息,表…

【IOS webview】源代码映射错误,页面卡住不动

报错场景 safari页面报源代码映射错误,页面卡住不动。 机型:IOS13 技术栈:react 其他IOS也会报错,但不影响页面显示。 debug webpack配置不要GENERATE_SOURCEMAP。 解决方法: GENERATE_SOURCEMAPfalse react-app…

ES中经纬度查询geo_point

0. ES版本 6.x版本 1. 创建索引 PUT /location {"settings": {"number_of_shards": 1,"number_of_replicas": 0},"mappings": {"location": {"properties": {"id": {"type": "keywor…

OpenCV界面编程

《OpenCV计算机视觉开发实践:基于Python(人工智能技术丛书)》(朱文伟,李建英)【摘要 书评 试读】- 京东图书 OpenCV的Python开发环境搭建(Windows)-CSDN博客 OpenCV也支持有限的界面编程,主要是针对窗口、控件和鼠标…