Anchor DETR:Transformer-Based目标检测的Query设计

写在前面

文中指出之前DETR-like算法存在以下问题:

  • 之前DETR-liked检测算法里,object query是一组可学习的嵌入表示(就是一组256-d的向量),缺乏明确的物理意义,不能解释它们会关注什么地方。
  • 每个object query 预测的位置没有一个特定的模式(specific mode),即每个object query不会关注特定的区域。

PS:第二点所谓“预测位置没有一个特定模式”这个结论是怎么得出来的呢?作者援引了DETR论文中的一幅图像(如上图所示)进行说明。该图像中每个子图上都有很多点,每个子图代表了一个object query在验证集所有图像上得到的预测框的中心点坐标(经过归一化后的),绿色代表小的预测框,红色代表水平方向比较大的预测框,蓝色代表垂直方向比较大的预测框。通过上图可知,即使同一个object query,在不同图像上得到的预测框其位置和大小都是不固定的,所以说没有特定模式,而这使得object query难以优化。


为解决上述问题:

  • 本文基于anchor point(在CNN-based检测算法中被广泛使用)设计object query,每个object query关注anchor point附近的目标;
  • 本文object query的设计可以预测一个位置的多个目标;
  • 设计了一种注意力变体,减少显存占用。

论文的贡献或方法都可以转化成相应的问题,然后从文中逐一寻找答案,寻找答案的过程也是理解论文的过程,现在我们可以提出以下问题:

  • anchor point怎么来的?
  • 如何基于anchor point设计object query?
  • 为什么本文object query的设计可以预测一个位置的多个目标?
  • 注意力变体是怎么样的,为什么可以减少显存占用?

在阅读论文时带着问题,有目的的阅读,边阅读边思考,通常效果会好很多,也更容易理解作者想表达的意思。

接下来让我们从文中method部分寻找问题的答案。

一、Method

1. Anchor Points

Q1:anchor point怎么来的?

A1:如上图所示,文中采用两种方式获得anchor point。一种是网格均匀采样,anchor point被固定为图像中均匀的网格点;为另一种是可学习的点,这些点根据满足0~1均匀分布随机初始化并作为可学习参数进行更新,其中可学习点初始化的相关代码如下:

# --snap--
if self.spatial_prior == "learned":self.position = nn.Embedding(self.num_position, 2)# --snap--
if self.spatial_prior == "learned":nn.init.uniform_(self.position.weight.data, 0, 1)

有了anchor point,就可以把回归头的输出当作对于anchor point的偏移量(参考了Deformable DETR的做法),将预测框中心点坐标(\widehat{cx},\widehat{cy})加到对应的anchor point上。

对Deformable DETR不了解的朋友可以查看我的博客:Deformable DETR:结合多尺度特征、可变形卷积机制的DETR

作者通过对比实验(如下图所示),采用了可学习anchor point的策略(但综合看起来两者好像差别不显著= = 、)。

2. Attention Formulation

在回答第二个问题之前,我们首先需要了解一下论文中的一些符号表示。论文在该部分讲解了DETR-like检测算法中的注意力机制的建模方式(比较容易理解,不过多赘述),其中涉及的一些符号表示对我们理解文章的后续内容是有帮助的。

注意力机制建模方式如下:

Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V

Q=Q_{f}+Q_{p},K=K_{f}+K_{p},V=V_{f}

其中d_{k}表示维度,f表示内容信息,p表示位置信息。

decoder中包含自注意力交叉注意力

自注意力中K_{f}V_{f}Q_{f}是相同的,K_{p}Q_{p}是相同的,Q_{f}\in \mathbb{R}^{N_{q}\times C}表示decoder前一层的输出,对于decoder第一层而言Q^{init}_{f}\in \mathbb{R}^{N_{q}\times C}可以设置为常数向量,也可以设置为可学习的嵌入表示。query位置部分Q_{p}\in \mathbb{R}^{N_{q}\times C}在DETR中通常用一组可学习的嵌入向量表示,其中N_{q}表示query的数量:

Q_{p}=Embedding(N_{q},C)

交叉注意力的讲解略过,不难理解。

