ML-Decoder: Scalable and Versatile Classification Head

1、引言

论文链接:https://openaccess.thecvf.com/content/WACV2023/papers/Ridnik_ML-Decoder_Scalable_and_Versatile_Classification_Head_WACV_2023_paper.pdf

        因为 transformer 解码器分类头[1] 在少类别多标签分类数据集上表现得很好,但由于其查询复杂度为 O(n^2),n 为类别数量,故 transformer 解码器分类头对于多类别数据集是不可行的,且 transformer 解码器分类头只适用于多标签分类任务,故 Tal Ridnik 等引入了一种新的基于多头注意力机制的分类头——ML-Decoder[2]。ML-Decoder 可以用于单标签分类、多标签分类和多标签 ZSL(zero shot learning) 任务,它提供更好的精度-速度 trade-off,可以用于上万类别的数据集,可以作为各种分类头的 drop-in 替代品,结合词查询可以用于 ZSL。

2、方法

        ML-Decoder 流如图 1 右所示,相对于  transformer 解码器分类头,ML-Decoder 有一下改变。

图1  transformer-decoder vs. ML-Decoder

2.1  移除自注意力机制

        通过删除自注意力机制将 ML-Decoder 的查询复杂度由 O(n^2) 降至 O(n),并未影响表示能力。

2.2  组解码

        为了使查询数量与类别数量无关,使用固定的 k 组查询,而不是一个类别对应一个查询。在前馈神经网络后,通过组全连接层在将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。如图 2 所示。

图2  组全连接方案(g=4)

2.3  固定查询        

        查询总是被输入到一个多头注意力层,该注意力层会先对查询应用一个可学习的投影计算。因此,将查询权重设置为可学习的是多余的——可学习的投影可以将任何固定值查询转换为可学习查询获得的任何值。

3、模块介绍

3.1  Cross-Attention

        Cross-Attention 的核心其实就是多头注意力机制,输入的 q 为固定查询,k 和 v 均为图像嵌入。Cross-Attention 和 Feed-Forward 模块构成所谓的 TransformerDecoder(Layer),python 代码如下所示:

class TransformerDecoder(nn.Module):def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1) -> None:super().__init__()self.dropout = nn.Dropout(dropout)self.norm0 = nn.LayerNorm(d_model)self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)# Implementation of Feedforward modelself.feed_forward = nn.Sequential(nn.LayerNorm(d_model),nn.Linear(d_model, dim_feedforward),nn.ReLU(),nn.Dropout(dropout),nn.Linear(dim_feedforward, d_model))self.norm1 = nn.LayerNorm(d_model)def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:tgt = tgt + self.dropout(tgt)tgt = self.norm0(tgt)tgt0 = self.multihead_attn(tgt, memory, memory)[0]tgt = tgt + self.dropout(tgt0)tgt0 = self.feed_forward(tgt)tgt = tgt + self.dropout(tgt0)return self.norm1(tgt)

3.2  Group Fully Connected Pooling  

        Group Fully Connected Pooling的目的是将每个组查询扩展到 g=n/k 个输出的同时池化嵌入维度。即将每组查询结果与对应的可学习的 (hidde_dim, g) 维矩阵相乘,python 代码如下所示:

class GroupFC(object):def __init__(self, groups: int):self.groups = groupsdef __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):"""计算每组类的 logits 值(未加偏置):param h: shape=(b, groups, hidden_dim):param duplicate_pooling: shape=(groups, hidden_dim, duplicate_factor), duplicate_factor 每组的类别数:param out_extrap: shape=(b, groups, duplicate_factor):return:"""for i in range(h.shape[1]):h_i = h[:, i, :]w_i = duplicate_pooling[i, :, :]out_extrap[:, i, :] = torch.matmul(h_i, w_i)

4、总结

        作者开源的 ML-Decoder 的 python 实现代码在:https://github.com/Alibaba-MIIL/ML_Decoder/blob/main/src_files/ml_decoder/ml_decoder.py

        论文[2] 在 paper with code 上的战绩如图 3 所示,表现还是不错的。

图3  来自论文[2] 的结果

        由于当参数 zsl != 0 时 wordvec_proj 的输入 query_embed = None,本人还未学习过 ZSL 领域,且使用该代码时报错(zsl = 0,当然应该是我的原因,但懒得排错了),于是参考作者的代码写了一个 MLDecoder 类(只考虑 zsl = 0),剩下的代码如下所示。

