DEtection TRansformer (DETR) 与 You Only Look Once (YOLO)

曾经想过计算机如何分析图像,识别并定位其中的物体吗?这正是计算机视觉领域的目标检测所完成的任务。DEtection TRansformer(DETR)和You Only Look Once(YOLO)是目标检测的两种重要方法。YOLO已经赢得了作为实时目标检测和跟踪问题的首选模型的声誉。与此同时,DETR是一种由变换器技术驱动的崭露头角的竞争者,有潜力在计算机视觉领域引发革命,类似于它对自然语言处理的影响。在本博客文章中,我将探讨这两种方法,以了解它们的工作原理!

自2012年以来,计算机视觉经历了一次革命性的转变,这是由卷积神经网络(CNN)和深度学习架构的出现推动的。这些架构中值得注意的有AlexNet(2012年),GoogleNet(2014年),VGGNet(2014年)和ResNet(2015年),它们包含了大量的卷积层,以提高图像分类的准确性。而图像分类任务涉及将标签分配给整个图像,例如将一张图片分类为狗或汽车,而目标检测不仅要识别图像中的内容,还要确定每个物体在图像中的位置。

cd8f4862d5953d0c26258261a0d8596b.png

图像上的目标检测和分类示例

YOLO的原始版本(2015年)是实时目标检测的一项突破性工作,当它发布时,它仍然是实际视觉应用中最常用的模型之一。它将检测过程从两到三个阶段(即R-CNN,Fast R-CNN)转变为单阶段的卷积阶段,并在准确性和速度方面超越了所有最先进的目标检测方法。原始论文中的模型架构随着时间的推移发生了变化,添加了不同的手工设计特性以提高模型的准确性。以下是YOLO的前三个版本及其区别的概述。

YOLO v1(2015年)是原始版本,为后续版本奠定了基础。它使用单一的深度卷积神经网络(CNN)来预测边界框和类别概率。YOLO v1将输入图像分成一个网格,并在网格的每个单元格中进行预测。每个单元格负责预测一定数量的边界框及其对应的类别概率。这个版本以令人印象深刻的速度实现了实时目标检测,但在检测小物体和准确定位重叠物体方面存在一些限制。

68bad393ed0dbe68c2f7802c785739ca.png

YOLO v1(2015年)原始框架

YOLO v2(2016年)解决了原始YOLO模型的一些限制。它引入了锚定框(anchor boxes),有助于更好地预测不同尺寸和宽高比的边界框。YOLO v2使用了更强大的骨干网络Darknet-19,并不仅在原始数据集(PASCAL VOC)上进行了训练,还在COCO数据集上进行了训练,大幅增加了可检测类别的数量。锚定框和多尺度训练的结合有助于提高小物体的检测性能。

YOLO v3(2018年)进一步提高了目标检测的性能。这个版本引入了特征金字塔网络的概念,具有多个检测层,允许模型在不同尺度和分辨率下检测物体。YOLO v3使用了一个更大的网络架构,拥有53个卷积层,称为Darknet-53,提高了模型的表示能力。YOLO v3在三个不同的尺度上进行检测:13x13、26x26和52x52的网格。每个尺度在每个格子单元格中预测不同数量的边界框。

75ea6493831083601c7b45903bbb2fb8.png

YOLO v3 框架

我们预测了多少个边界框?在416 x 416的分辨率下,YOLO v1预测了7 x 7 = 49个框。YOLO v2预测了13 x 13 x 5 = 845个框。而YOLO v3在3个不同的尺度上进行了预测:13 x 13 x 3 + 26 x 26 x 3 + 52 x 52 x 3 = 10,647个框。非极大值抑制(NMS)是一种后处理技术,用于过滤掉多余和重叠的边界框预测。在NMS算法中,首先删除置信度低于某个阈值的框。然后,具有与“当前”预测具有一定IoU(交并比)阈值(例如0.5)以上的较低置信度分数的所有其他预测被标记为多余并被抑制。

DETR(DEtection TRansformer)是一种相对新的目标检测算法,由Facebook人工智能研究(FAIR)的研究人员于2020年提出。它基于变换器架构,这是一种用于各种自然语言处理任务的强大的序列到序列模型。传统的目标检测器(例如R-CNN和YOLO)复杂并经历了多次变化,依赖于手工设计的组件(例如NMS)。与此不同,DETR是一个直接的集合预测模型,它使用变换器编码器-解码器架构一次性预测所有物体。这种方法比传统的目标检测器更简单、更高效,并在COCO数据集上实现了可比较的性能。

