【论文阅读】DETR 论文逐段精读

【论文阅读】DETR 论文逐段精读

文章目录

  • 【论文阅读】DETR 论文逐段精读
  • 📖DETR 论文精读【论文精读】
    • 🌐前言
    • 📋摘要
    • 📚引言
    • 🧬相关工作
    • 🔍方法
      • 💡目标函数
      • 📜模型结构
      • ⚙️代码
    • 📌实验

参考跟李沐学AI: 精读DETR

📖DETR 论文精读【论文精读】


🌐前言

目标检测领域:从目标检测开始火到 detr 都很少有端到端的方法,大部分方法最后至少需要后处理操作(NMS, non-maximum suppression 非极大值抑制)。有了 NMS,模型调参就会很复杂,而且即使训练好了一个模型,部署起来也非常困难(NMS 不是所有硬件都支持)。

📋摘要

贡献:把目标检测做成一个端到端的框架,把之前特别依赖人的先验知识的部分删掉了(NMS 部分、anchor)。

DETR提出

  • 新的目标函数,通过二分图匹配的方式,强制模型输出一组独一无二的预测(没有那么多冗余框,每个物体理想状态下就会生成一个框)
  • 使用 encoder-decoder 的架构

两个小贡献:

  1. decoder 还有另外一个输入 learned object query,类似 anchor 的意思
    (给定这些object query之后,detr就可以把learned object query和全局图像信息结合一起,通过不同的做注意力操作,从而让模型直接输出最后的一组预测框)
  2. 想法&&实效性:并行比串行更合适

DETR 的好处:

  1. 简单性:想法上简单,不需要一个特殊的 library,只要硬件支持 transformer 或 CNN,就一定支持 detr
  2. 性能:在 coco 数据集上,detr 和一个训练非常好的 faster RCNN 基线网络取得了差不多的效果,模型内存和速度也和 faster RCNN 差不多
  3. 想法好,解决了目标检测领域很多痛点,写作好
  4. 别的任务:全景分割任务上 detr 效果很好,detr 能够非常简单拓展到其他任务上

📚引言


DETR 流程(训练)

  1. CNN 提特征
  2. 特征拉直,送到 encoder-decoder 中,encoder 作用:进一步学习全局信息,为近下来的 decoder,也就是最后出预测框做铺垫。
  3. decoder 生成框的输出,当你有了图像特征之后,还会有一个 object query(限定了你要出多少框),通过 query 和特征在 decoder 里进行自注意力操作,得到输出的框(文中是100,无论是什么图片都会预测100个框)
  4. loss :二分图匹配,计算100个预测框和2个 GT 框的 matching loss,决定100个预测框哪两个是独一无二对应到红黄色的 GT 框,匹配的框去算目标检测的 loss

推理
1、2、3一致,第四步 loss 不需要,直接在最后的输出上用一个阈值卡一个输出的置信度,置信度比较大(>0.7的)保留,置信度小于0.7的当做背景物体。

🧬相关工作

让 DETR 成功主要原因:transformer

🔍方法

分两块:1、基于集合的目标函数怎么做,作者如何通过二分图匹配把预测的框和 GT 框连接在一起,算得目标函数 2、detr 具体模型架构

💡目标函数

DETR模型最后输出是一个固定集合,无论图片是什么,最后都会输出 n 个(本文 n=100)

问题:detr 每次都会出 100 个输出,但是实际上一个图片的 GT 的 bounding box 可能只有几个,如何匹配?如何计算 loss?怎么知道哪个预测框对应 GT 框?
匈牙利算法是解决该问题的一个知名且高效的算法,能够以较低的复杂度得到唯一的最优解。
在 scipy 库中,已经封装好了匈牙利算法,只需要将成本矩阵(cost matrix)输入进去就能够得到最优的排列。在 DETR 的官方代码中,也是调用的这个函数进行匹配(from scipy.optimize import linear_sum_assignment)。
从N个预测框中,选出与M个GT Box最匹配的预测框,也可以转化为二分图匹配问题,这里需要填入矩阵的“成本”,就是每个预测框和GT Box的损失。对于目标检测问题,损失就是分类损失和边框损失组成。

所以整个步骤就是:

  • 遍历所有的预测框和 GT Box,计算其 loss。
  • 将 loss 构建为 cost matrix,然后用 scipy 的 linear_sum_assignment(匈牙利算法)求出最优解,即找到每个 GT Box 最匹配的那个预测框。
  • 计算最优的预测框和 GT Box 的损失。(分类+回归)

但是在 DETR 中,损失函数有两点小改动:

  • 去掉分类损失中的 log
  • 回归损失为 L1 loss+GIOU

📜模型结构

