Conditional DETR解读---带anchor的DETR

DETR存在的问题

1.收敛速度慢

2.对小目标物体检测效果不好,因为transformer计算量大,受限于计算规模,CNN提取特征时只采取了最后一层特征,没有用FPN等结构。所以对于小目标检测效果不好。

论文主要观点

  • 通过对DETRdecoder中的attentionmap进行可视化,发现query查询到的区域都是物体的extremity末端区域。所以论文认为attention尝试找到物体的边界区域。

  • 论文中认为DETRtransofmer结构中的信息主要可以分为两部分,一部分是与图像的特征(颜色纹理等)相关的信息,称为content,比如encoder或decoder的输出信息。另一部分是代表空间上的信息,称为spatial,比如position embedding等。

  • detr中的CNN与encoder只涉及图像特征向量提取;decoder中的self-attn只涉及query之间的交互去重;所以收敛慢的最可能原因发生在cross attn

  • Cross attention中的K包含encoder输出信息(content key Ck)与position embedding(spatial Key Pk),Q包含self attention的输出(content query Cq)和object query(spatial query Pq)信息。论文中发现去掉cross attention中的object基本不掉点,所以收敛慢很可能是content query难学习导致的。

  • 提出了reference point的概念,为每个query设定一个检测范围,使得匹配更加稳定,加快了收敛

  • 原始detr混合两者学习,使得content query难学习。所以将content与spatial进行解耦

在这里插入图片描述

变为

在这里插入图片描述

网络结构

在这里插入图片描述

对于object query生成了一个2D坐标embedding(上图中的s),用于限定当前query的预测范围。最终decoder的输出的是相对与s的偏移量

bbox回归输出

在这里插入图片描述

其中f是decoer的输出,S表示x,y的坐标。最终b是[x,y,w,h]的向量。

classifier分类输出

在这里插入图片描述

f是decoder的输出,输出每个候选框的类别

decoder Pq生成:

提出了reference point的概念,即图中的s,是一个2d的坐标(q_num,B,2),由object queries经过一个线性层生成,代表了每个query的预测范围。

s经过sigmoid和position embedding后(图中的Ps),跟FFN(decoder embedding)(即图中的T)做内积。得到空间特征Pq

在这里插入图片描述

在这里插入图片描述

代码spatial query这一部分的实现:

# query_pos [num_query,batch,d_model]
# reference_points_before_sigmoid [num_query,batch,2]  从query预测一个坐标,代表了这个query预测的大概范围
reference_points_before_sigmoid = self.ref_point_head(query_pos)    # [num_queries, batch_size, 2]
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
for layer_id, layer in enumerate(self.layers):# 图里的s,代表了query的预测大概范围obj_center = reference_points[..., :2].transpose(0, 1)      # [num_queries, batch_size, 2]# For the first decoder layer, we do not apply transformation over p_s## pos_transformation代表图里的T,表示decoder embedding的特征经过ffn后其实得到的是相对于s的偏移量if layer_id == 0:pos_transformation = 1else:pos_transformation = self.query_scale(output)# get sine embedding for the query vectorquery_sine_embed = gen_sineembed_for_position(obj_center)     # apply transformation# 最终的Pq,代表空间特征信息query_sine_embed = query_sine_embed * pos_transformationoutput = layer(output, memory, tgt_mask=tgt_mask,memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask,pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,is_first=(layer_id == 0))

decoder中cross attention的实现


# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape# k的位置编码
k_pos = self.ca_kpos_proj(pos)# For the first decoder layer, we concatenate the positional embedding predicted from 
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first:q_pos = self.ca_qpos_proj(query_pos)q = q_content + q_posk = k_content + k_pos
else:q = q_contentk = k_contentq = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
# decoder embedding cat spatial query
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model//self.nhead)
# encoder embdeding cat position embedding
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)tgt2 = self.cross_attn(query=q,key=k,value=v, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]               
# ========== End of Cross-Attention =============

head的实现