接下来我们可以继续寻找下一个问题的答案。

3. Anchor Points to Object Query 

Q2:如何基于anchor point设计object query?

A2:anchor point可表示为Pos_{q}\in \mathbb{R}^{N_{A}\times 2},其中N_{A}表示点的个数。根据anchor point获得object query只需要确定一种编码方式即可,即Q_{p}=Encode(Pos_{q})。一种很自然的想法就是继续使用位置编码函数进行编码,但作者采用了一个额外的MLP网络对位置编码结果进行微调。

PS:为什么要额外添加一个MLP微调位置编码结果?文中没有进行相应的消融实验,原因未知。

4. Multiple Predictions for Each Anchor Point

Q3:既然作者的想法是说,通过anchor point得到object query,使得每个object query能够关注某个特定的区域。那如果一个位置有多个目标,但这个位置只有一个object query关注这里,即只会有一个预测框,那怎么办?

A3:简单来说,就是让这个地方可以有多个预测框。作者重新回顾了decoder第一层query的内容部分Q^{init}_{f}\in \mathbb{R}^{N_{q}\times C},每个object query只有一种模式(pattern),即Q^{i}_{f}\in \mathbb{R}^{1\times C}。为了使得一个anchor point可以预测多个目标,作者将多模式嵌入(multiple pattern embedding)整合到了每个object query中,以适应一个位置存在多个目标的情况。其中多模式嵌入表示为:

Q^{i}_{j}=Embedding(N_{p},C)

其中N_{p}表示模式的数量(文中N_{p}=3)。

PS:如何理解pattern呢?我个人理解这里的pattern主要指的是预测框的位置和大小。通过增加pattern的数量,可以增加在某个位置预测框的数量,进而实现一个位置多个目标的检测。但具体如何将多个pattern整合到一个object query中,文中没有明确说明,我结合代码看了以下,简单来说就是把pattern embeddings和object query通过reshape和repeat统一到相同维度,再进行相加,相关代码如下:

# ---transformer的init方法---
# 初始化模式嵌入,(3,256)
self.pattern = nn.Embedding(self.num_pattern, d_model)
# 初始化anchor point,(300, 2)
if self.spatial_prior == "learned":self.position = nn.Embedding(self.num_position, 2)# ---transformer的forward方法---
# 调整维度(300, 2)-repeat->(bs, 3*300, 2)
if self.spatial_prior == "learned":reference_points = self.position.weight.unsqueeze(0).repeat(bs, self.num_pattern, 1)
# 为每个object query分配3个模式嵌入
# (3, 256)-reshape->(1, 3, 1, 256)-repeat->(bs, 3, 300, 256)-reshape->(bs, 3*300, 256)
tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, c).repeat(bs, 1, self.num_position, 1).reshape(bs, self.num_pattern * self.num_position, c)# ---decoder layer的forward方法---
# 对anchor point进行位置编码并使用MLP微调,(bs, 3*300, 2)-positional embed->(bs, 3*300, 256)
query_pos = adapt_pos2d(pos2posemb2d(reference_points))
# 将pattern embeddings和positional embeddings相加
q = k = self.with_pos_embed(tgt, query_pos)

文中还提到,由于平移不变性,所有object query都共享这些模式(个人理解写在下面)。因此进一步可以得到Q^{init}_{f}\in \mathbb{R}^{N_{p}N_{A}\times C}Q_{p}\in \mathbb{R}^{N_{p}N_{A}\times C},即object query的数量N_{q}=N_{p} \times N_{A}

PS:所谓平移不变性是什么意思呢?举个例子,对于一个检测模型来说,无论目标是在图像中间还是边缘,都应该检测到目标。而图像中每个位置都有可能出现多个目标的情况,所以所有object query应该共享这些模式。

模型预测框可视化结果如下图所示:

每一列表示一个object query在所有图像中预测框的中心点分布情况,其中最后一行的黑点表示anchor point,前三行表示每个pattern对应预测框中心点的分布情况,可以看出预测框都是在anchor point附近。

5. Row-Column Decoupled Attention(RCDA)