下面参考官网的一个 demo,以输入尺寸3×800×1066为例进行前向过程:

  • CNN 提取特征([800,1066,3]→[25,34,256]
    backbone 为 ResNet-50,最后一个 stage 输出特征图为 25×34×2048(32 倍下采样),然后用 1×1 的卷积将通道数降为 256;
  • Transformer encoder 计算自注意力([25,34,256]→[850,256]
    将上一步的特征拉直为 850×256,并加上同样维度的位置编码(Transformer 本身没有位置信息),然后输入的 Transformer encoder 进行自注意力计算,最终输出维度还是 850×256;
  • Transformer decoder 解码,生成预测框
    decoder 输入除了 encoder 部分最终输出的图像特征,还有前面提到的 learned object query,其维度为 100×256。在解码时,learned object query 和全局图像特征不停地做 across attention,最终输出 100×256 的自注意力结果。
    这里的 object query 即相当于之前的 anchor/proposal,是一个硬性条件,告诉模型最后只得到 100 个输出。然后用这 100 个输出接 FFN 得到分类损失和回归损失。
  • 使用检测头输出预测框
    检测头就是目标检测中常用的全连接层(FFN),输出 100 个预测框( h x c e n t e r , y c e n t e r , w , h h x_{center}, y_{center}, w, h hxcenter,ycenter,w,h )和对应的类别。
  • 使用二分图匹配方式输出最终的预测框,然后计算预测框和真实框的损失,梯度回传,更新网络。

除此之外还有部分细节:

  • Transformer-encode/decoder 都有 6层
  • 除第一层外,每层 Transformer encoder 里都会先计算 object query 的 self-attention,主要是为了移除冗余框。这些 query 交互之后,大概就知道每个 query 会出哪种框,互相之间不会再重复(见实验)。
  • decoder 加了 auxiliary loss,即每层 decoder 输出的 100×256 维的结果,都加了 FFN 得到输出,然后去计算 loss,这样模型收敛更快。(每层 FFN 共享参数)

⚙️代码

import torch
from torch import nn
from torchvision.models import resnet50class DETR(nn.Module):def __init__(self, num_classes, hidden_dim, nheads,num_encoder_layers, num_decoder_layers):super().__init__()# We take only convolutional layers from ResNet-50 modelself.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])self.conv = nn.Conv2d(2048, hidden_dim, 1) # 1×1卷积层将2048维特征降到256维self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # 类别FFNself.linear_bbox = nn.Linear(hidden_dim, 4)                # 回归FFNself.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # object query# 下面两个是位置编码self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):x = self.backbone(inputs)h = self.conv(x)H, W = h.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1) # 位置编码h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))return self.linear_class(h), self.linear_bbox(h).sigmoid()detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)

📌实验

  • 最上面一部分是 Detectron 2 实现的 Faster RCNN ,但是本文中作者使用了很多 trick
  • 中间部分是作者使用了 GIoU loss、更强的数据增强策略、更长的训练时间来把上面三个模型重新训练了一次,这样更显公平。重新训练的模型以+表示,参数量等这些是一样的,但是普偏提了两个点
  • 下面部分是 DETR 模型,可以看到参数量、GFLOPS 更小,但是推理更慢。模型比 Faster RCNN 精度高一点,主要是大物体检测提升 6 个点 AP,小物体相比降低了 4个点左右

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/788426.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu-server部署hive-part4-部署hive

参照 https://blog.csdn.net/qq_41946216/article/details/134345137 操作系统版本:ubuntu-server-22.04.3 虚拟机:virtualbox7.0 部署hive 下载上传 下载地址 http://archive.apache.org/dist/hive/ apache-hive-3.1.3-bin.tar.gz 以root用户上传至…

多层PCB内部长啥样?

硬件工程师刚接触多层PCB的时候,很容易看晕。动辄十层八层的,线路像蜘蛛网一样。 画了几张多层PCB电路板内部结构图,用立体图形展示各种叠层结构的PCB图内部架构。 高密度互联板(HDI)的核心 在过孔 多层PCB的线路加工,和单层双…

Transformer - Positional Encoding 位置编码 代码实现

Transformer - Positional Encoding 位置编码 代码实现 flyfish import torch import torch.nn as nn import torch.nn.functional as F import os import mathclass PositionalEncoding(nn.Module):def __init__(self, d_model, dropout, max_len5000):super(PositionalEnco…

深度学习理论基础(六)注意力机制

目录 深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息&…

http: server gave HTTP response to HTTPS client 分析一下这个问题如何解决中文告诉我详细的解决方案

这个错误信息表明 Docker 客户端在尝试通过 HTTPS 协议连接到 Docker 仓库时,但是服务器却返回了一个 HTTP 响应。这通常意味着 Docker 仓库没有正确配置为使用 HTTPS,或者客户端没有正确配置以信任仓库的 SSL 证书。以下是几种可能的解决方案&#xff1…

半导体制程离子注入注入的是哪些离子

离子注入是一种低温过程 通过该过程将一种元素的离子加速进入固体靶材,从而改变靶材的物理、化学或电学性质。离子注入用于半导体器件制造和金属精加工以及材料科学研究。如果离子停止并保留在目标中,则它们可以改变目标的元素成分(如果离子…

6 个典型的Java 设计模式应用场景题

单例模式(Singleton) 场景: 在一个Web服务中,数据库连接池应当在整个应用生命周期中只创建一次,以减少资源消耗和提升性能。使用单例模式确保数据库连接池的唯一实例。 代码实现: import java.sql.Connection; import java.sql.SQLException;public class DatabaseConne…

上位机图像处理和嵌入式模块部署(qmacviusal边缘宽度测量)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面有一篇文章,我们了解了测量标定是怎么做的。即,我们需要提前知道测量的方向,灰度的方向,实际的…

“省钱有道”的太平鸟,如何真正“高飞”?

衣食住行产业中,服装品类消费弹性较大、可选属性较强,其发展可以显著反映当前的经济温度。 根据国家统计局数据,2023年1-12月,我国限额以上单位服装类商品零售额累计10352.9亿元,同比增长15.4%,增速比2022…

Python框架下的qt设计之JSON格式化转换小程序

JSON转换小程序 代码展示: 主程序代码: from PyQt6.QtWidgets import (QApplication, QDialog, QMessageBox )import sys import jsonclass MyJsonFormatter(jsonui.Ui_jsonFormatter,QDialog): # jsonui是我qt界面py文件名def __init__(self):super()…

【HTML】注册页面制作 案例二

(大家好,今天我们将通过案例实战对之前学习过的HTML标签知识进行复习巩固,大家和我一起来吧,加油!💕) 案例复习 通过综合案例,主要复习: 表格标签,可以让内容…

【Go】十七、进程、线程、协程

文章目录 1、进程、线程2、协程3、主死从随4、启动多个协程5、使用WaitGroup控制协程退出6、多协程操作同一个数据7、互斥锁8、读写锁9、deferrecover优化多协程 1、进程、线程 进程作为资源分配的单位,在内存中会为每个进程分配不同的内存区域 一个进程下面有多个…

集合的学习

为什么要有集合:集合会自动扩容 集合不能存基本数据类型(基本数据类型是存放真实的值,而引用数据类型是存放一个地址,这个地址存放在栈区,地址所指向的内容存放在堆区) 数组和集合的对比: 集…

Flutter 开发学习笔记(3):第三方UI库的引入

文章目录 前言初始化程序Icon导入如何导入 Toast消息提示框引入简单封装简单使用 Charts图表导入新建pages文件夹存放page简单代码实现效果 总结 前言 Flutter已经发布了有10年了,生态也算比较完善了。用于安卓程序开发应该是非常的方便。我们这里就接入一些简单的…

golang语言系列:Web框架+路由 之 Gin

云原生学习路线导航页(持续更新中) 本文是golang语言学习系列,本篇对Gin框架的基本使用方法进行学习 1.Gin框架是什么 Gin 是一个 Go (Golang) 编写的轻量级 http web 框架,运行速度非常快,如果你是性能和高效的追求者…

【JavaEE】_Spring MVC项目上传文件

目录 1. 文件上传具体实现 2. 保存文件 1. 文件上传具体实现 .java文件内容如下: package com.example.demo.controller;import com.example.demo.Person; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.Multip…

day4|gin的中间件和路由分组

中间件其实是一个方法, 在.use就可以调用中间件函数 r : gin.Default()v1 : r.Group("v1")//v1 : r.Group("v1").Use()v1.GET("test", func(c *gin.Context) {fmt.Println("get into the test")c.JSON(200, gin.H{"…

特征融合篇 | YOLOv8改进之将Neck网络更换为GFPN(附2种改进方法)

前言:Hello大家好,我是小哥谈。GFPN(Global Feature Pyramid Network)是一种用于目标检测的神经网络架构,它是在Faster R-CNN的基础上进行改进的,旨在提高目标检测的性能和效果。其核心思想是引入全局特征金字塔,通过多尺度的特征融合来提取更丰富的语义信息。具体来说,…

FPGA + 图像处理 (二) RGB转YUV色域、转灰度图及仿真

前言 具体关于色域的知识就不细说了,简单来讲YUV中Y通道可以理解为就是图像的灰度图,因此,将RGB转化为YUV是求彩色图的灰度直方图、进行二值化操作等的基础。 HDMI时序生成模块 这里先介绍一下仿真时用于生成HDMI时序,用这个时…

自贡市第一人民医院:超融合与 SKS 承载 HIS 等核心业务应用,加速国产化与云原生转型

自贡市第一人民医院始建于 1908 年,现已发展成为集医疗、科研、教学、预防、公共卫生应急处置为一体的三级甲等综合公立医院。医院建有“全国综合医院中医药工作示范单位”等 8 个国家级基地,建成高级卒中中心、胸痛中心等 6 个国家级中心。医院日门诊量…