DETR架构简单,由三个主要组件组成:用于特征提取的CNN骨干(例如ResNet),变换器编码器-解码器和用于最终检测预测的前馈网络(FFN)。骨干处理输入图像并生成激活映射。变换器编码器减少通道维度并应用多头自注意力和前馈网络。变换器解码器使用N个物体嵌入的并行解码,并独立地预测边界框坐标和类别标签。DETR使用成对关系一次性推断所有物体,从整个图像上下文中受益。

625e25dd5475998704da4224880a728d.png

下面的代码(摘自DETR的官方GitHub存储库)定义了这个DETR模型的前向传递,它通过各种层处理输入数据,包括卷积骨干和变换器网络。我在代码中包含了网络每个层的输出形状,以使您了解所有数据的变换过程。

class DETRdemo(nn.Module):def __init__(self, num_classes, hidden_dim=256, nheads=8,num_encoder_layers=6, num_decoder_layers=6):super().__init__()# 2. create ResNet-50 backboneself.backbone = resnet50()del self.backbone.fc# create conversion layerself.conv = nn.Conv2d(2048, hidden_dim, 1)# 3. create a default PyTorch transformerself.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)# 4. prediction heads, one extra class for predicting non-empty slots# note that in baseline DETR linear_bbox layer is 3-layer MLPself.linear_class = nn.Linear(hidden_dim, num_classes + 1)self.linear_bbox = nn.Linear(hidden_dim, 4)# 5. output positional encodings (object queries)self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))# spatial positional encodings# note that in baseline DETR we use sine positional encodingsself.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs): # propagate inputs through ResNet-50 up to avg-pool layer# input: torch.Size([1, 3, 800, 1066])x = self.backbone.conv1(inputs)    # torch.Size([1, 64, 400, 533])x = self.backbone.bn1(x)           # torch.Size([1, 64, 400, 533])x = self.backbone.relu(x)          # torch.Size([1, 64, 400, 533])   x = self.backbone.maxpool(x)       # torch.Size([1, 64, 200, 267])x = self.backbone.layer1(x)        # torch.Size([1, 256, 200, 267])x = self.backbone.layer2(x)        # torch.Size([1, 512, 100, 134])x = self.backbone.layer3(x)        # torch.Size([1, 1024, 50, 67])x = self.backbone.layer4(x)        # torch.Size([1, 2048, 25, 34])# convert from 2048 to 256 feature planes for the transformerh = self.conv(x)                   # torch.Size([1, 256, 25, 34])# construct positional encodingsH, W = h.shape[-2:]pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1) # torch.Size([850, 1, 256])src = pos + 0.1 * h.flatten(2).permute(2, 0, 1)  # torch.Size([850, 1, 256])target = self.query_pos.unsqueeze(1)    # torch.Size([100, 1, 256])# propagate through the transformerh = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1)).transpose(0, 1) # torch.Size([1, 100, 256])linear_cls = self.linear_class(h)        # torch.Size([1, 100, 92])liner_bbx = self.linear_bbox(h).sigmoid()  # torch.Size([1, 100, 4])# finally project transformer outputs to class labels and bounding boxesreturn {'pred_logits': linear_cls,  'pred_boxes': linear_bbx}

以下是代码的逐步解释:

  1. 初始化:__init__方法定义了DETR模块的结构。它以几个超参数作为输入,包括类别数量(num_classes),隐藏维度(hidden_dim),注意力头数(nheads),以及编码器和解码器的层数(num_encoder_layers和num_decoder_layers)等。 

  2. BackBone和卷积层:代码创建了一个ResNet-50(self.backbone),并删除了其全连接(fc)层,因为它不会用于检测。卷积层(self.conv)被添加用于将 BackBone 的输出从2048通道转换为hidden_dim通道。

  3. Transformer:使用nn.Transformer类创建了一个PyTorch变换器(self.transformer)。这个变换器将同时处理模型的编码器和解码器部分。编码器和解码器层数以及其他参数根据提供的超参数进行设置。

  4. 预测头:模型为预测定义了两个线性层:self.linear_class 用于预测类别对数概率。额外添加了一个类别以预测非空槽,因此类别数为num_classes + 1。self.linear_bbox 用于预测边界框的坐标。对其应用了sigmoid()函数以确保边界框坐标在[0, 1]范围内。 

  5. 位置编码:位置编码对于基于变换器的模型至关重要。模型定义了查询位置编码(self.query_pos)和空间位置编码(self.row_embed和self.col_embed)。

