Anchor DETR:Transformer-Based目标检测的Query设计

写在前面

文中指出之前DETR-like算法存在以下问题:

  • 之前DETR-liked检测算法里,object query是一组可学习的嵌入表示(就是一组256-d的向量),缺乏明确的物理意义,不能解释它们会关注什么地方。
  • 每个object query 预测的位置没有一个特定的模式(specific mode),即每个object query不会关注特定的区域。

PS:第二点所谓“预测位置没有一个特定模式”这个结论是怎么得出来的呢?作者援引了DETR论文中的一幅图像(如上图所示)进行说明。该图像中每个子图上都有很多点,每个子图代表了一个object query在验证集所有图像上得到的预测框的中心点坐标(经过归一化后的),绿色代表小的预测框,红色代表水平方向比较大的预测框,蓝色代表垂直方向比较大的预测框。通过上图可知,即使同一个object query,在不同图像上得到的预测框其位置和大小都是不固定的,所以说没有特定模式,而这使得object query难以优化。


为解决上述问题:

  • 本文基于anchor point(在CNN-based检测算法中被广泛使用)设计object query,每个object query关注anchor point附近的目标;
  • 本文object query的设计可以预测一个位置的多个目标;
  • 设计了一种注意力变体,减少显存占用。

论文的贡献或方法都可以转化成相应的问题,然后从文中逐一寻找答案,寻找答案的过程也是理解论文的过程,现在我们可以提出以下问题:

  • anchor point怎么来的?
  • 如何基于anchor point设计object query?
  • 为什么本文object query的设计可以预测一个位置的多个目标?
  • 注意力变体是怎么样的,为什么可以减少显存占用?

在阅读论文时带着问题,有目的的阅读,边阅读边思考,通常效果会好很多,也更容易理解作者想表达的意思。

接下来让我们从文中method部分寻找问题的答案。

一、Method

1. Anchor Points

Q1:anchor point怎么来的?

A1:如上图所示,文中采用两种方式获得anchor point。一种是网格均匀采样,anchor point被固定为图像中均匀的网格点;为另一种是可学习的点,这些点根据满足0~1均匀分布随机初始化并作为可学习参数进行更新,其中可学习点初始化的相关代码如下:

# --snap--
if self.spatial_prior == "learned":self.position = nn.Embedding(self.num_position, 2)# --snap--
if self.spatial_prior == "learned":nn.init.uniform_(self.position.weight.data, 0, 1)

有了anchor point,就可以把回归头的输出当作对于anchor point的偏移量(参考了Deformable DETR的做法),将预测框中心点坐标(\widehat{cx},\widehat{cy})加到对应的anchor point上。

对Deformable DETR不了解的朋友可以查看我的博客:Deformable DETR:结合多尺度特征、可变形卷积机制的DETR

作者通过对比实验(如下图所示),采用了可学习anchor point的策略(但综合看起来两者好像差别不显著= = 、)。

2. Attention Formulation

在回答第二个问题之前,我们首先需要了解一下论文中的一些符号表示。论文在该部分讲解了DETR-like检测算法中的注意力机制的建模方式(比较容易理解,不过多赘述),其中涉及的一些符号表示对我们理解文章的后续内容是有帮助的。

注意力机制建模方式如下:

Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V

Q=Q_{f}+Q_{p},K=K_{f}+K_{p},V=V_{f}

其中d_{k}表示维度,f表示内容信息,p表示位置信息。

decoder中包含自注意力交叉注意力

自注意力中K_{f}V_{f}Q_{f}是相同的,K_{p}Q_{p}是相同的,Q_{f}\in \mathbb{R}^{N_{q}\times C}表示decoder前一层的输出,对于decoder第一层而言Q^{init}_{f}\in \mathbb{R}^{N_{q}\times C}可以设置为常数向量,也可以设置为可学习的嵌入表示。query位置部分Q_{p}\in \mathbb{R}^{N_{q}\times C}在DETR中通常用一组可学习的嵌入向量表示,其中N_{q}表示query的数量:

Q_{p}=Embedding(N_{q},C)

交叉注意力的讲解略过,不难理解。

