DETR 论文精读【论文精读】
这一次朱毅博士给大家精读的论文是 DETR
,是目标检测领域里程碑式的一个工作,文章收录于 ECCV20
。DETR
是 Detection Transformer
的缩写,作者使用 Transformer
简化了目标检测流程,不再需要进行 NMS(非极大值抑制)
操作,而是直接将目标预测看作是集合预测问题
。
这篇论文在2020年5月第一次挂到 arxiv
上的时候,就引起了计算机视觉圈的强烈反响,网上很多人都觉得这篇论文应该是 ECCV20
的最佳论文。从20年5月到22年5月两年的时间 DETR
的引用次数就已经超过了2000,而且它的官方代码也有将近 9000个star。如果跟去年大火的 CLIP
相比,CLIP
的官方代码到现在也只有8000个star,可见 DETR
的受欢迎程度。
DETR
论文链接为:https://export.arxiv.org/pdf/2005.12872.pdf
DETR
代码链接为:https://github.com/facebookresearch/detr
0. 标题、作者、摘要
首先是论文标题
,论文标题意思是:使用 Transformers
进行端到端
的目标检测。
第一个关键词是 End-to-End
即端到端的目标检测,另外一个就是使用了目前大火的 Transformer
来做目标检测。
-
对目标检测了解不多的同学来说,可能并不能体会到这个端到端的意义究竟有多大?事实上呢,从深度学习开始火一直到
DETR
,目标检测领域都很少有端到端的方法。大部分方法至少最后还需要一个后处理的操作
,也就是nms
这个操作。不论是proposal based
的方法还是anchor based
的方法,还是non anchor based
的方法,最后都会生成很多的预测框,如何去除这些冗余的框,就是nms
要做的事情,而因为有了 nms 的存在,这个模型在调参上就比较复杂。而且即使训练好了一个模型,部署起来也非常困难,因为nms
这个操作不是所有硬件都支持的。所以说一个简单的端到端的目标检测系统,是大家一直以来梦寐以求的,而DETR
就解决了以上说的这些痛点。 -
它既不需要
proposal
也不需要anchor
,直接利用Transformer
这种能全局建模的能力,把目标检测看成了一个集合预测的问题
,而且也因为有了这种全局建模的能力
,所以DETR
不会输出那么多冗余的框,它最后出什么结果就是什么结果。而不需要再用nms
去做这个后处理了,一下就让模型的训练和部署都简单了不少。作者其实在他们的官方代码里也写到,他们的目的就是不想让大家一直觉得目标检测是一个比图像分类难很多的任务。它们其实都可以用一种简单的优雅的框架去做出来,而不像之前那些目标检测框架需要很多的人工干预,很多的先验知识,而且还需要很多复杂的库,或者普通的硬件不支持的一些算子。
文章作者
全部来自 Facebook AI
,一作是 NYU
的博士生,这应该是在 Facebook
实习的时候做出来的工作。二作 是 PyTorch
最重要的维护者之一,五作是第一个提出全景分割这个任务的,就是 panoptic segmentation
,在分割领域非常有经验,因此本文在最后做了一个全景分割的拓展实验。
下面是论文摘要
,摘要总共有9句话。
- 第1句话交代了这篇文章干了什么事情,这篇文章把
目标检测任务直接看成是一个集合预测的问题
。因为本来的任务就是给定一个图片然后去预测一堆框,每个框不光要知道它的坐标,还要知道这个框里所包含的物体的类别。但是这些框其实是一个集合,对于不同的图片来说,它里面包含的框也是不一样的,也就是说每个图片对应的那个集合也是不一样的,而我们的任务就是说给定一个图片我们要去把这个集合预测出来。听起来这么直接的一个设置之前很少有人把它做work
。 - 第2句话是这篇文章的贡献,把目标检测做成了一个端到端的框架。把之前特别依赖于人的先验知识的部分都给删掉了。比如说最突出的
非极大值抑制
的部分,还有生成anchor
的部分。一旦把这两个部分拿掉之后,就不用费尽心思去设计这种anchor
,而且最后也不会出这么多框,也不会用到nms
,也不会有那么多的超参需要去调,整个网络就变得非常的简单。 - 第3句话具体介绍了
DETR
,本文提出了两个东西,一个就是新的目标函数
,通过这种二分图匹配
的方式能够强制模型去输出一组独一无二的这个预测,意思就说没有那么多冗余的框了,每个物体理想状态下他就会生成一个框。另外一个就是使用了这种tranformer encoder-decoder
的架构。 - 第4句话介绍在这个
transformer encoder-decoder
架构中,在transformers
解码器的时候,另外还有一个输入learned object query
。有点类似于anchor
的意思。在给定了这些object query
之后,DETR
就可以把这个learned object query
和全局的图像信息结合在一起,通过不停的去做注意力操作,从而能够让模型直接输出最后的一组预测框,而且作者这里还强调了一下是in parallel
,就是一起出框。为什么要强调是并行出框呢,其实有两个原因:一个原因是在Tranformer
原始的2017年那篇论文里,decoder
是用在自然语言处理的任务上,所以还有一个掩码机制,采用的是一种自回归的方式就一点一点把这个文本生成出来的。而目标检测任务我们是一股脑就把这些目标检测全都输出出来。而且第二点不论是从想法上,还是从实效性上来说并行都比串行要更合适,因为首先这些框,它是没有一个顺序的,并不是说想检测一个大物体,就要先依赖于检测小物体,或者说检测图片右边的物体就要依赖于图检测图片左边的物体;而没有一个先后的顺序,所以说没法做这种自回归的预测。第二对于视觉,尤其对于检测任务来说,我们肯定是希望它越快越好,越实时性越好。所以并行一起出框,肯定是要比顺序的这种一个一个出框要实效性高很多的。 - 第5句话介绍了
DETR
有什么好处,作者这里把简单性
排到了第一位。这个新的模型从想法上来看非常简单,而且从实践上来看也不需要一个特殊的deep learning library
,只要库或者硬件支持cnn
和transformer
,那就一定能支持DETR
。 - 第6句话作者提到了
DETR
的性能,在COCO
数据集上DETR
和一个训练的非常好的Faster RCNN
的基线网络取得了差不多的这个结果。而且模型内存、速度也和Faster RCNN
差不多。其实光从目标检测的这个性能上来说,DETR
在当时不是最强的,它和当时最好的方法差了将近10个点。 - 第7-8句话作者又试了一下别的任务,在全景分割这个任务上
DETR
的效果非常不错。最后作者又列了一下这个优点,就是DETR
能够非常简单的就能拓展到其他的任务上,比如说全景分割,就是用DETR
在后面加一个专用的分割头就可以了。这个性质其实非常厉害,尤其是对于一个新的方法来说。因为它本来就是挖了一个坑,那这个坑肯定是挖的越大越好,这样接下来才会有更多人来填坑。从这个角度来说性能稍微差一点其实是好事,因为接下来大家才会在这个上面去继续去做。 - 最后作者说他们的这个代码全都在
facebook research
下面这个DETR
的repo
里,感兴趣的同学可以去玩一玩。facebook research
的代码库写的非常好,尤其是本来就是做开源的人写出来的代码库,看起来都是一种享受。
1. 引言
引言第一段
介绍了本文研究的动机。
- 作者上来先说目标检测任务就是对于每一个这个感兴趣的物体去预测一些框和这个物体的类别,所以说白了就是一个
集合预测的问题
。但是现在大多数好用的目标检测器都是用一种间接的方式
去处理这个集合预测的问题。比如说用这种proposal
的方式(RCNN
系列的工作Faster R-CNN
、Mask R-CNN
、Cascade R-CNN
),或者anchor
方式(YOLO
、Focal loss
),还有就是最近的一些non anchor base
的方法,比如说用物体的中心点(centernet
、FCOS
)。这些方法都没有直接的去做几何预测的任务,而是设计了一个代理任务,要么是回归要么是分类,然后去解决目标检测问题。 - 但是所有这些提到的方法,它们的性能很大程度上受限于后处理操作,也就是
nms
操作。因为这些方法都会生成大量的这种冗余的框
,也就是这里说的near duplicate predictions
。接近重复的预测,对于同一个物体会大大小小左左右右的出来很多框,所以最后就得用nms
去把这些框都抑制掉。但是也因为用了anchor
,因为用了nms
,所以导致这些检测器都非常的复杂,而且非常的难以优化,非常的难调参。 - 为了简化这个流程本文提出了一个直接的方式去解决集合预测的问题,从而巧妙的绕过了之前所有的这些代理任务,也就绕过了人工设计的部分,比如说生成
anchor
,比如说使用nms
。 - 作者说这种
端到端
的思想其实已经在很多别的任务里大范围的使用了。而且能使这些任务都变得更加的简单,更加的好用。但是在目标检测领域还没有人这么做过,之前也有一些类似的尝试(比如learnable nms
或者soft nms
)。它们一定程度上简化了目标检测的流程,但是要么就是融入了更多的先验知识,要么就是在这些比较难的benchmark
数据集上取得不了很好的成绩。 - 最后作者总结第一段,作者的目的就是要把这个鸿沟弥补上:即我们不需要使用更多的先验知识,我们就是端到端的一个网络同时还能在这些比较难的数据集上取得更好的结果。
图一是 DETR
整个流程,大致分为以下几步:
- 首先用卷积神经网络抽取特征,然后拿到这个特征之后把它拉直,送给一个
Transformer
的encoder-decoder
。 - 在这里
Transformer encoder
的作用就是去进一步的学习全局的信息
,为接下来的decoder
也是为最后的出预测框做铺垫。本文中用了很多的实验和图来说明为什么要用encoder
,但最直观最简单的理解就是如果使用了encoder
。那每一个点或者说每一个特征就跟这个图片里其他的特征都会有交互了,这样大概就知道哪块是哪个物体,哪块又是另外一个物体,对于同一个物体来说就只应该出一个框而不是出好多个框。这种全局的特征非常有利于去移除这种冗余的框。 - 做完第二步全局建模之后,第三步就是用
Transformer decoder
生成框的输出。这里作者其实没有画object query
,但其实这个object query
挺重要的。当有了图像特征之后,这里还会有一个object query
,这个query
其实就是限定了要出多少个框。通过这个query
和特征不停的去做交互,在这个decoder
里去做自注意力操作从而得到最后的输出的这个框。在论文里作者选择的这个框数是100,是一个固定的值。意味着不论是什么图片最后都会预测出来100个框,那现在问题来了出的这100个框怎么去和Ground Truth
做匹配? 然后去算 loss ? - 第4步是文章最重要的一个贡献,文章把这个问题看成是一个集合预测的问题。最后用
二分图匹配
的方法去算这个loss
,比如这里Ground Truth
其实就只有两个框,在训练的时候通过算这100个预测的框和这两个Ground Truth
框之间的这种matching loss
,从而决定出在这100个预测中哪两个框是独一无二的。一旦决定好这个匹配关系之后,就像普通的目标检测一样去算一个分类的loss
,再算一个bounding box
的loss
。至于那些没有匹配到Ground Truth
的框,也就是剩下的98个框,就会被标记成没有物体也就是所谓的背景类。
总结一下第一步用卷积神经网络抽特征、第二步用 Transformer encoder
去学全局特征帮助后面做检测、第三步
就是用 Transformer decoder
去生成很多的预测框、第四步就是把预测框和 Ground Truth
的框做一个匹配,然后最后在匹配上的这些框里面去算目标检测的 loss
。
这样整个模型就能训练起来了,那推理的时候 DETR
是怎么做的呢?同样的道理,前三步都是一致的,只有第四步不一样。因为训练的时候,需要去算这个二分图匹配的 loss
,但是在做推理的时候 loss
是不需要的,直接在你在最后的输出上用一个阈值去卡一下输出的置信度。文章中置信度大于0.7的预测就会被保留下来,也就是所谓的前景物体。那剩下所有的那些置信度小于0.7的就被当成背景物体。可以看到 DETR
不论是在训练的时候还是在做推理的时候,都没有 anchor
生成的这一步而且也都没有用到 nms
。
介绍完了模型架构,作者在引言最后说了下结果,而且还顺带提了几个细节:
- 第一个结果呢就是在
COCO
这个数据集上,DETR
取得了跟之前一个Faster RCNN
的这个基线网络匹配的结果。就是不论从AP
性能上还是从模型大小和速度上都差不多。 - 而且尤其值得一提的是虽然这个
AP
结果差不多,但是大物体
的AP
和小物体的AP
差的还是非常远的,比如说DETR
,就对大物体表现特别的好。这个应该归功于使用了Transformer
,能做这种全局的建模。不论物体有多大,应该都能检测出来。而不像原来一样,如果使用Anchor
的话就会受限于Anchor
的大小,但是同时DETR
也有缺陷,比如在小物体上效果就不怎么样。但作者觉得这个还好,毕竟DETR
是一个新的框架,也需要时间去进化。像之前的目标检测器也都是通过了多年的进化才达到现在这个水平,比如说YOLO-v1, v2, v3, v3, v4, v5
,还有RCNN,Fast RCNN,Faster RCNN,Mask RCNN,Fpn RCNN,Cascade RCNN
。
作者说接下来肯定是会有后续工作来提高这一点的,就像几年前 FPN
对Faster rcnn
做的一样,通过使用多尺度 特征,能够提升这个小物体的检测。事实上确实如此,只不过这次不是像之前一样,用了一年半的时间才从 Fast RCNN
到 FPN
,这次只用不到半年的时间 Deformable DETR
就出来了。不仅很好的通过多尺度的特征解决了小物体检测问题,同时也解决了作者接下来提到的问题,就是DETR训练太慢
。
DETR
确实训练特别慢,想要达到很好的效果作者训练了500个 epoch
。对于 COCO
来说,一般就训练几十个 epoch
。所以500个 epoch
相当于是十倍于之前的训练时长。这里也是很值得我们学习的一点,就是当你改变了训练 setting
,使你这个方法跟之前的方法都没法去做公平对比的时候,怎样做才能让审稿人相信你的说法,才能让审稿人放过你。我们可以看一看 DETR
是怎么解决的。
最后作者强调了 DETR
不光是做检测的一个方法,它其实是一个框架可以拓展到很多别的任务上。作者说 DETR
的设计理念是为了能够适用于更多的复杂的任务,从而让这些复杂任务都变得简单,而且甚至有可能用一个统一的框架去解决所有的问题。事实上作者也确实做到了,不光是在这篇论文里能验证出 DETR
对全景分割有用,而接下来在很多别的后续工作里也验证了 DETR
的有效性。很快就把 DETR
拓展到了目标追踪、视频里的姿态预测、视频里的语义分割等等各种任务。所以 DETR
真的是遍地开花,也难怪那么多人觉得 DETR
有拿最佳论文的潜质。
2. 相关工作
在相关工作
这一节作者讲了三个部分:
- 作者先讲了下
集合预测
这个问题,以及之前大家一般都是用什么方法去解决这种集合预测的问题。因为视觉用集合预测去解决问题的不多,所以作者觉得有必要来科普铺垫一下。 - 然后第二部分作者就介绍了下
Transformer
,以及parallel decoding
。就是如何不像之前的那些Transformer decoder
那样去做自回归的预测,而是一股脑一口气把预测全都给你返回回来。 - 然后第三部分就是介绍了下
目标检测
之前的一些相关工作,这里就着重讲一下第三部分。
2.3节作者说现在的大部分的目标检测器都是根据已有的一些初始的猜测
然后去做预测。比如说对于 two-stage
的目标检测方法来说初始猜测就是中间的 proposal
;对于 single-stage
的目标检测方法来说初始猜测就是 anchor
或者说是物体的中心点。最近的一篇论文做了详细的比较,发现之前方法的性能跟刚开始的初始猜测非常相关,就是怎么去做后处理,得到最后的预测其实对最后的性能的影响是至关重要的。
作者这里就从两个方面来阐述了这件事情:
- 第一个就是用集合的思想来做,就是
set-based loss
,之前的可学习的NMS
方法或者说关系型网络都可以利用类似于自注意力的方法去处理物体之间的这种联系。从而导致最后能得出独一无二的预测,这样就不需要任何的后处理步骤了。但是这些方法的性能往往都比较低,那为了跟当时那些比较好的方法性能对齐,这些方法往往又用了一些人工干预。比如这里说的用手工设计的场景特征,去帮助模型进行学习但是DETR
的目标是想让目标检测做的尽可能的简单。所以不希望这个过程特别复杂,也不希望用到过多的人工先验的知识,这就是DETR
和之前的方法的区别。 - 第二个作者想讲的就是
Recurrent detector
,就是之前有没有类似的工作用encoder-decoder
的形式去做目标检测。也是有的不光是有做目标检测,而且要做实例分割的工作还是蛮多的。但是这些工作全都是15、16年的时候的工作,那个时候大家去做recurrent detector
全都用的是RNN
,因为用了RNN
所以说肯定是自回归的模型。因此这个时效性就会比较差,而且性能也会比较差。
而 DETR
我们可以看到,不光是利用了 Transformers Encoder
以后能得到更全局的信息,从而目标检测更容易做了。而且用了 Transformers
不带掩码机制的 decoder
, DETR
最后能够一起把目标框全都输出来,从而达到 Parallel Decoding
,所以实效性大大增强了。
但其实当我们看了相关工作之后,我们发现其实让 DETR
工作的最主要的原因还是因为使用了 Transformers
。因为之前大家也试过基于集合的这种目标函数,也利用匈牙利算法去做过这个二分图匹配,但是因为 backbone
出来的特征不够强,所以最后的性能不好。还得依赖一些先验知识或者人工的干预,从而达到更好的性能。那第二个方向也是一样,之前也有人已经用 RNN
去做过这种 encoder-decoder
方式,只不过因为没有用 Transformers
,所以说结果不够好,比不过当时的别的基线模型。 所以说白了,DETR 的成功还是 Transformers 的成功。
3. DETR
接下来我们就来看一下文章的主体部分,DETR
其实主要分了两节:第一节是3.1节,主要说了一下基于集合的目标函数到底是怎么做的,作者是如何通过一个二分图匹配把预测的框和 Ground Truth
的框连接在一起,从而计算目标函数;然后3.2节主要就是讲了 DETR
的具体模型结构,就是图一里说过的四步。
那之所以把目标函数放在前面,而把结构放在后面其实是因为这个结构相对而言还是比较标准的。用一个 CNN
去抽特征,然后再用一个 Transformer encoder-decoder
去做特征强化,最后出这个预测框相对而言还是比较好理解,而且比较标准的,没什么太多可以讲的。就算是 object query
是一个比较新的东西,但就是非常小的一个东西,几句话就解释清楚了。反而是集合目标函数是比较新的东西,正是因为有了这个目标函数所以才能达到一对一的出框方式,所以才能不需要 nms
,所以作者这里就介绍了目标函数。
作者这里说 DETR
最后的输出是一个固定的集合
,不论输入图片是什么,最后都会有 N
个输出,在这篇论文里 N
就是100。也就是说任何一张图片进来,最后都会扔出来100个框,一般而言 N=100
,应该是比一张图片中所包含的物体个数要多很多的。普通的一张图片可能里面就含有几个或者十几个物体,尤其是对于 COCO
数据集而言里面一张图片里包含的最大个体数也没有超过100,所以作者这里把 N
设成100就足够用了。
但是现在问题就来了,DETR
每次都会出100个输出,但是实际上一个图片的 Ground Truth
的 Bounding box
可能只有几个,那如何去做这种匹配,如何去算 loss
,怎么知道哪个预测框对应哪个 Ground Truth
框呢?所以这里作者就把这个问题转化成了一个二分图匹配的问题。
那接下来就来先说一下二分图匹配到底是个什么问题?网上大部分的讲解都是给了这么一个例子:如何分配一些工人去干一些活,从而能让最后的支出最小。就比如说现在有三个工人abc
,然后要去完成三个工作,分别是xyz
,然后因为每个工人各自有各自的长处和短处,所以说他们完成这些工作需要的时间或者需要的工钱也不一样。所以最后就会有这么一个矩阵,每个矩阵里就有他们完成这些任务所需要的时间或者说完成任务所需要的钱。然后这个矩阵就叫 cost matrix
,也就叫损失矩阵。最优二分图匹配的意思就是最后能找到一个唯一解,能够给每个人都去分配他对应最擅长的那项工作,然后完成这三个工作,最后的价钱最低。那其实说的更直白一点,就是用遍历的算法也可以把这个做出来,把所有的排列组合都跑一遍。最后看哪个花的钱最少就选哪个就好了。
但是这样做复杂度肯定就非常高了,因此就有很多的算法来解决这个问题。匈牙利算法
就是其中一个比较有名而且比较高效的算法。其实一般遇到这个问题大家都有成熟的解决方案,大家一般就会用 Scipy
包里提供的一个函数,叫做 linear sum assignment
去完成这个。这个函数的输入其实就是 cost matrix
,只要输入 cost matrix
就能算出来一个最优的排列。在 DETR
这篇论文的代码里也用的是这个函数。
仔细一想,目标检测其实也是一个二分图匹配的问题,可以把 abc
看成是预测的框,然后把 xyz
看成是 Ground Truth
的框, cost matrix
不一定是个正方形,也可以是长方形。总之都是可以扔到函数里去得到一个最优匹配的。
现在的问题对于目标检测来说就是 cost matrix
里面的值应该放什么?从 cost matrix
就可以看出来它里面应该放的是 cost
,也就是损失。那损失是什么,其实就是 loss
,对目标检测来说 loss
包含两个部分:
- 一个就是分类对不对,分类的
loss
; - 另一个就是出框的准确度,框预测的准不准,也就是遍历所有的预测框,然后拿预测的框去和
ground Truth
的框去算这两个 loss。
然后把这个 loss
放到 cost matrix
就可以了,一旦有了 loss
就可以用 Scipy
的 linear sum assignment
,匈牙利算法就得到最后的最优解。作者这里其实也给了一些注释:找最优的匹配跟原来利用人的先验知识去把预测和之前的 proposal
或者 Anchor
做匹配的方式是差不多的,都是一个意思。只不过约束更强就是一定要得到一个一对一的这个匹配关系。而不是像原来一样是一对多的,也就说现在只会有一个框跟 Ground Truth
框是对应的。这样后面才不需要去做后处理 nms
,一旦做完了这个匹配这一步骤,就可以算一个真正的目标函数,然后用这个 loss
去做梯度回传更新模型的参数了。
准确的说最后的目标函数其实还是两个 loss
,一个是分类的 loss
,一个是出框的 loss
。DETR
这里作者做了两个跟普通的目标检测不太一样的地方:
- 第一个就是一般大家在做分类的
loss
的时候都是用log
函数去算的,但是作者为了让loss
和后面loss
大概在同样的取值空间,作者把log
给去掉了,作者发现这样的结果最后会稍微好一些。 - 那第二个改动就是在
bounding box
这块, 之前的工作一般都是用一个L1 loss
就可以了,但是对于DETR
来说用L1 loss
可能就有点问题。因为L1 loss
跟出框的大小有关系,框越大最后算出来loss
就容易越大。之前也说过用了Transformer
这种全局的特征,所以对大物体很友好经常会出大框。那出大框loss
就会大,所以就不利于优化。所以作者在这里不光是使用了一个L1 loss
,还用了一个generalized iou loss
,generalized iou loss
就是一个跟框大小无关的目标函数。作者这里就用了L1 loss
和generalized iou loss
的一个合体。
总之最后的这个目标函数跟普通的目标检测来说也是差不多的,只不过是先去算了一个最优匹配,然后在最优匹配的上面再去算这个 loss
,这个就是 DETR
的基于集合预测的目标函数,从而能做到出的框和 ground truth
能够做到一一匹配而不需要 nms
。
主体方法第二部分就是 DETR
整体网络框架。论文里其实有很多的细节,我们只用看图二就已经基本知道 DETR
在干什么了。图二也就是图一的一个升级版,图一画的很简单就是让大家看清楚 DETR
的 flow
是什么样。图二就把每个步骤画的更清晰了一些。
- 这里输入图片的大小是 3x800x1066,3是
rgb channel
,在检测分割中一般输入图片都会大一些。第一步就是通过卷积网络去得到一些特征,卷积网络的最后一层conv5
的时候会得到一个特征:2048x25x34,25和34就分别是之前的800和1066的1/32,2048就是对应的通道数。然后接下来因为要把特征扔给一个Transformer
,所以作者这里做了一次1x1的降维操作,就把2048变成了256,从卷积神经网络出来的特征维度就是 256x25x34。那接下来因为Transformers
是没有位置信息,所以要给他加入位置编码。在这里面位置编码其实是一个固定的位置编码,维度大小也是 256x25x34,目的很简单就是因为这两个东西是要加在一起的,所以说维度必须得一致。现在这两个东西加到一起之后,其实就是Transformer
的输入了。这时候只需要把这个向量拉直,拉直的意思就是说把 h 和 w 拉直变成一个数值,也就是变成了 850x256。 - 850就是序列长度,256就是
Transformer
的head dimension
。那接下来就跟一个普通的Transformer encoder
或者跟我们之前讲过的VisionTransformer
是一样的,输入是 850x256,不论经过多少个Transformer encoder block
最后的输出还是 850x256。在DET
里作者使用了6个Encoder
,也就说会有6个Transformer Encoder
,这样叠起来那第二步就走完了。 - 第三步就是进入一个
Decoder
,然后去做框的输出。这里面有一个新东西,也就是我们之前反复提高过的object query
,这个object query
其实是一个learnable embedding
。它其实是可以学习的,而且准确的说它是一个learnable
的positional embedding
。它的维度100x256,256是为了和之前的256对应。然后100的意思也就是说告诉模型最后要得到100个输出。其实也可以把它理解成是一种Anchor
的机制或者是一种condition
。在这个Transformer Decoder
里头其实做的是一个cross Attention
,比如说输入object query
是100x256, 然后我还有另外一个输入是从这个图像端拿来的全局特征:850x256, 这时候拿 850x256 和 100x256去反复的做自注意力操作,最后就得到了一个100x256的特征。同样的道理DETR
里也用了6层Decoder
,也就说这里是有6个Transformer Decoder
叠加起来的,每层的输入和输出的维度也都是不变的,始终是100x256进、100x256出。 - 拿到 100x256 的最终特征之后,最后就是要做预测了。就是要在上面加一个检测头,检测头其实是比较标准的。就是加一个
feed forward network
,准确的说就是把特征给这些全连接层,然后全连接层就会做两个预测:一个是物体类别的预测,一个是出框的预测。一旦有了这100个预测框,就可以去和Ground Truth
去做最优匹配,然后用匈牙利算法去算得最后的这个目标函数,梯度反向回传然后来更新这个模型,这样一个端到端的可以学习的DETR
模型就实现了。
模型这里还有很多的细节:
- 在补充材料里作者就强调说在
Transformer Decoder
里在每一个Decoder
都会先做一次object query
的自注意力操作。但其实在第一层Decoder
这里可以不做,是可以省掉。但是后面那些层都不能省掉Object Query
做自注意力的操作。主要就是为了移除冗余框,因为他们之间互相通信之后就知道每个Query
可能得到什么样的一个框,尽量不要去做重复的框。 - 另外一个细节就是在最后算
loss
的时候,作者发现为了让模型收敛的更快或者训练的更稳定,在每个Decoder
的后面加了auxiliary loss
,就是加了很多额外的目标函数。这是一个很常见的trick
,在检测、分割里,尤其是分割里是用的非常常见的。就是说不光在最后一层去算这个loss
,在之前的Decoder
里也算loss
。因为其实每个Decoder
的输出都是 100x256,你都可以在这个输出 100x256上去做ffn
然后得到输出。作者这里就是在每一个Decoder
的后面,这里是在6个Decoder
的后面都加了ffn
,而去得到了目标检测的输出。然后就算了loss
,当然这里面这些ffn
全都是共享参数。
4. 实验
接下来我们就来看一下实验,主要就是看几个表格和图。看一看 DETR
跟 Faster RCNN
比到底是什么水平;以及为什么要用 Transformer encoder
,到底好在哪里。还有就是通过可视化看看 object query
到底学到了什么。
首先来看表一,表一里对比了一下 DETR
和之前的 Faster RCNN
性能上到底如何。表里的上半部分是 detectron2
里的 Faster RCNN
的实现,然而因为在这篇论文里 DETR
还用了一些别的技巧:比如说使用 glou loss
、使用了更多的 Data Augmentation
、训练了更长的时间。作者又在中间把这些同样的模型又重新训练了一遍,就是用这些更好的训练策略去把上面这三个模型又训练了一遍。所以我们可以看到模型都是一个模型,只不过在后面用加号来表示是提升版,所以这就是为什么 GFLOPS 、FPS
还有模型参数量都是一样的,只不过是训练的策略不一样。
但是我们也可以看出来训练的策略非常重要,训练策略改一改性能就提升蛮大的,基本稳健地都能提升两个点。作者这里接下来做对比其实就是表中中间和下面的对比,因为它们的训练设置差不多,所以相对而言比较公平。但毕竟一个是用 cnn
,一个是用 Transformers
,所以说也很难强求他们什么都保持一致。在目前这种情况下其实他们的训练设置已经差不多了。
那如果我们来对比一下,就是 Faster RCNN-DC5+
模型和 DETR-DC5
模型。
- 可以看到虽然
fps
差不多,当然Faster RCNN-DC5+
跑的更快一些,但是DETR-DC5
的flops
会小,187只有320的差不多一半。所以这也就牵扯到另外一个问题,好像gflops
越小,模型就越小跑起来就越快,其实不尽然,gflops
跟这个没有什么关系。 - 那如果从
AP
角度上来看,Faster RCNN-DC5+
只有41,但DETR
有43.3,高了两个多点,还是相当不错的。同样的道理我们再把BackBone
换成Resnet101
,再把fpn
也加上,然后去跟DETR-DC5-R101
去比。会发现他们的这个参数量一样都是60兆,然而gflops
也都是250左右,但是DETR
还是慢了一倍。虽然DETR
跑的慢,但是效果还是高一些的,AP
能达到44.9。 - 作者更希望大家意识到的是这个细分类的比较,就说到底对小物体和大物体而言,
DETR
和Faster RCNN
差在哪里。我们可以看到对于小物体而言Faster RCNN
表现的非常好,比DETR
要高4个AP
。但是对于大物体,DETR
反而比Faster RCNN
高了6个AP
,这个结果就非常有意思。作者这里觉得是因为使用了Transformers
全局建模而且没有用Anchor
,所以说DETR
想检测多大的物体就可以检测多大的物体,对大物体检测比较友好。
但是因为DETR
实在是太简单了,没有使用多尺度的特征,也没有用fpn
,也没有用一个非常复杂的目标检测头,所以说在小物体的处理上就不是很好。
一起来看下作者这里放的第一个可视化图,图3里作者其实是把 Transformer encoder
的自注意力给可视化出来了。比如说在这个图里有几头牛,然后如果我们在这些牛身上点一些点,以这些点作为基准点。然后我们用这个基准点去算这个点和图像里所有其他点的自注意力,然后我们看看自注意力是怎么分布的。会惊奇的发现自注意力已经做的非常非常好了,在这头牛上做自注意力,其实基本已经把这头牛的形状给恢复出来了,甚至已经有一点实例分割出来的那个 mask
的形状了。而且即使对于遮挡这么严重的情况(比如这头小牛在这头大牛身子下面)还是能把两头牛分得相当清楚,所以这也就是使用 Transformer Encoder
的威力。使用 Transformer Encoder
之后,其实图片里的物体就已经分的很开了,这个时候再在上面去做 decoder
,去做目标检测或者分割,相对而言任务就简单很多。
看完了 Transformer
编码器,那我们来看一看 Transformer
解码器。因为它也是一个 Transformer block
,所以也有自注意力可以去可视化。在经过了 Transformer encoder
之后我们为什么还需要一个 Decoder
呢?然后这个 Decoder
到底在学什么呢?下面这个图其实非常有意思,而且把 DETR
的好处展现的是淋漓尽致。
这里把每个物体的自注意力用不同的颜色表示出来了,比如说对于这两头大象小的这头象用的是黄色,大的这头象用的是蓝色。我们可以发现即使在遮挡这么严重的情况下,比如说后面那头大象的蹄子还是可以显示出蓝色,还有这么细的尾巴也是蓝色。而对于小象来说,其实大象的皮肤都是非常接近的,但是还是能把轮廓分的非常清楚。对于右图这个斑马,其实斑马本身背上的花纹就已经很吸引人眼球了,是一个很强的一个特征。而且遮挡又这么严重,但即使是在这种情况下 DETR
还是能准确的分出蓝色斑马的边界还有它的蹄子、绿色斑马的边界和蹄子、还有黄色斑马的轮廓。
所以作者这里说 Transformer Encoder
和 Decoder
一个都不能少。Encoder
在学什么呢,Encoder
在学一个全局的特征。它是尽可能的让物体和物体之间的分得开,但是光分开还是不够的,对于这些头、尾巴极值点、最外围的这些点该怎么办呢? 把这个就交给 Decoder
去做了,因为 Encoder
已经把物体都分好了,所以 Decoder
接下来就可以把所有的注意力都分到去学边缘了,去怎么更好的区分物体以及解决这种遮挡问题。
其实这有点像之前做分割的时候用的 U-net
结构,encoder
去抽一个更有语义的特征,然后 Decoder
一边一点一点把图片大小恢复出来,另外一边把更多的细节加进去,从而能达成最后的这个分割效果。总之 DETR
用这种 encoder-decoder
的方式,目的其实跟之前 cnn
里面用 encoder-decoder
的目的差不多是一致的。最后达成的这个效果也是差不多一致的,只不过换成 Transformer
之后效果就更拔群了。
另外一个比较有意思的图就是 object Query
的可视化,作者这里把 coco validation set
里所有的图片得到最后的所有输出框全都可视化了出来。虽然 n 是100但这里只画了20个正方形,每一个正方形就代表一个 object query
,作者其实就想看看 object query
到底学了什么。从图7可以看出来,图里这些绿色的点代表小的 bounding box
,红色的点代表大的横向的 bounding box
,蓝色这些点代表竖向的大的 bonnding box
。
看完这些分布之后就会发现,原来 object query
其实跟 Anchor
还是有一些像的,anchor
只不过是提前先去定义好的一些 bounding box
,最后把预测跟这些提前定好的 bounding box
去做对比。而 object query
其实是一个可以学习的东西,比如说第一个 object query
,他学到最后就相当于每次给他一个图片,第一个 object query
都会去问图片的左下角有没有看到一些小的物体。
总之这100个 Query
就相当是100个不停的问问题的人一样,每个人都有自己问问题的方式,每当一个图片来的时候,他们就按照各自自己的方式去问图片各种各样的问题。如果找到了合适的答案,就把答案返回给你,而这个答案其实就是对应的那个 bounding box
。如果没找到就返回给你什么也没有,然后作者这里还发现所有的这些 Object query
中间都有红色的竖线,意思就是说每个 object query
都会去问这个图片中间有没有一个大的横向物体。之所以学出来了这个模式是因为 COCO
这个数据集的问题,因为 COCO
这个数据集往往在中心都有一个大物体,就跟 ImageNet
一样经常有一个大物体是占据整个空间的。从这个角度讲其实 DETR
又是一个非常深度学习的模型,就是说很简单什么参数也不用调,一切都是端到端。在训练之前你也不知道 object Query
是个什么东西,反正我给你数据你就学去吧,结果最后发现学的还挺好,object query
学的还有模有样直接就把生成 Anchor
这一步呢给取代了。
5. 总结
最后来总结一下这篇论文:
- 作者说他们提出了
DETR
,一个全新的目标检测框架。主要是利用Transformers
而且用了二分图匹配
,最后导致这个框架是一个端到端可以学习的网络。 - 然后在
COCO
数据集上DETR
跟之前的Faster RCNN
基线模型打成了平手,然后在另外一个任务全景分割上反而取得了更好的结果,而且因为它的这种简单应用性还有很大的潜力去应用到别的任务上。 - 最后因为在
COCO
数据集上只有44的AP
,实在是有点低,所以作者不去强调这件事情,反而去强调了另外一件事。就是DETR
在什么上面做的最好,DETR
是在大物体上做的最好。所以作者这里就强调了一下DETR
检测大物体的效果特别好,原因就是因为用了自注意力所带来的全局信息。 - 因为这篇文章是一个比较新的想法,也是一个比较新的结构,肯定不可能是完美的。所以说与其让大家去挑刺,作者这里说还不如我自己就把缺点展示一下。缺点就是训练时间太长了,而且因为用
Transformers
,所以可能不好优化,然后性能也不是很好,尤其是在小物体上的性能。现在的这些特别好的检测器全都是用了好几年的提升才把这些问题一一解决的,那我们觉得因为我们这个框架特别好,特别容易用,肯定会有很多后续工作立马就DETR
的这些问题成功的解决了。
6. DETR预训练demo
这里使用作者提供的 notebook
演示 DETR
的效果,地址为:https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb。
导入库文件:
from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False);
定义 DETR
模型:
class DETRdemo(nn.Module):"""Demo DETR implementation.Demo implementation of DETR in minimal number of lines, with thefollowing differences wrt DETR in the paper:* learned positional encoding (instead of sine)* positional encoding is passed at input (instead of attention)* fc bbox predictor (instead of MLP)The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.Only batch size 1 supported."""def __init__(self, num_classes, hidden_dim=256, nheads=8,num_encoder_layers=6, num_decoder_layers=6):super().__init__()# create ResNet-50 backboneself.backbone = resnet50()del self.backbone.fc# create conversion layerself.conv = nn.Conv2d(2048, hidden_dim, 1)# create a default PyTorch transformerself.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)# prediction heads, one extra class for predicting non-empty slots# note that in baseline DETR linear_bbox layer is 3-layer MLPself.linear_class = nn.Linear(hidden_dim, num_classes + 1)self.linear_bbox = nn.Linear(hidden_dim, 4)# output positional encodings (object queries)self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))# spatial positional encodings# note that in baseline DETR we use sine positional encodingsself.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):# propagate inputs through ResNet-50 up to avg-pool layerx = self.backbone.conv1(inputs)x = self.backbone.bn1(x)x = self.backbone.relu(x)x = self.backbone.maxpool(x)x = self.backbone.layer1(x)x = self.backbone.layer2(x)x = self.backbone.layer3(x)x = self.backbone.layer4(x)# convert from 2048 to 256 feature planes for the transformerh = self.conv(x)# construct positional encodingsH, W = h.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)# propagate through the transformerh = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1)).transpose(0, 1)# finally project transformer outputs to class labels and bounding boxesreturn {'pred_logits': self.linear_class(h), 'pred_boxes': self.linear_bbox(h).sigmoid()}
加载预训练模型:
detr = DETRdemo(num_classes=91)
state_dict = torch.hub.load_state_dict_from_url(url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',map_location='cpu', check_hash=True)
detr.load_state_dict(state_dict)
detr.eval();
COCO
数据集类别映射:
# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
一些辅助函数:
# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return b
构建检测函数:
def detect(im, model, transform):# mean-std normalize the input image (batch-size: 1)img = transform(im).unsqueeze(0)# demo model only support by default images with aspect ratio between 0.5 and 2# if you want to use images with an aspect ratio outside this range# rescale your image so that the maximum size is at most 1333 for best resultsassert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'# propagate through the modeloutputs = model(img)# keep only predictions with 0.7+ confidenceprobas = outputs['pred_logits'].softmax(-1)[0, :, :-1]keep = probas.max(-1).values > 0.7# convert boxes from [0; 1] to image scalesbboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)return probas[keep], bboxes_scaled
检测图片:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)scores, boxes = detect(im, detr, transform)
可视化预测:
def plot_results(pil_img, prob, boxes):plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}: {p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()plot_results(im, scores, boxes)