这些编码有助于模型理解不同元素之间的空间关系。模型生成100个有效的预测。我们仅保留输出中概率高于特定限制的部分预测,并且舍弃所有其他预测。

示例

在这一部分,我展示了我的GitHub存储库中的一个示例项目,我在该项目中使用了DETR和YOLO模型来处理实时视频流。该项目的目标是研究DETR在实时视频流上的性能,与行业中大多数实时应用的首选模型YOLO进行比较。

import torch
from ultralytics import YOLO
import cv2
from dataclasses import dataclass
import time
from utils.functions import plot_results, rescale_bboxes, transform
from utils.datasets import LoadWebcam, LoadVideo
import logginglogging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
)@dataclass
class Config:source: str = "assets/walking_resized.mp4"view_img: bool = Falsemodel_type: str = "detr_resnet50"device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")skip: int = 1yolo: bool = Trueyolo_type = "yolov8n.pt"class Detector:def __init__(self):self.config = Config()self.device = self.config.deviceif self.config.source == "0":logging.info("Using stream from the webcam")self.dataset = LoadWebcam()else:logging.info("Using stream from the video file: " + self.config.source)self.dataset = LoadVideo(self.config.source)self.start = time.time()self.count = 0def load_model(self):if self.config.yolo:if self.config.yolo_type is None or self.config.yolo_type == "":raise ValueError("YOLO model type is not specified")model = YOLO(self.config.yolo_type)logging.info(f"YOLOv8 Inference using {self.config.yolo_type}")else:if self.config.model_type is None or self.config.model_type == "":raise ValueError("DETR model type is not specified")model = torch.hub.load("facebookresearch/detr", self.config.model_type, pretrained=True).to(self.device)model.eval()logging.info(f"DETR Inference using {self.config.model_type}")return modeldef detect(self):model = self.load_model()for img in self.dataset:self.count += 1if self.count % self.config.skip != 0:continueif not self.config.yolo:im = transform(img).unsqueeze(0).to(self.device)outputs = model(im)# keep only predictions with 0.7+ confidenceprobas = outputs["pred_logits"].softmax(-1)[0, :, :-1]keep = probas.max(-1).values > 0.9bboxes_scaled = rescale_bboxes(outputs["pred_boxes"][0, keep].to("cpu"), img.shape[:2])else:outputs = model(img)logging.info(f"FPS: {self.count / self.config.skip / (time.time() - self.start)}")# print(f"FPS: {self.count / self.skip / (time.time() - self.start)}")if self.config.view_img:if self.config.yolo:annotated_frame = outputs[0].plot()cv2.imshow("YOLOv8 Inference", annotated_frame)if cv2.waitKey(1) & 0xFF == ord("q"):breakelse:plot_results(img, probas[keep], bboxes_scaled)logging.info("************************* Done *****************************")if __name__ == "__main__":detector = Detector()detector.detect()

下面的server.py脚本使用了Ultralytics的YOLO v8模型和torch hub中预训练的DETR模型。server.py脚本负责从诸如网络摄像头、IP摄像头或本地视频文件等源获取数据。可以在server.py配置数据类中修改此源。性能评估结果显示,使用yolov8m.pt模型时,它在Tesla T4 GPU上实现了每秒55帧(FPS)的卓越处理速度。另一方面,使用detr_resnet50模型在Tesla T4 GPU上实现了每秒15帧(FPS)的处理速度。

结论