作者先说明了现有注意力机制的一些缺点:

  • transformer架构计算量较大,会占用较多的显存。
  • Deformable DETR虽然能降低显存,但会导致内存的随机访问,对硬件不友好(好吧,不懂硬件,说啥是啥)。
  • 其他注意力变体作者实验发现不适用于DETR-like的检测器。

所以,作者提出了一种新的注意力机制变体——行列解耦注意力,以降低显存要求,同时能媲美甚至超越DETR中标准的注意力机制。

大致思路跟深度可分离卷积好像差不多,就是对x和y分别进行计算,最后整合起来。具体的没仔细看,算法复杂度、降低内存啥的这类内容本能的排斥(主要是太菜了看不懂)。主要模型相关的内容已经介绍完了,后续有机会再把这部分内容补上吧。

二、实验结果

文中实验结果都比较好理解,后续补充对实验结果的个人思考。

三、总结

最后做个总结(也是回顾),Anchor DETR的主要贡献是:

  • 根据anchor point得到object query,使其具有明确物理意义,每个object query关注特定区域;
  • 针对第一点可能面临的“一个区域多个目标”的挑战,进一步将多个pattern整合到了一个object query,可实现一个位置多个目标的检测;
  • 提出行列解耦注意力机制,在降低显存使用的同时,性能可媲美甚至超过标准注意力机制。

上述改进使得模型收敛速度提高了10倍,性能也有较为显著提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/58081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探索现代软件开发中的持续集成与持续交付(CI/CD)实践

探索现代软件开发中的持续集成与持续交付(CI/CD)实践 随着软件开发的飞速进步,现代开发团队已经从传统的开发模式向更加自动化和灵活的开发流程转变。持续集成(CI) 与 持续交付(CD) 成为当下主…

【SSM-Day5】SpringMVC

【SSM-Day5】SpringMVC Web->Servlet->Servlet容器MVC档案SpringMVC档案SpringMVC核心操作📢建立连接RequestMapping:实现路由映射Controller/ResponseBody:表示Spring某个类是否可以接收HTTP请求 📢接收请求1. 直接接收一个…

【skywalking 】选择Elasticsearch存储

介绍 skywalking支持 Elasticsearch 和 OpenSearch 作为存储。 OpenSearch 是 ElasticSearch 7.11 的一个分支,但在 Apache 2.0 中获得许可。 OpenSearch 存储与 ElasticSearch 共享相同的配置。为了激活 OpenSearch 作为存储,请将存储提供程序设置为e…

【QT】Qt窗口(上)

个人主页~ Qt窗口 一、菜单栏二、工具栏三、状态栏四、浮动窗口 Qt窗口是通过QMainWindow类来实现的,我们之前的学习是通过QWidget类实现的 QMainWindow包含一个菜单栏Menu Bar②,多个工具栏Tool Bars③,多个浮动窗口Dock Widgets&#xff0c…

关闭windows更新方法

在windows更新里选择暂停windows更新 然后按下winr,输入regedit 在注册表里找到 计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\WindowsUpdate\UX\Settings\PauseUpdatesExpiryTime 修改时间即可

【数据分享】中国汽车市场年鉴(2013-2023)

数据介绍 在这十年里,中国自主品牌汽车迅速崛起。吉利、长城、比亚迪等品牌不断推出具有竞争力的车型,在国内市场乃至全球市场都占据了一席之地。同时,新能源汽车的发展更是如日中天。随着环保意识的提高和政策的大力支持,电动汽车…

第十八届联合国世界旅游组织/亚太旅游协会旅游趋势与展望大会在广西桂林开幕

10月19日,第十八届联合国世界旅游组织/亚太旅游协会旅游趋势与展望大会(以下简称“大会”)在广西桂林开幕,来自美国、英国、德国、俄罗斯、柬埔寨等25个国家约120名政府官员、专家学者和旅游业界精英齐聚一堂,围绕“亚洲及太平洋地区旅游业&a…

Git 创建SSH秘钥

1、命令行输入 ssh-keygen -t rsa -b 4096 2、系统提示你“Enter a file in which to save the key”,直接按回车键 3、再提示你输入密码的时候直接按回车键,创建没有密码的SSH密钥 4、密钥对创建后,可以在自己电脑对应的 ~/.ssh 目录下找到…

