DETR整体模型结构解析

DETR流程

  1. Backbone用卷积神经网络抽特征。最后通过一层1*1卷积转化到d_model维度fm(B,d_model,HW)。

  2. position embedding建立跟fm维度相同的位置编码(B,d_model,HW)。

  3. Transformer Encoder,V为fm,K,Q为fm+position embedding。因为V代表的是图像特征。所以不添加位置编码

  4. Transformer Decoder。生成一个固定大小(query_num)的object query(B,q_num,d_model)比如100个预测框。Decoder输入tgt与object query形状相同。代码中为torch.zero()。第一层selfattention K,V为tgt+query,Q为tgt。第二层Q为上一层输出+query。V为encoder输出,K为encoder输出+position。这里V仍然代表图像特征所以不添加位置编码

  5. 用输出的100个object query框和ground truth框做一个匹配,然后在一一配对好的框中去计算目标检测的loss(分类loss与回归loss(L1+IOU))

  6. 二分图匹配与匈牙利算法

    DETR 预测了一组固定大小的 N = 100 个边界框

    将 ground-truth 也扩展成 N = 100 个检测框

    使用一个额外的特殊类标签 ϕ 来表示在未检测到任何对象,或者认为是背景类别。

    这样预测和真实都是两个100 个元素的集合了

    采用匈牙利算法进行二分图匹配,对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。

  7. 推理过程不需要二分图匹配,只需要取最大得分框即可

代码详细参考:

transformer 在 CV 中的应用(二) DETR 目标检测网络 -

网络结构

参数说明:B:batchsize大小,C通道数,H,W:CNN输出特征图的高宽。d_model设定的特征维度大小如512。
Q,K,V:自注意力矩阵。l_q:Q矩阵的长度,l_kv:K,V矩阵的长度。KV矩阵的长度必须相同,Q矩阵长度可以跟KV矩阵长度不同
Q矩阵维度:(B,l_q,d_model)
K矩阵维度:(B,l_kv,d_model)
V矩阵维度:(B,l_kv,d_model)
object_query维度(B,q_num,d_model)

Backbone:

img→CNNbackbone→fm特征图(B,C,H,W) → fm特征图输入到transformer中时要再经过一层卷积将通道数转化成d_ model。C→d_model.

position embedding(B,d_model,H*W)。backbone通过CNN提取图像特征,然后通过特征图生成尺度对应的位置编码。

position embedding:

位置编码官方实现了两种,一种是固定位置编码,另一种是自学习位置编码,这里就介绍固定位置编码。

位置编码要考虑 x, y 两个方向,图像中任意一个点 (h, w) 有一个位置,这个位置编码长度为 256 ,前 128 维代表 h 的位置编码, 后 128 维代表 w 的位置编码,把这两个 128 维的向量拼接起来就得到一个 256 维的向量,它代表 (h, w) 的位置编码。位置编码的计算公式如下图所示

在这里插入图片描述

Transformer
DETRtransformer结构图
在这里插入图片描述

接受CNN提取的特征(B,d_model,HW),位置编码(B,d_model,HW),querys(B,query_num,d_model)

encoder:q,k添加位置编码。v代表图像本身特征,不添加位置编码。multi_head_attention跟FFN后都带了两个残差连接。

# post代表norm放在后面
def forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):q = k = self.with_pos_embed(src, pos)  #q,k增加positionsrc2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)   # 残差src = self.norm1(src)# ffnsrc2 = self.linear2(self.dropout(self.activation(self.linear1(src))))# 残差src = src + self.dropout2(src2)src = self.norm2(src)return src

decoder:

设定一个object queries(num_query,d_model)

有两层multihead self attention

  • 第一层obquery添加到K,Q上作为position embedding

第二层的Q来于decoder,K,V来自于encoder输出。

第二层self attention K添加编码,Q增加object queries。V代表图像特征,不添加任何信息

KV要有相同维度,Q可以跟KV在长度维度上不同,d_model维度相同

softmax(QKt/(d^0.5))V→矩阵乘法Q*Kt:(l_q,d_model)@(d_model_l_kv)→(l_q,l_kv)

再乘以V(l_q,l_kv)@(l_kv,d_model)→(l_q,d_model)

def forward_post(self, tgt, memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):''':param tgt: query_pos rep 2tensor of shape (bs, c, h, w) ->tgt = torch.zeros_like(query_embed):param memory::param tgt_mask::param memory_mask::param tgt_key_padding_mask::param memory_key_padding_mask::param pos::param query_pos::return:'''q = k = self.with_pos_embed(tgt, query_pos)tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0]tgt = tgt + self.dropout1(tgt2)tgt = self.norm1(tgt)tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)tgt = self.norm2(tgt)tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))tgt = tgt + self.dropout3(tgt2)tgt = self.norm3(tgt)return tgt
  • DETR在计算attention的时候没有使用masked attention,因为将特征图展开成一维以后,所有像素都可能是互相关联的,因此没必要规定mask。

  • object queries的转换过程:object queries是预定义的目标查询的个数,代码中默认为100。它的意义是:根据Encoder编码的特征,Decoder将100个查询转化成100个目标,即最终预测这100个目标的类别和bbox位置。最终预测得到的shape应该为[N, 100, C],N为Batch Num,100个目标,C为预测的100个目标的类别数+1(背景类)以及bbox位置(4个值)