总之,YOLO是需要实时检测和速度的应用的绝佳选择,适用于视频分析和实时对象跟踪等应用。另一方面,DETR在需要提高准确性并处理物体之间复杂交互的任务中表现出色,这在医学影像、细粒度目标检测和检测质量高于实时处理速度的场景中可能特别重要。然而,值得注意的是,DETR的新版本——即实时DETR或RT-DETR——于2023年发布,声称在速度和准确性方面均优于所有相似规模的YOLO检测器。尽管这个创新没有在本博客中涵盖,但强调了这个领域的动态性,以及根据特定应用需求进一步优化YOLO和DETR之间的选择的潜力。

·  END  ·

HAPPY LIFE

4029158c8737dfd1a78fa47a114d5de6.png

本文仅供学习交流使用,如有侵权请联系作者删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/64856.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RISC-V】RISC-V寄存器

一、通用寄存器 32位RISC-V体系结构提供32个32位的整型通用寄存器寄存器别名全称说明X0zero零寄存器可做源寄存器(rs)或目标寄存器(rd)X1ra链接寄存器保存函数返回地址X2sp栈指针寄存器指向栈的地址X3gp全局寄存器用于链接器松弛优化X4tp线程寄存器常用于在OS中保存指向进程控…

回归预测 | MATLAB实现IBES-ELM改进的秃鹰搜索优化算法优化极限学习机多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现IBES-ELM改进的秃鹰搜索优化算法优化极限学习机多输入单输出回归预测(多指标,多图) 目录 回归预测 | MATLAB实现IBES-ELM改进的秃鹰搜索优化算法优化极限学习机多输入单输出回归预测(多指标,多图…

手撕二叉平衡树

今天给大家带来的是平衡树的代码实现&#xff0c;如下&#xff1a; #pragma once #include <iostream> #include <map> #include <set> #include <assert.h> #include <math.h> using namespace std; namespace cc {template<class K, clas…

CXL寄存器介绍(2)- CXL DVSEC

&#x1f525;点击查看精选 CXL 系列文章&#x1f525; &#x1f525;点击进入【芯片设计验证】社区&#xff0c;查看更多精彩内容&#x1f525; &#x1f4e2; 声明&#xff1a; &#x1f96d; 作者主页&#xff1a;【MangoPapa的CSDN主页】。⚠️ 本文首发于CSDN&#xff0c…

TiDB 一栈式综合交易查询解决方案获“金鼎奖”优秀金融科技解决方案奖

日前&#xff0c;2023“金鼎奖”评选结果揭晓&#xff0c; 平凯星辰&#xff08;北京&#xff09;科技有限公司研发的 TiDB 一栈式综合交易查询解决方案获“金鼎奖”优秀金融科技解决方案奖 &#xff0c; 该方案已成功运用于 多家国有大行、城商行和头部保险企业 。 此次获奖再…

【AI】《动手学-深度学习-PyTorch版》笔记(二十一):目标检测

AI学习目录汇总 1、简述 通过前面的学习,已经了解了图像分类模型的原理及实现。图像分类是假定图像中只有一个目标,算法上是对整个图像做的分类。 下面我们来学习“目标检测”,即从一张图像中找出需要的目标,并标记出位置。 2、边界框 边界框:bounding box,就是一个方…

我想开通期权?如何开通期权账户?

场内期权的合约由交易所统一标准化定制&#xff0c;大家面对的同一个合约对应的价格都是一致的&#xff0c;比较公开透明&#xff0c;期权开户当天不能交易的&#xff0c;期权开户需要满足20日日均50万及半年交易经验即可操作&#xff0c;下文科普我想开通期权&#xff1f;如何…

Java设计模式:四、行为型模式-10:访问者模式

一、定义&#xff1a;访问者模式 访问者模式&#xff1a;核心在于同一个事物不同视角下的访问信息不同。 在一个稳定的数据结构下&#xff0c;例如用户信息、雇员信息等&#xff0c;增加易变的业务访问逻辑。为了增强扩展性&#xff0c;将两部分的业务解耦的一种设计模式。 二…

详解 SpringMVC 中获取请求参数

文章目录 1、通过ServletAPI获取2、通过控制器方法的形参获取请求参数3、[RequestParam ](/RequestParam )4、[RequestHeader ](/RequestHeader )5、[CookieValue ](/CookieValue )6、通过POJO获取请求参数7、解决获取请求参数的乱码问题总结 在Spring MVC中&#xff0c;获取请…

自建音乐播放器之一

这里写自定义目录标题 1.1 官方网站 2. Navidrome 简介2.1 简介2.2 特性 3. 准备工作4. 视频教程5. 界面演示5.1 初始化页5.2 专辑页 前言 之前给大家介绍过 Koel 音频流服务&#xff0c;就是为了解决大家的这个问题&#xff1a;下载下来的音乐&#xff0c;只能在本机欣赏&…

【pyqt5界面化工具开发-12】QtDesigner图形化界面设计

目录 0x00 前言 一、启动程序 二、基础的使用 三、保存布局文件 四、加载UI文件 0x00 前言 关于QtDesigner工具的配置等步骤&#xff08;网上链接也比较多&#xff09; 下列链接非本人的&#xff08;如果使用pip 在命令行安装过pyqt5以及tools&#xff0c;那么就可以跳过…

springboot整合SpringSecurity

先写了一个配置类 给这个访问路径&#xff0c;加上角色权限 package com.qf.config;import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; impo…

【网络编程上】

目录 一.什么是互联网 1.计算机网络的定义与分类&#xff08;了解&#xff09; &#xff08;1&#xff09;计算机网络的定义 &#xff08;2&#xff09;计算机网络的分类 ① 按照网络的作用范围进行分类 ②按照网络的使用者进行分类 2.网络的网络 &#xff08;理解&#xf…

苹果Mac系统如何优化流畅的运行?提高运行速度

Mac系统的稳定性和流畅性一直备受大家称赞&#xff0c;这也是大多数人选择Mac的原因&#xff0c;尽管如此&#xff0c;我们仍不时地对Mac进行优化、调整&#xff0c;以使其比以前更快、更流畅地运行。以下是小编分享给各位的Mac优化方法&#xff0c;记得保存哦~ 一、释放被过度…

【笔记】常用 js 函数

数组去重 Array.from(new Set()) 对象合并 Object.assign . 这里有个细节&#xff1a;当两个对象中含有key相同value不同时&#xff0c;会以 后面对象的key&#xff1a;value为准 保留小数点后几位 toFixed 注意&#xff1a; Number型&#xff0c;用该方法处理完&#xff0c;会…

Markdown Preview Plus Chrome插件使用

Markdown Preview Plus Chrome插件使用 1.插件说明2.插件下载3.插件配置4.文档样式4.1 网页显示4.2 导出PDF 系统&#xff1a;Win10 Chrome&#xff1a;113.0.5672.127 Markdown Preview Plus&#xff1a;0.7.3 1.插件说明 一般 markdown 工具自带的预览功能比较简单&#xff…

RTPEngine 通过 HTTP 获取指标的方式

文章目录 1.背景介绍2.RTPEngine 支持的 HTTP 请求3.通过 HTTP 请求获取指标的方法3.1 脚本配置3.2 请求方式 1.背景介绍 RTPEngine 是常用的媒体代理服务器&#xff0c;通常被集成到 SIP 代理服务器中以减小代理服务器媒体传输的压力&#xff0c;其架构如下图所示。这种使用方…

人工智能论文通用创新点(一)——ACMIX 卷积与注意力融合、GCnet(全局特征融合)、Coordinate_attention、SPD(可替换下采样)

1.ACMIX 卷积与注意力融合 论文地址:https://arxiv.org/pdf/2111.14556.pdf 为了实现卷积与注意力的融合,我们让特征图经过两个路径,一个路径经过卷积,另外一个路径经过Transformer,但是,现在有一个问题,卷积路径比较快,Transformer比较慢。因此,我们让Q,K,V通过1*1的…

jmeter 线程组

在jmeter中&#xff0c;通过指定并发数量、启动延迟时间和持续时间&#xff0c;并组织示例&#xff08;Samplers&#xff09;在多个线程之间的执行方式&#xff0c;实现模拟并发用户的行为。 添加线程组&#xff1a; 在测试计划中&#xff0c;右键点击“添加” -> “Thread…

Android 1.1 背景相关与系统架构分析

目录 1.1 背景相关与系统架构分析 分类 Android 基础入门教程 1.Android背景与当前的状况 2.Android系统特性与平台架构 系统特性&#xff1a; 平台架构图&#xff1a; 架构的简单理解&#xff1a; 3.本节小结&#xff1a; 1.1 背景相关与系统架构分析 分类 Android 基础…