# hs代表decoder embedding,reference代表s(reference point)
hs, reference = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])
reference_before_sigmoid = inverse_sigmoid(reference)
outputs_coords = []
for lvl in range(hs.shape[0]):# 回归head hs输出相对于 reference的偏移量,得到检测框tmp = self.bbox_embed(hs[lvl])tmp[..., :2] += reference_before_sigmoidoutputs_coord = tmp.sigmoid()outputs_coords.append(outputs_coord)
outputs_coord = torch.stack(outputs_coords)
#分类head,hs输出分类结果
outputs_class = self.class_embed(hs)

总结思考

实际上conditional DETR有点像transfoermer版本的faster-RCNN。将特征信息与空间信息进行了解耦。reference point像anchor的概念,让网络自己为每个query设定一个anchor范围,从而使得二分匹配更加问题,所以加快了网络的收敛

作者论文解读:https://zhuanlan.zhihu.com/p/401916664
公式解释得更加详细

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/845223.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为昇腾910B性能及应用场景

华为昇腾910B是一款高性能的人工智能处理器芯片,以下是其性能参数及应用的相关信息: 一、性能参数 制造工艺:昇腾910B采用了先进的7nm工艺制程,确保了其高效能低功耗的特性。核心数量:集成了数千个处理核心&#xff…

【退役之Java面试经历】第一次面试记录和复盘, Action!

一、简历 两段工作经历,四个项目 二、面试 技术面试 总体还行,关于 redis 和 rabbitmq 以及 spring boot,spring cloud 的知识,回答得还可以。但是,还问到了 “单点登录”、“撰写需求分析文档和操作手册”等盲点。…

代码随想录训练营Day 45|力扣1049. 最后一块石头的重量 II、494. 目标和、474.一和零

1.最后一块石头的重量2 视频讲解&#xff1a;动态规划之背包问题&#xff0c;这个背包最多能装多少&#xff1f;LeetCode&#xff1a;1049.最后一块石头的重量II_哔哩哔哩_bilibili 代码随想录 代码&#xff1a; class Solution { public:int lastStoneWeightII(vector<int…

芋道系统,springboot+vue3+mysql实现地址的存储与显示

1.效果图 2.前端实现&#xff1a; <el-form-item label"地址" prop"entrepriseAddress"><el-cascaderv-model"formData.entrepriseAddress"size"large":options"region"/></el-form-item> //导入组件 im…

k8s中pod如何排错?

排除Kubernetes Pod故障通常涉及一系列步骤&#xff0c;以诊断问题并找到解决方案。以下是一些常见的故障排除方法&#xff1a; 检查Pod状态: 使用kubectl get pods查看Pod的状态。如果Pod没有处于Running状态&#xff0c;查看更详细的信息&#xff0c;使用kubectl describe …

【JMeter接口自动化】第7讲 Jmeter三个重要组件

线程组:是JMeter中最基本的元素之一&#xff0c;用于模拟并发用户访问目标系统。线程组定义了测试计划中的用户数量、用户行为和用户请求之间的关系。 添加方法:测试计划->添加->线程(用户)->线程组 在线程组中&#xff0c;您可以设置以下参数&#xff1a; 线程数&a…

一种改进的形态学滤波算法-以心电信号的基线校正和噪声抑制为例(MATLAB环境)

信号在釆集和传输过程中难免受到噪声源的干扰&#xff0c;反映非线性动力学行为的特征信息有可能被噪声所掩盖。尤其是在混沌振动信号噪声抑制方面&#xff0c;因为混沌信号的高度非线性及宽频特性&#xff0c;噪声和混沌信号往往具有重叠的带宽。传统的时域及频域降噪方法效果…

高通Android 12/Android 13截屏

正常截屏Power音量-键 组合键同时长按&#xff0c;实现截屏逻辑。 PhoneWindowManager #init #interceptScreenshotChord init(Context context, IWindowManager windowManager,WindowManagerFuncs windowManagerFuncs) 2、在PhoneWindowManager中init方法中注册广播 framew…

神经网络-------人工神经网络

