写在前面
文中指出之前DETR-like算法存在以下问题:
- 之前DETR-liked检测算法里,object query是一组可学习的嵌入表示(就是一组256-d的向量),缺乏明确的物理意义,不能解释它们会关注什么地方。
- 每个object query 预测的位置没有一个特定的模式(specific mode),即每个object query不会关注特定的区域。
PS:第二点所谓“预测位置没有一个特定模式”这个结论是怎么得出来的呢?作者援引了DETR论文中的一幅图像(如上图所示)进行说明。该图像中每个子图上都有很多点,每个子图代表了一个object query在验证集所有图像上得到的预测框的中心点坐标(经过归一化后的),绿色代表小的预测框,红色代表水平方向比较大的预测框,蓝色代表垂直方向比较大的预测框。通过上图可知,即使同一个object query,在不同图像上得到的预测框其位置和大小都是不固定的,所以说没有特定模式,而这使得object query难以优化。
为解决上述问题:
- 本文基于anchor point(在CNN-based检测算法中被广泛使用)设计object query,每个object query关注anchor point附近的目标;
- 本文object query的设计可以预测一个位置的多个目标;
- 设计了一种注意力变体,减少显存占用。
论文的贡献或方法都可以转化成相应的问题,然后从文中逐一寻找答案,寻找答案的过程也是理解论文的过程,现在我们可以提出以下问题:
- anchor point怎么来的?
- 如何基于anchor point设计object query?
- 为什么本文object query的设计可以预测一个位置的多个目标?
- 注意力变体是怎么样的,为什么可以减少显存占用?
在阅读论文时带着问题,有目的的阅读,边阅读边思考,通常效果会好很多,也更容易理解作者想表达的意思。
接下来让我们从文中method部分寻找问题的答案。
一、Method
1. Anchor Points
Q1:anchor point怎么来的?
A1:如上图所示,文中采用两种方式获得anchor point。一种是网格均匀采样,anchor point被固定为图像中均匀的网格点;为另一种是可学习的点,这些点根据满足0~1均匀分布随机初始化并作为可学习参数进行更新,其中可学习点初始化的相关代码如下:
# --snap--
if self.spatial_prior == "learned":self.position = nn.Embedding(self.num_position, 2)# --snap--
if self.spatial_prior == "learned":nn.init.uniform_(self.position.weight.data, 0, 1)
有了anchor point,就可以把回归头的输出当作对于anchor point的偏移量(参考了Deformable DETR的做法),将预测框中心点坐标加到对应的anchor point上。
对Deformable DETR不了解的朋友可以查看我的博客:Deformable DETR:结合多尺度特征、可变形卷积机制的DETR
作者通过对比实验(如下图所示),采用了可学习anchor point的策略(但综合看起来两者好像差别不显著= = 、)。
2. Attention Formulation
在回答第二个问题之前,我们首先需要了解一下论文中的一些符号表示。论文在该部分讲解了DETR-like检测算法中的注意力机制的建模方式(比较容易理解,不过多赘述),其中涉及的一些符号表示对我们理解文章的后续内容是有帮助的。
注意力机制建模方式如下:
其中表示维度,表示内容信息,表示位置信息。
decoder中包含自注意力和交叉注意力。
自注意力中,、和是相同的,和是相同的,表示decoder前一层的输出,对于decoder第一层而言可以设置为常数向量,也可以设置为可学习的嵌入表示。query位置部分在DETR中通常用一组可学习的嵌入向量表示,其中表示query的数量:
交叉注意力的讲解略过,不难理解。
接下来我们可以继续寻找下一个问题的答案。
3. Anchor Points to Object Query
Q2:如何基于anchor point设计object query?
A2:anchor point可表示为,其中表示点的个数。根据anchor point获得object query只需要确定一种编码方式即可,即。一种很自然的想法就是继续使用位置编码函数进行编码,但作者采用了一个额外的MLP网络对位置编码结果进行微调。
PS:为什么要额外添加一个MLP微调位置编码结果?文中没有进行相应的消融实验,原因未知。
4. Multiple Predictions for Each Anchor Point
Q3:既然作者的想法是说,通过anchor point得到object query,使得每个object query能够关注某个特定的区域。那如果一个位置有多个目标,但这个位置只有一个object query关注这里,即只会有一个预测框,那怎么办?
A3:简单来说,就是让这个地方可以有多个预测框。作者重新回顾了decoder第一层query的内容部分,每个object query只有一种模式(pattern),即。为了使得一个anchor point可以预测多个目标,作者将多模式嵌入(multiple pattern embedding)整合到了每个object query中,以适应一个位置存在多个目标的情况。其中多模式嵌入表示为:
其中表示模式的数量(文中)。
PS:如何理解pattern呢?我个人理解这里的pattern主要指的是预测框的位置和大小。通过增加pattern的数量,可以增加在某个位置预测框的数量,进而实现一个位置多个目标的检测。但具体如何将多个pattern整合到一个object query中,文中没有明确说明,我结合代码看了以下,简单来说就是把pattern embeddings和object query通过reshape和repeat统一到相同维度,再进行相加,相关代码如下:
# ---transformer的init方法---
# 初始化模式嵌入,(3,256)
self.pattern = nn.Embedding(self.num_pattern, d_model)
# 初始化anchor point,(300, 2)
if self.spatial_prior == "learned":self.position = nn.Embedding(self.num_position, 2)# ---transformer的forward方法---
# 调整维度(300, 2)-repeat->(bs, 3*300, 2)
if self.spatial_prior == "learned":reference_points = self.position.weight.unsqueeze(0).repeat(bs, self.num_pattern, 1)
# 为每个object query分配3个模式嵌入
# (3, 256)-reshape->(1, 3, 1, 256)-repeat->(bs, 3, 300, 256)-reshape->(bs, 3*300, 256)
tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, c).repeat(bs, 1, self.num_position, 1).reshape(bs, self.num_pattern * self.num_position, c)# ---decoder layer的forward方法---
# 对anchor point进行位置编码并使用MLP微调,(bs, 3*300, 2)-positional embed->(bs, 3*300, 256)
query_pos = adapt_pos2d(pos2posemb2d(reference_points))
# 将pattern embeddings和positional embeddings相加
q = k = self.with_pos_embed(tgt, query_pos)
文中还提到,由于平移不变性,所有object query都共享这些模式(个人理解写在下面)。因此进一步可以得到和,即object query的数量。
PS:所谓平移不变性是什么意思呢?举个例子,对于一个检测模型来说,无论目标是在图像中间还是边缘,都应该检测到目标。而图像中每个位置都有可能出现多个目标的情况,所以所有object query应该共享这些模式。
模型预测框可视化结果如下图所示:
每一列表示一个object query在所有图像中预测框的中心点分布情况,其中最后一行的黑点表示anchor point,前三行表示每个pattern对应预测框中心点的分布情况,可以看出预测框都是在anchor point附近。
5. Row-Column Decoupled Attention(RCDA)
作者先说明了现有注意力机制的一些缺点:
- transformer架构计算量较大,会占用较多的显存。
- Deformable DETR虽然能降低显存,但会导致内存的随机访问,对硬件不友好(好吧,不懂硬件,说啥是啥)。
- 其他注意力变体作者实验发现不适用于DETR-like的检测器。
所以,作者提出了一种新的注意力机制变体——行列解耦注意力,以降低显存要求,同时能媲美甚至超越DETR中标准的注意力机制。
大致思路跟深度可分离卷积好像差不多,就是对x和y分别进行计算,最后整合起来。具体的没仔细看,算法复杂度、降低内存啥的这类内容本能的排斥(主要是太菜了看不懂)。主要模型相关的内容已经介绍完了,后续有机会再把这部分内容补上吧。
二、实验结果
文中实验结果都比较好理解,后续补充对实验结果的个人思考。
三、总结
最后做个总结(也是回顾),Anchor DETR的主要贡献是:
- 根据anchor point得到object query,使其具有明确物理意义,每个object query关注特定区域;
- 针对第一点可能面临的“一个区域多个目标”的挑战,进一步将多个pattern整合到了一个object query,可实现一个位置多个目标的检测;
- 提出行列解耦注意力机制,在降低显存使用的同时,性能可媲美甚至超过标准注意力机制。
上述改进使得模型收敛速度提高了10倍,性能也有较为显著提升。