接下来我们可以继续寻找下一个问题的答案。

3. Anchor Points to Object Query 

Q2:如何基于anchor point设计object query?

A2:anchor point可表示为Pos_{q}\in \mathbb{R}^{N_{A}\times 2},其中N_{A}表示点的个数。根据anchor point获得object query只需要确定一种编码方式即可,即Q_{p}=Encode(Pos_{q})。一种很自然的想法就是继续使用位置编码函数进行编码,但作者采用了一个额外的MLP网络对位置编码结果进行微调。

PS:为什么要额外添加一个MLP微调位置编码结果?文中没有进行相应的消融实验,原因未知。

4. Multiple Predictions for Each Anchor Point

Q3:既然作者的想法是说,通过anchor point得到object query,使得每个object query能够关注某个特定的区域。那如果一个位置有多个目标,但这个位置只有一个object query关注这里,即只会有一个预测框,那怎么办?

A3:简单来说,就是让这个地方可以有多个预测框。作者重新回顾了decoder第一层query的内容部分Q^{init}_{f}\in \mathbb{R}^{N_{q}\times C},每个object query只有一种模式(pattern),即Q^{i}_{f}\in \mathbb{R}^{1\times C}。为了使得一个anchor point可以预测多个目标,作者将多模式嵌入(multiple pattern embedding)整合到了每个object query中,以适应一个位置存在多个目标的情况。其中多模式嵌入表示为:

Q^{i}_{j}=Embedding(N_{p},C)

其中N_{p}表示模式的数量(文中N_{p}=3)。

PS:如何理解pattern呢?我个人理解这里的pattern主要指的是预测框的位置和大小。通过增加pattern的数量,可以增加在某个位置预测框的数量,进而实现一个位置多个目标的检测。但具体如何将多个pattern整合到一个object query中,文中没有明确说明,我结合代码看了以下,简单来说就是把pattern embeddings和object query通过reshape和repeat统一到相同维度,再进行相加,相关代码如下:

# ---transformer的init方法---
# 初始化模式嵌入,(3,256)
self.pattern = nn.Embedding(self.num_pattern, d_model)
# 初始化anchor point,(300, 2)
if self.spatial_prior == "learned":self.position = nn.Embedding(self.num_position, 2)# ---transformer的forward方法---
# 调整维度(300, 2)-repeat->(bs, 3*300, 2)
if self.spatial_prior == "learned":reference_points = self.position.weight.unsqueeze(0).repeat(bs, self.num_pattern, 1)
# 为每个object query分配3个模式嵌入
# (3, 256)-reshape->(1, 3, 1, 256)-repeat->(bs, 3, 300, 256)-reshape->(bs, 3*300, 256)
tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, c).repeat(bs, 1, self.num_position, 1).reshape(bs, self.num_pattern * self.num_position, c)# ---decoder layer的forward方法---
# 对anchor point进行位置编码并使用MLP微调,(bs, 3*300, 2)-positional embed->(bs, 3*300, 256)
query_pos = adapt_pos2d(pos2posemb2d(reference_points))
# 将pattern embeddings和positional embeddings相加
q = k = self.with_pos_embed(tgt, query_pos)

文中还提到,由于平移不变性,所有object query都共享这些模式(个人理解写在下面)。因此进一步可以得到Q^{init}_{f}\in \mathbb{R}^{N_{p}N_{A}\times C}Q_{p}\in \mathbb{R}^{N_{p}N_{A}\times C},即object query的数量N_{q}=N_{p} \times N_{A}

PS:所谓平移不变性是什么意思呢?举个例子,对于一个检测模型来说,无论目标是在图像中间还是边缘,都应该检测到目标。而图像中每个位置都有可能出现多个目标的情况,所以所有object query应该共享这些模式。

模型预测框可视化结果如下图所示:

每一列表示一个object query在所有图像中预测框的中心点分布情况,其中最后一行的黑点表示anchor point,前三行表示每个pattern对应预测框中心点的分布情况,可以看出预测框都是在anchor point附近。

5. Row-Column Decoupled Attention(RCDA)