class MLDecoder(nn.Module):"""Args:groups: 查询/类别组数hidden_dim: Transformer 解码器特征维度in_dim: 输入 tensor 特征维度(CNN 编码器输出为通道数,Transformer 编码器输出为最后一个维度)"""def __init__(self, num_classes, groups, in_dim=2048, hidden_dim=768, mlp_dim=2048, nhead=8, dropout=0.1):super().__init__()self.proj = nn.Linear(in_dim, hidden_dim)# non-learnable queriesself.query_embed = nn.Embedding(groups, hidden_dim)self.query_embed.requires_grad_(False)self.num_classes = num_classesself.decoder = TransformerDecoder(d_model=hidden_dim, nhead=nhead, dim_feedforward=mlp_dim, dropout=dropout)# group fully-connectedself.duplicate_factor = math.ceil(num_classes / groups)  # 每组类别数量,math.ceil: 向上取整self.duplicate_pooling = torch.nn.Parameter(torch.zeros((groups, hidden_dim, self.duplicate_factor)))self.duplicate_pooling_bias = torch.nn.Parameter(torch.zeros(num_classes))torch.nn.init.xavier_normal_(self.duplicate_pooling)self.group_fc = GroupFC(groups)def forward(self, x):# 确保解码器输入 shape 为 [b, h * w, c]if len(x.shape) == 4:x = x.flatten(2).transpose(1, 2)x = F.relu(self.proj(x), True)  # (b, h * w, hidden_dim)# Cross-Attention + Feed-Forwardquery_embed = self.query_embed.weight  # (groups, hidden_dim)# tensor.expend: 增大一个维度至指定大小, 不增大的维度为-1,例如将 shape 由 (b, n, c)->(b, 2n, c), 参数 size=(-1, 2n,-1)tgt = query_embed[None].expand(x.shape[0], -1, -1)  # (b, groups, hidden_dim)h = self.decoder(tgt, x)  # (b, groups, hidden_dim)# Group Fully Connected Poolingout_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)self.group_fc(h, self.duplicate_pooling, out_extrap)h_out = out_extrap.flatten(1)[:, :self.num_classes]  # (b, num_classes)return h_out + self.duplicate_pooling_bias

参考文献

[1] Shilong Liu, Lei Zhang, Xiao Yang, Hang Su, and Jun Zhu. Query2label: A simple transformer way to multi-label classification. arXiv preprint arXiv:2107.10834, 2021.

[2] Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben Baruch, and Asaf Noy. Ml-decoder: Scalable and versatile classification head. In IEEE/CVF Winter Conference on Applications of Computer Vision, WACV 2023, Waikoloa, HI, USA, January 2-7, 2023, pages 32–41. IEEE, 2023.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/782269.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP的定时任务框架的taskPHP3.0学习记录2(环境要求、配置Redis、crontab执行时间语法、命令操作以及Screen全屏窗口管理器)