【隐私计算篇】全同态加密应用场景案例(隐私云计算中的大模型推理、生物识别等)

1.题外话 最近因为奖项答辩,一直在忙材料准备,过程非常耗费时间和精力,很难有时间来分享。不过这段时间虽然很忙碌,但这期间有很多新的收获,特别是通过与领域内专家的深入交流和评审过程,对密码学和隐私计算…

今日头条APP移动手机端留痕脚本

这两个的脚本目的是什么呢? 很简单,就是批量访问指定用户的首页,在他人访客记录里面留下你的账户信息,可以让对方访问你的头条,概率下会关注你的头条,目的嘛,这个自己细想! 第1个是…

Python实现Android设备录屏功能及停止录屏功能

1、功能概述? 提供源码下载 之前通过ADB命令实现了实时的录屏功能。但是很遗憾,虽然通过adb命令录屏非常方便,但由于权限限制,无法在安卓系统较高的设备上使用。现选择使用另一开源工具来解决这一问题,并记录使用详细…

java jsoup爬虫如何快速获取到html页面的选择器元素

java jsoup爬虫如何快速获取到html页面的选择器元素 一、打开元素选择器二、选定元素三、定位元素位置四、右键 -> copy ->copySelector五、代码中获取 一、打开元素选择器 在java采用jsoup爬虫中,返回的是html页面而不是json字段,就需要使用jsou…

[C++11] 右值引⽤与移动语义

文章目录 左值和右值左值(Lvalue)右值(Rvalue)区别 左值引⽤和右值引⽤左值引用(Lvalue Reference)右值引用(Rvalue Reference)右值引用的特点 右值引用延长生命周期右值引⽤和移动语…

传输层UDP

再谈端口号 端口号:标识了主机上进行通信的不同的应用程序 在TCP/IP 协议中我们用“源IP”"源端口号" “目的IP”“目的端口号” “协议号”五元组来标识一个通信 用netstat -n 查看 查看网络信息,我们有两种命令查看网络通信1.用netsta…

Linux-练习3

题目: 操作过程: 1.建立用户组 shengcan,其id 为 2000 2.建立用户组 caiwu,其id 为 2001 3.建立用户组 jishu,其 id 为 2002 4.建立用户 lee,指定其主组 id 为 shengchan,附加组为 jishu 和…

多GPU训练大语言模型,DDP, ZeRO 和 FSDP

在某些时候,我们可能需要将模型训练扩展到单个 GPU 之外。当模型变得太大无法适应单个 GPU 时,需要使用多 GPU 计算策略。但即使模型适合单个 GPU,使用多个 GPU 来加速训练也是有好处的。即使您正在处理一个小模型,了解如何在多个…

在浏览器中运行 Puppeteer:解锁新能力

Puppeteer,这个强大的浏览器自动化工具,通常在Node.js环境中运行。但你有没有想过,在浏览器本身中运行Puppeteer会是什么样子?这不仅能让我们利用Puppeteer的功能完成更多任务,还能避开Node.js特定的限制。 支持的功…

【Canvas与桌面】文山甲密铺桌面壁纸 1920*1080

【成图】 不加蒙版的部分截图&#xff1a; 加上蒙版的桌面壁纸图&#xff1a; 不加蒙版的桌面壁纸图&#xff1a; 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8&qu…

ts:对象数组的简单使用

ts中对象数组的简单使用 一、主要内容说明二、例子1、源码12、源码1运行效果 三、结语四、定位日期 一、主要内容说明 平常ts创建数组的格式如下&#xff1a; let array:string[]["元素1","元素2","元素3","元素3","元素4"…

Java语言-异常

目录 1.异常的概念与体系结构 1.1 异常的概念 1.2 异常的体系结构 1.3 异常的分类 1.3.1 编译时异常(受查异常) 1.3.2 运行时异常(非受查异常) 2.异常的处理 2.1 防御式编程 2.1.1 LBYL 2.1.2 EAFP 2.2 异常的抛出 2.3 异常的捕获 2.3.1 异常声明throws 2.3.2 …