得到预测结果以后,将object predictions和ground truth box之间通过匈牙利算法进行二分匹配:假如有K个目标,那么100个object predictions中就会有K个能够匹配到这K个ground truth,其他的都会和“no object”匹配成功,使其在理论上每个object query都有唯一匹配的目标,不会存在重叠,所以DETR不需要nms进行后处理。

匹配

匈牙利匹配算法

匈牙利匹配算法,二分图匹配算法

scipy.optimize.linear_sum_assignment(cost_matrix, maximize=False)
#cost_matrix 二分图开销矩阵

https://blog.csdn.net/CV_Autobot/article/details/129096035

https://blog.csdn.net/lemonxiaoxiao/article/details/108672039

query与gt匹配

transformer通过query输出n_q数量的bbox与对应分类置信度

真实框[gt1,gt2,…gtn]

每个bbox与gt之间有一个距离度量。

距离度量由三部分组成:真实类别的置信度得分+边界框的L1loss+边界框的IOUloss

通过匈牙利算法找出距离最小的query_bbox为gt对应的prebbox

loss训练

整体流程

pred输出→100(num_queries)class,100(num_queries)boxes

gt(tagert)→100class,100boxes(包含背景类)

pred,gt→计算相互loss,得到二分图成本矩阵,然后计算匈牙利匹配算法→return匹配上的classes与boxes

匹配成功的框→计算真正的class损失(),box回归损失(GLOUloss)。。

预测框与真实框的差异来自于两方面:1.二分图匹配时带来的差异。2.预测框与真实框之间的差异。

  • 分类损失:交叉熵损失,针对所有predictions。没有匹配到的querybbox应该分类为背景

  • 回归损失:bbox loss采用了L1 loss和giou loss,针对匹配成功的querybbox

  • cardinality 损失,对应函数是 loss_cardinality ; cardinality 损失是计算预测有物体的个数的绝对损失,值是为了记录,不参与反向传播

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/18805.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

非量表题如何进行信效度分析

效度是指设计的题确实在测量某个东西,一般问卷中使用到。如果是量表类的数据,其一般是用因子分析这种方法去验证效度水平,其可通过因子分析探究各测量量表的内部结构情况,分析因子分析得到的内部结构与自己预期的内部结构进行对比…

自学之路Flutter使用Provider进行状态管理

使用前的准备 首先在pubspec.yaml中配置,然后pub get,等待安装完成 我们首先创建两个比较简单的控制器,测试页面跳转之间的数据传递。 import package:flutter/material.dart;void main() {runApp(const MyApp()); }class MyApp extends StatelessWid…

python接口自动化之会话保持

🍦 会话保持-token 有的网站登录需要token鉴权,是啥意思呢,现在有两个接口,一个接口是登录,一个接口是提交订单,那你怎么保证,提交登录这个用户是登录状态呢。登录成功的接接口会在response里面…

大模型预训练结果到底是什么?

近日参加一个线下 AI 交流会议,会上有个非本行业的老师提问:“大家说的训练好的大模型到底是什么?是像 Word 软件一样可以直接使用的程序吗?” 这个问题看似简单,却一下把我问住了。的确,我们这些身处 AI 领…

Kafka原生API使用Java代码-生产者-发送消息

文章目录 1、生产者发送消息1.1、使用EFAK创建主题my_topic31.2、根据kafka官网文档写代码1.3、pom.xml1.4、KafkaProducer1.java1.5、使用EFAK查看主题1.6、再次运行KafkaProducer1.java1.7、再次使用EFAK查看主题 1、生产者发送消息 1.1、使用EFAK创建主题my_topic3 1.2、根…

STM32 OTA需要注意问题

一、OTA设计思路(问题) 1、根据stm32f405 flash分布,最初将flash划分为四个区域,分别是Bootloader、APP1、APP2、参数区,设备上电后,进入Bootloader程序,判断OTA参数,根据参数来确定…

APP逆向之调试的开启

很基础的一个功能设置,大佬轻喷。 背景 在开始进行对APP逆向分析的时候,需要对APP打开调试模式。 打开调试的模式有多种方式可以通过直接改包方式也可以通过借助第三方工具进行打开调试模式。 下面就整理下这个打开调试模式的一些方式。 改包修改模…

Java面试题分享-敏感词替换 java 版本

入职啦最近更新了一些后端笔试、面试题目,大家看看能快速实现吗? 关注 入职啦 微信公众号,每日更新有用的知识,Python,Java,Golang,Rust,javascript 等语言都有 不要再用replaceAll做…