环境要求 php版本> 5.5开启socket扩展开启pdo扩展开启shmop扩展 echo <pre>; echo --; $requiredVersion 5.6.0; $currentVersion phpversion(); if (version_compare($currentVersion, $requiredVersion, >)) {echo "1.PHP版本满足要求&#xff0c;当前版…

c语言:vs2022写一个一元二次方程(包含虚根)

求一元二次方程 的根&#xff0c;通过键盘输入a、b、c&#xff0c;根据△的值输出对应x1和x2的值(保留一位小数)(用if语句完成)。 //一元二次方程的实现 #include <stdio.h> #include <math.h> #include <stdlib.h> int main() {double a, b, c, delta, x1…

数据结构 - 算法效率|时间复杂度|空间复杂度

目录 1.算法效率 2.时间复杂度 2.1定义 2.2大O渐近表示法 2.3常见时间复杂度计算举例 3.空间复杂度 3.1定义 3.2常见空间复杂度计算举例 1.算法效率 算法的效率常用算法复杂度来衡量&#xff0c;算法复杂度描述了算法在输入数据规模变化时&#xff0c;其运行时间和空间…

opejdk11 java 启动流程 java main方法怎么被jvm执行

java启动过程 java main方法怎么被jvm执行 java main方法是怎么被jvm调用的 1、jvm main入口 2、执行JLI_Launch方法 3、执行JVMInit方法 4、执行ContinueInNewThread方法 5、执行CallJavaMainInNewThread方法 6、创建线程执行ThreadJavaMain方法 7、执行ThreadJavaMain方法…

Last-Modified:HTTP缓存控制机制解析

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

docker配置github仓库ghcr国内镜像加速

文章目录 说明ghcr.io简介配置镜像命令地址命令行方式1panel面板方式方式一&#xff1a;配置镜像加速&#xff0c;命令行拉取方式二&#xff1a;配置镜像仓库&#xff0c;可视化拉取 说明 由于使用的容器需要从github下载镜像&#xff0c;服务器在国外下载速度很慢&#xff0c…

26. UE5 RPG同步面板属性(二)

在上一篇&#xff0c;我们解析了UI属性面板的实现步骤&#xff1a; 首先我们需要通过c去实现创建GameplayTag&#xff0c;这样可以在c和UE里同时获取到Tag创建一个DataAsset类&#xff0c;用于设置tag对应的属性和显示内容创建AttributeMenuWidgetController实现对应逻辑 并且…

理解游戏服务器架构-部署架构

目录 前言 我所理解的服务器架构 什么是否部署架构 部署架构的职责 进程业务职责 网络链接及通讯方式 与客户端的连接方式 服务器之间连接关系 数据落地以及一致性 数据库的选择 数据访问三级缓存 数据分片 读写分离 分布式数据处理 负载均衡 热更新 配置更新 …

html第二次作业

骨架 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-width, initi…

vscode初始化node项目

首先需要安装node环境&#xff0c;推荐直接使用nvm 安装node&#xff0c;方便切换node版本 1.npm init 初始化node项目 在命令行输入npm init指令 根据指令创建完成后会在当前目录下生成一个package.json文件&#xff0c;记住运行npm init执行的目录必须是一个空目录 2.创建…

金三银四面试题(八):JVM常见面试题(2)

今天我们继续探讨常见的JVM面试题。这些问题不比之前的问题庞大&#xff0c;多用于面试中​JVM部分的热身运动&#xff0c;开胃菜&#xff0c;但是大家已经要认真准备。 JRE、JDK、JVM 及JIT 之间有什么不同&#xff1f; JRE 代表Java 运行时&#xff08;Java run-time&#…

Kafka入门到实战-第四弹

Kafka入门到实战 Kafka集群搭建官网地址Kafka概述使用Kraft搭建Kafka集群更新计划 Kafka集群搭建 官网地址 声明: 由于操作系统, 版本更新等原因, 文章所列内容不一定100%复现, 还要以官方信息为准 https://kafka.apache.org/Kafka概述 Apache Kafka 是一个开源的分布式事件…

计算机视觉之三维重建(5)---双目立体视觉

文章目录 一、平行视图1.1 示意图1.2 平行视图的基础矩阵1.3 平行视图的极几何1.4 平行视图的三角测量 二、图像校正三、对应点问题3.1 相关匹配法3.2 归一化相关匹配法3.3 窗口问题3.4 相关法存在的问题3.5 约束问题 一、平行视图 1.1 示意图 如下图即是一个平行视图。特点&a…

2核2G服务器优惠价格轻量61元一年,CVM价格313元15个月

腾讯云2核2G服务器多少钱一年&#xff1f;轻量服务器61元一年&#xff0c;CVM 2核2G S5服务器313.2元15个月&#xff0c;轻量2核2G3M带宽、40系统盘&#xff0c;云服务器CVM S5实例是2核2G、50G系统盘。腾讯云2核2G服务器优惠活动 txybk.com/go/txy 链接打开如下图&#xff1a;…

最小覆盖子串-java

最小覆盖子串-java 题目描述 : 给你一个字符串 s 、一个字符串 t 。返回 s 中涵盖 t 所有字符的最小子串。如果 s 中不存在涵盖 t 所有字符的子串&#xff0c;则返回空字符串 "" 。 注意&#xff1a; 对于 t 中重复字符&#xff0c;我们寻找的子字符串中该字符数量必…

Mac上Matlab_R2023b ARM 版 启动闪退(意外退出)解决方法

安装好后&#xff0c;使用 "libmwlmgrimpl.dylib" 文件替换掉"/Applications/Matlab_R2023b.app/bin/maca64/matlab_startup_plugins/lmgrimpl"文件夹下的同名文件 在终端下执行如下命令&#xff1a; codesign --verbose --force --deep -s - /Applicat…

本地项目上传到GitHub

本文档因使用实际项目提交做为案例&#xff0c;故使用xxx等字符进行脱敏&#xff0c;同时隐藏了部分输出&#xff0c;已实际项目和命令行输出为准 0、 Git 安装与GitHub注册 1&#xff09; 在下述地址下载Git&#xff0c;安装一路默认下一步即可。安装完成后&#xff0c;随便…

乡村数字化转型:科技赋能打造智慧农村新生态

随着信息技术的迅猛发展&#xff0c;数字化转型已成为推动社会进步的重要引擎。在乡村振兴的大背景下&#xff0c;乡村数字化转型不仅是提升乡村治理能力和治理水平现代化的关键&#xff0c;更是推动农业现代化、农村繁荣和农民增收的重要途径。本文旨在探讨乡村数字化转型的内…

解决MySQL幻读?可重复读隔离级别背后的工作原理

什么是当前读和快照读 当前读&#xff1a;又称为 "锁定读"&#xff0c;它会读取记录的最新版本&#xff08;也就是最新的提交结果&#xff09;&#xff0c;并对读取到的数据加锁&#xff0c;其它事务不能修改这些数据&#xff0c;直到当前事务提交或回滚。"sele…

Linux课程____shell脚本应用

:一、认识shell 常用解释器 Bash , ksh , csh 登陆后默认使用shell&#xff0c;一般为/bin/bash&#xff0c;不同的指令&#xff0c;运行的环境也不同 二、 编写简单脚本并使用 # vim /frist.sh //编写脚本文件&#xff0c;简单内容 #&#xff01;/bin/bash …