一、什么是神经网络和神经元 人工神经网络&#xff08;英语&#xff1a;Artificial Neural Network&#xff0c;ANN&#xff09;&#xff0c;简称 神经网络&#xff08;Neural Network&#xff0c;NN&#xff09;或 类神经网络&#xff0c;是一种模仿生物神经网络&#xff08;…

AI实时免费在线图片工具3:人物换脸、图像编辑

1、FaceAdapter 人物换脸 https://huggingface.co/spaces/FaceAdapter/FaceAdapter 2、InstaDrag https://github.com/magic-research/InstaDrag

cnc编程与nc编程的区别:深入解析两者之间的核心差异

cnc编程与nc编程的区别&#xff1a;深入解析两者之间的核心差异 在制造业中&#xff0c;编程技术是实现高精度、高效率加工的关键环节。CNC编程和NC编程作为两种重要的编程方式&#xff0c;各自具有独特的特点和应用场景。然而&#xff0c;对于初学者或非专业人士来说&#xf…

计算机网络之快重传和快恢复以及TCP连接与释放的握手

快重传和快恢复 快重传可以让发送方尽早得知丢失消息&#xff0c; 当发送消息M1,M2&#xff0c;M3,M4,M5后,假如消息M2丢失&#xff0c;那么按照算法会发送对M2报文前一个报文M1的重复确认&#xff08;M1正常接受到&#xff0c;已经发送了确认),然后之后收到M4,M5,也会发送两…

做项目时,怎么运用 SWOT 分析法进行项目或决策分析?

SWOT分析法是一种常用的战略工具&#xff0c;用于评估项目或决策的优势、劣势、机会和威胁。以下是在项目或决策分析中如何运用SWOT分析法的一般步骤&#xff1a; 步骤1&#xff1a;明确分析的目标 在进行SWOT分析之前&#xff0c;首先要明确分析的目标是什么。你可能想要分析…

element-plus关于表单数据自定义参数校验

element-plus关于表单数据自定义参数校验 核心点&#xff1a; 代码&#xff1a; <el-form-item :prop"tableData[ scope.$index ].score":rules"[{ required: true, message: 得分不能为空, trigger: blur },{ validator: (rule: any, value: any, ca…

【Python】解决Python报错:AttributeError: ‘generator‘ object has no attribute ‘xxx‘

&#x1f9d1; 博主简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

网络分层与各层网络协议介绍

一.OSI七层模型 1.OSI&#xff08;Open Systems Interconnection&#xff09;七层模型是由国际标准化组织&#xff08;ISO&#xff09;提出的一种网络通信协议的参考模型&#xff0c;用于标准化网络通信的过程。 OSI模型将网络通信分为七个层次&#xff0c;每个层次负责不同的…

Java集合-List(Collection子接口)及其子类(ArrayList、Vector、LinkedList)

List接口是 Collection接口的子接口。 1、List集合类中数据有序&#xff0c; 即添加顺序和取出顺序有序&#xff0c;而且可以重复。 2、List集合类中每个元素都有其对应的顺序索引&#xff0c;即支持索引。例&#xff0c;list.get(2)&#xff1b;取第三个元素。 3、实现类有很多…

cocosCreator2.x滑动组件全面版一体化(虚拟列表+分页列表+普通列表改良版)

/******************************************* author kL <klk0qq.com>* date 2019/6/6* doc 列表组件.* end******************************************/ /* eslint-disable */ const { ccclass, property, disallowMultiple, menu, executionOrder, requireComponent…

家政预约小程序10公众号集成

目录 1 使用测试号3 工作流配置4 配置关注事件脚本5 注册开放平台6 获取公众号access_token6 实现关注业务逻辑总结 我们本次实战项目构建的相当于一个预约平台&#xff0c;既有家政企业&#xff0c;也有家政服务人员还有用户。不同的人员需要收到不同的消息&#xff0c;比如用…

99.网络游戏逆向分析与漏洞攻防-ui界面的设计-角色信息显示的界面与功能

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果&#xff0c;代码看不懂是正常的&#xff0c;只要会抄就行&#xff0c;抄着抄着就能懂了 内容…