npm获取yarn在安装依赖时 git://github.com/user/xx.git 无法访问解决方法 -- 使用 insteadOf设置git命令别名

今天在使用一个node项目时突然遇到 一个github的拉取异常&#xff0c;一看协议居然是git://xxx 貌似github早就不用这种格式了&#xff0c; 而是使用的gitgithub.com:xxx 这种或者https协议&#xff0c;解决方法&#xff1a; 使用insteadof设置git别名 url.<base>.inste…

DNF手游攻略:开荒必备攻略!

DNF手游马上就要开服了&#xff0c;今天给大家带来最完整的DNF手游入门教程。这篇攻略主要讲述了 DNF手游开服第一天要注意的事项&#xff0c;这是一个新手必备的技能书&#xff0c;可以让你在开服的时候&#xff0c;少走一些弯路&#xff0c;让你更快完成任务&#xff01;废话…

蓝牙Mesh模块多跳大数据量高带宽传输数据方法

随着物联网技术的飞速发展&#xff0c;越来越多的设备需要实现互联互通。蓝牙Mesh网络作为一种低功耗、高覆盖、易于部署的无线通信技术&#xff0c;已经成为物联网领域中的关键技术之一。在蓝牙Mesh网络中&#xff0c;节点之间可以通过多个跳数进行通信&#xff0c;从而实现大…

mysql-日志管理-error.log

日志管理 默认的数据库日志 vim /etc/my.cnf //错误日志 log-error/usr/local/mysql/mysql.log查看数据库日志 tail -f /usr/local/mysql/mysql.log1 错误日志 &#xff1a;启动&#xff0c;停止&#xff0c;关闭失败报错。rpm安装日志位置 /var/log/mysqld.log #默认开启 2 …

【OrangePi AIpro】香橙派 AIpro 为AI而生

产品简介 OrangePi AIpro(8T)&#xff1a;定义边缘智能新纪元的全能开发板 在当今人工智能与物联网技术融合发展的浪潮中&#xff0c;OrangePi AIpro(8T)凭借其强大的硬件配置与全面的接口设计&#xff0c;正逐步成为开发者手中的创新利器。这款开发板不仅代表了香橙派与华为…

最新淘宝死店全自动采集私信筛选脚本,号称日赚500+【采集软件+使用教程】

原理&#xff1a; 利用脚本自动采集长时间未登录店铺&#xff0c;然后脚本自动私信对应的店铺&#xff0c;看看商家是不是不回消息来判断是否是死店&#xff0c;再下单购买死店的产品&#xff0c;超过48小时不发货就可以联系客服获得赔付&#xff0c;一单利润百分之5%-30%&…

配置阿里yum源

配置阿里yum源&#xff08;这个很重要&#xff09;&#xff1a;https://developer.aliyun.com/article/1480470 1.备份系统自带yum源配置文件 mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup2.下载ailiyun的yum源配置文件 2.1 CentOS7 wge…

SRS、ZLMediakit音视频流媒体服务器

SRS、ZLMediakit都是做为webrtc的SFU&#xff08;selective forward unit&#xff09; WebRTC 开发实践&#xff1a;为什么你需要 SFU 服务器 https://mp.weixin.qq.com/s?__bizMzAxNTc1MjM0Mw&mid2652213442&idx1&sn33f0393a2dbc2b6a39c613bb238ec145&chksm…

Ansible03-Ansible Playbook剧本详解

目录 写在前面5. Ansible Playbook 剧本5.1 YAML语法5.1.1 语法规定5.1.2 示例5.1.3 YAML数据类型 5.2 Playbook组件5.3 Playbook 案例5.3.1 Playbook语句5.3.2 Playbook1 分发hosts文件5.3.3 Playbook2 分发软件包&#xff0c;安装软件包&#xff0c;启动服务5.3.3.1 任务拆解…

DHCP原理和配置服务

一、DHCP工作原理 DHCP(Dynamic Host Configuration Protocol&#xff0c;动态主机配置协议)由Internet工作任务小组设计开发专门用于为TCP/IP网络中的计算机自动分配TCP/IP参数的协议 使用DHCP的好处 减少管理员的工作量 避免输入错误的可能 避免IP地址冲突 当更改IP地址…

VUE3 学习笔记(9):VUE 插槽的概念、基本应用、传值

在调用子组件时&#xff0c;我们希望把父组件的HTML传给子组件&#xff0c;那么在引用子组件内部进行定义&#xff0c;然后子组件通过slot标签进行接收 基本示例 父 app.vue <!--内容控制--> <template><test><div><p>{{name}}</p><p…

Hikyuu性能实测:A股全市场1915万日K Bar,HDF5首次加载计算6.5秒

因为网友对“百万数据跑两秒"有疑问&#xff0c;经过一番交流后&#xff0c;才发现原来是我没有注明是首次数据加载过程中进行的计算&#xff0c;否则百万数据2秒反而是显的慢了&#xff0c;对此重新更新了相关描述&#xff1a;“AMD 7950x 实测&#xff1a;A股全市场&…