作者先说明了现有注意力机制的一些缺点:

  • transformer架构计算量较大,会占用较多的显存。
  • Deformable DETR虽然能降低显存,但会导致内存的随机访问,对硬件不友好(好吧,不懂硬件,说啥是啥)。
  • 其他注意力变体作者实验发现不适用于DETR-like的检测器。

所以,作者提出了一种新的注意力机制变体——行列解耦注意力,以降低显存要求,同时能媲美甚至超越DETR中标准的注意力机制。

大致思路跟深度可分离卷积好像差不多,就是对x和y分别进行计算,最后整合起来。具体的没仔细看,算法复杂度、降低内存啥的这类内容本能的排斥(主要是太菜了看不懂)。主要模型相关的内容已经介绍完了,后续有机会再把这部分内容补上吧。

二、实验结果

文中实验结果都比较好理解,后续补充对实验结果的个人思考。

三、总结

最后做个总结(也是回顾),Anchor DETR的主要贡献是:

  • 根据anchor point得到object query,使其具有明确物理意义,每个object query关注特定区域;
  • 针对第一点可能面临的“一个区域多个目标”的挑战,进一步将多个pattern整合到了一个object query,可实现一个位置多个目标的检测;
  • 提出行列解耦注意力机制,在降低显存使用的同时,性能可媲美甚至超过标准注意力机制。

上述改进使得模型收敛速度提高了10倍,性能也有较为显著提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/58081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探索现代软件开发中的持续集成与持续交付(CI/CD)实践

探索现代软件开发中的持续集成与持续交付(CI/CD)实践 随着软件开发的飞速进步,现代开发团队已经从传统的开发模式向更加自动化和灵活的开发流程转变。持续集成(CI) 与 持续交付(CD) 成为当下主…

【SSM-Day5】SpringMVC

【SSM-Day5】SpringMVC Web->Servlet->Servlet容器MVC档案SpringMVC档案SpringMVC核心操作📢建立连接RequestMapping:实现路由映射Controller/ResponseBody:表示Spring某个类是否可以接收HTTP请求 📢接收请求1. 直接接收一个…

【skywalking 】选择Elasticsearch存储

介绍 skywalking支持 Elasticsearch 和 OpenSearch 作为存储。 OpenSearch 是 ElasticSearch 7.11 的一个分支,但在 Apache 2.0 中获得许可。 OpenSearch 存储与 ElasticSearch 共享相同的配置。为了激活 OpenSearch 作为存储,请将存储提供程序设置为e…

MySQL中的Redo Log、Undo Log和Binlog:深入解析

引言 在数据库管理系统中,日志是保障数据一致性和完整性的关键机制。MySQL作为一种广泛使用的关系型数据库管理系统,提供了多种日志类型来满足不同的需求。本文将详细介绍MySQL中的Redo Log、Undo Log和Binlog,从背景、业务场景、功能、底层…

【QT】Qt窗口(上)

个人主页~ Qt窗口 一、菜单栏二、工具栏三、状态栏四、浮动窗口 Qt窗口是通过QMainWindow类来实现的,我们之前的学习是通过QWidget类实现的 QMainWindow包含一个菜单栏Menu Bar②,多个工具栏Tool Bars③,多个浮动窗口Dock Widgets&#xff0c…

CISC(Complex Instruction Set Computer)和RISC(Reduced Instruction Set Computer)

CISC(Complex Instruction Set Computer)和RISC(Reduced Instruction Set Computer)是两种不同类型的指令集架构(ISA),它们在设计理念、指令复杂性、寻址方式、实现方式以及应用场景上存在显著差…

关闭windows更新方法

在windows更新里选择暂停windows更新 然后按下winr,输入regedit 在注册表里找到 计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\WindowsUpdate\UX\Settings\PauseUpdatesExpiryTime 修改时间即可

什么是事件冒泡?如何阻止事件冒泡和浏览器默认事件?

事件冒泡是浏览器事件处理模型中的一个重要概念。当一个事件发生在某个元素上时,它会首先在该元素上触发,然后逐层向上冒泡到其父元素,直到根元素(通常是 document)为止。这意味着如果在一个嵌套的元素上触发了事件&am…

【数据分享】中国汽车市场年鉴(2013-2023)

数据介绍 在这十年里,中国自主品牌汽车迅速崛起。吉利、长城、比亚迪等品牌不断推出具有竞争力的车型,在国内市场乃至全球市场都占据了一席之地。同时,新能源汽车的发展更是如日中天。随着环保意识的提高和政策的大力支持,电动汽车…

PCL库中的算法封装详解

摘要 Point Cloud Library(PCL)是一个广泛应用于三维点云处理的开源库,涵盖了从基础数据结构到高级算法的丰富功能。PCL通过面向对象的设计和模块化的架构,将各种算法封装成独立的类,使得用户能够方便地调用和组合这些…

第十八届联合国世界旅游组织/亚太旅游协会旅游趋势与展望大会在广西桂林开幕

10月19日,第十八届联合国世界旅游组织/亚太旅游协会旅游趋势与展望大会(以下简称“大会”)在广西桂林开幕,来自美国、英国、德国、俄罗斯、柬埔寨等25个国家约120名政府官员、专家学者和旅游业界精英齐聚一堂,围绕“亚洲及太平洋地区旅游业&a…

Git 创建SSH秘钥

1、命令行输入 ssh-keygen -t rsa -b 4096 2、系统提示你“Enter a file in which to save the key”,直接按回车键 3、再提示你输入密码的时候直接按回车键,创建没有密码的SSH密钥 4、密钥对创建后,可以在自己电脑对应的 ~/.ssh 目录下找到…

java的String方法

lastIndexOf() 源码: public int lastIndexOf(String str) {return lastIndexOf(str, length());} lastIndexOf(String str):用于在一个字符串中查找指定字符最后一次出现的位置 subString() 源码: public String substring(int beginIn…

数据库如何保证主键唯一性

数据库保证主键(Primary Key)的唯一性主要通过以下机制实现: 1. **主键约束(PRIMARY KEY Constraint)**: 这是保证主键唯一性的核心机制。在数据库表中,通过定义主键约束,可以确…

MySQL关于DAYOFWEEK和WEEKDAY说明

⭕️前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家(点击跳转到网站)⭕️ 一、MySQL中关于DAYOFWEEK和WEEKDAY DAYOFWEEK和WEEKDAY均是MySQL中的日期实际函数,用于获取当前日期是星期几&#x…

Android13、14特殊权限-应用安装权限适配

Android13、14特殊权限-应用安装权限适配 文章目录 Android13、14特殊权限-应用安装权限适配一、前言二、权限适配三、其他1、特殊权限-应用安装权限适配小结2、dumpsys package查看获取到了应用安装权限3、Android权限系统:应用操作管理类AppOpsManager&#xff08…

达梦与mssql的order by的区别

在单表简单查询时,mssql和dm8都可以通过查询字段名或别名进行order by mssql和dm8,使用字段名进行order by select emp_ID,emp_Name from Employee order by emp_Name mssql和dm8,使用字段别名进行order by select emp_ID,emp_Name as …

【隐私计算篇】全同态加密应用场景案例(隐私云计算中的大模型推理、生物识别等)

1.题外话 最近因为奖项答辩,一直在忙材料准备,过程非常耗费时间和精力,很难有时间来分享。不过这段时间虽然很忙碌,但这期间有很多新的收获,特别是通过与领域内专家的深入交流和评审过程,对密码学和隐私计算…

今日头条APP移动手机端留痕脚本

这两个的脚本目的是什么呢? 很简单,就是批量访问指定用户的首页,在他人访客记录里面留下你的账户信息,可以让对方访问你的头条,概率下会关注你的头条,目的嘛,这个自己细想! 第1个是…

Python实现Android设备录屏功能及停止录屏功能

1、功能概述? 提供源码下载 之前通过ADB命令实现了实时的录屏功能。但是很遗憾,虽然通过adb命令录屏非常方便,但由于权限限制,无法在安卓系统较高的设备上使用。现选择使用另一开源工具来解决这一问题,并记录使用详细…