【AI】目标检测算法DETR源码解析及推理测试

说到目标检测,自然而然我们会想到YOLO这个框架,YOLO框架已经发展到V8版本了,各种应用也比较成熟,不过我最近在研究Transformer,今天的主角是Transformer在目标检测的开山之作:DETR:End-to-End Object Detection with Transformers,这是由Facebook AI团队出品的,其源码地址在:https://github.com/facebookresearch/detr。

0.DETR简介

DETR的特别之一在于将transformer应用于目标检测领域;而他不同于之前的算法的地方在于它不像YOLO这种使用anchor,也不想faster-rcnn使用各种proposal方法,同时它还去除了NMS,这让它在当时一众目标检测算法中显得比较特别。

论文提出了一种将目标检测视为直接集预测问题的新方法。DETR简化了检测流程,有效地消除了对许多人工设计组件的需求,如NMS或anchor生成。新框架的主要组成部分,称为DEtection TRansformer或DETR,是一种基于集合的全局损失,通过二分匹配强制进行一对一预测,以及一种transformer encoder-decoder架构。
在这里插入图片描述
基本思想:
(1)先来个CNN得到各Patch作为输入,再套transformer做编码和解码;
(2)编码的路子跟VIT基本一样,重点在解码,直接预测N个坐标框(原文是100个);

在这里插入图片描述
(3)编码(Encoder)的主要任务是得到各个目标的注意力结果,准备好特征,等解码器来进行匹配
(4)解码(Decoder)过程的核心目标是让object queries学会从原始特征中找到物体的位置;
(5)object queries采用随机初始化的方式(0+位置编码);
(6)输出层就是N个object queries的预测;
(7)输出的匹配采用匈牙利匹配方式,按照loss最小的组合,匹配上的作为目标输出,没有匹配上的都作为背景。

DETR不仅可以用在检测领域,在分割领域也同样可用。性能上倒是还不错,就是训练太慢了,训练模型的机器配置要求也比较高。

1.源码解析

源码的启动程序是main.py程序,在main函数中,主要关注build_model和build_dataset,这两个分别是构建模型和构建数据集的。

1.1数据集处理

模型的数据集采用的coco2017,其数据集处理也是采用的coco数据集格式,对于需要使用自己数据集的,需要将图像和标注文件都做成coco数据集的形式,否则需要自己实现数据集加载的程序。运行模型训练需要指定数据集的路径,即coco_path参数。

由build_dataset函数可知,构建数据集的操作位于coco.py文件中,这里直接继承了torchvision中的方法torchvision.datasets.CocoDetection

class CocoDetection(torchvision.datasets.CocoDetection):def __init__(self, img_folder, ann_file, transforms, return_masks):super(CocoDetection, self).__init__(img_folder, ann_file)self._transforms = transformsself.prepare = ConvertCocoPolysToMask(return_masks)def __getitem__(self, idx):img, target = super(CocoDetection, self).__getitem__(idx)image_id = self.ids[idx]target = {'image_id': image_id, 'annotations': target}img, target = self.prepare(img, target)if self._transforms is not None:img, target = self._transforms(img, target)return img, target

这里数据准备用到了ConvertCocoPolysToMask函数,可以跳入进去,这里面主要对标注进行了处理,将xywh形式的标注转为x1y1x2y2的标注格式,x1y1是标注框左上角的点,x2y2是标注框右下角的点。

只保留iscrowd == 0,就是单个目标没有重叠的

anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 

x y w h转换成了 x1y1 x2y2

boxes = [obj["bbox"] for obj in anno] # x y w h# guard against no boxes via resizingboxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)boxes[:, 2:] += boxes[:, :2]boxes[:, 0::2].clamp_(min=0, max=w)boxes[:, 1::2].clamp_(min=0, max=h)

过滤掉左上角坐标小于右下角坐标的标注(注意自己标数据的时候可能出现的),这种情况是在标注数据的时候画框从右下角拉到左上角,这样容易导致错误,所以需要过滤掉。

keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 

1.2模型解析

模型代码总的代码位于detr.py文件,我们可以跟随DETR类的forward函数来看模型数据处理的方式:

backbone

首先经过backbone,backbone对数据进行了两个操作,获取特征图和位置编码:

features, pos = self.backbone(samples)

跟入backbone中查看:

class Joiner(nn.Sequential):def __init__(self, backbone, position_embedding):super().__init__(backbone, position_embedding)def forward(self, tensor_list: NestedTensor):print(tensor_list.tensors.shape)xs = self[0](tensor_list)out: List[NestedTensor] = []pos = []for name, x in xs.items():print(x.tensors.shape)out.append(x)# position encodingpos.append(self[1](x).to(x.tensors.dtype))return out, pos

这里的xs = self[0](tensor_list)采用resnet获取特征图,而self[1](x).to(x.tensors.dtype)则是为了获取位置编码,获取特征图没有特别支出,就是走了resnet模型,获取位置编码是采用的正余弦的方式进行的,其具体操作代码位于position_encoding.py文件中的PositionEmbeddingSine类:
这里的mask是将非数据的特征值去掉,位置编码采用的是在二维矩阵的行方向和列方向求cumsum累加的操作:

        mask = tensor_list.maskassert mask is not Nonenot_mask = ~masky_embed = not_mask.cumsum(1, dtype=torch.float32) #行方向累加x_embed = not_mask.cumsum(2, dtype=torch.float32) #列方向累加

然后执行正余弦编码操作:

        # 映射成角度if self.normalize:eps = 1e-6y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scalex_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale#奇数和偶数变换dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)#计算正余弦pos_x = x_embed[:, :, :, None] / dim_tpos_y = y_embed[:, :, :, None] / dim_tpos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)

transformer

走完backbone之后我们回到detr的forward函数中,下一步就是走transformer函数了,transformer的定义位于transformer.py文件中,可以跟随代码跳入到forward函数中:

Encoder

首先进行encoder操作:

        #首先是对特征图、位置编码和mask编码的变换操作,同时生成了query_embed用于后续的decoder操作bs, c, h, w = src.shapesrc = src.flatten(2).permute(2, 0, 1)pos_embed = pos_embed.flatten(2).permute(2, 0, 1)query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)mask = mask.flatten(1)tgt = torch.zeros_like(query_embed)memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)

Encoder操作主要是走transformer操作,生成QKV然后计算attention,主要的操作位于TransformerEncoderLayer类中,这里计算QKV的时候只对QK进行了操作,没有对V进行操作,同时计算attention的操作直接使用了torch提供的attention计算方式:

    def forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):#只有K和Q 加入了位置编码;并没有对V做q = k = self.with_pos_embed(src, pos) #两个返回值:自注意力层的输出,自注意力权重;只需要第一个src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0] # 执行transformer连接之类的操作src = src + self.dropout1(src2)src = self.norm1(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = src + self.dropout2(src2)src = self.norm2(src)return src
Decoder

然后进行decoder操作:
跟着源码跳入TransformerDecoderLayer类的forward函数:
首先将query添加位置编码,进行自身注意力机制计算:

#这里的tgt初始值为0,融入位置编码后输入到自注意力的multihead attention的操作中
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0]

然后经过一些连接操作后进入到跟Encoder中生成的K和V的注意力计算机制,这个操作是全篇核心思想的实现,这里的attention操作也是直接使用的torch中提供的多头注意力机制计算方法。

tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]

然后进入一些连接等其他操作,然后输出query得到最终结果。
这里Encoder和decoder都是执行多层的。

1.3损失函数

计算损失主要有三个:分类损失、回归损失和giou。这一部分位于match.py文件中

    def forward(self, outputs, targets):""" Performs the matchingParams:outputs: This is a dict that contains at least these entries:"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinatestargets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truthobjects in the target) containing the class labels"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinatesReturns:A list of size batch_size, containing tuples of (index_i, index_j) where:- index_i is the indices of the selected predictions (in order)- index_j is the indices of the corresponding selected targets (in order)For each batch element, it holds:len(index_i) = len(index_j) = min(num_queries, num_target_boxes)"""bs, num_queries = outputs["pred_logits"].shape[:2]# We flatten to compute the cost matrices in a batchout_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]# Also concat the target labels and boxestgt_ids = torch.cat([v["labels"] for v in targets])tgt_bbox = torch.cat([v["boxes"] for v in targets])# Compute the classification cost. Contrary to the loss, we don't use the NLL,# but approximate it in 1 - proba[target class].# The 1 is a constant that doesn't change the matching, it can be ommitted.cost_class = -out_prob[:, tgt_ids]# Compute the L1 cost between boxescost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)# Compute the giou cost betwen boxescost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))# Final cost matrixC = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giouC = C.view(bs, num_queries, -1).cpu()sizes = [len(v["boxes"]) for v in targets]indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

2.预测

模型训练好之后可以使用模型进行预测,这一部分根据官方文档来即可,我这边简单贴一下:
准备工作:

import mathfrom PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'import ipywidgets as widgets
from IPython.display import display, clear_outputimport torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False);# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return bdef plot_results(pil_img, prob, boxes):plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}: {p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()

首先加载预训练模型,如果本地没有训练出来,可以使用官方已经训练好的,如下:

model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval();

加载图片:

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)

模型预测:

# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)# propagate through the model
outputs = model(img)# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

结果展示:

plot_results(im, probas[keep], bboxes_scaled)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/593558.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FairGuard游戏加固产品常见问题解答

针对日常对接中,各位用户对FairGuard游戏加固方案在安全性、稳定性、易用性、接入流程等方面的关注,我们梳理了相关问题与解答,希望可以让您对产品有一个初步的认知与认可。 Q1:FairGuard游戏加固产品都有哪些功能? A:FairGuar…

OpenFeign

OpenFeign 一、基本使用 1、引入依赖 <groupId>org.springframework.cloud</groupId> <artifactId>spring-cloud-starter-openfeign</artifactId><groupId>org.springframework.cloud</groupId> <artifactId>spring-cloud-start…

k8s的网络

k8s的网络 k8s中的通信模式&#xff1a; 1、pod内部之间容器与容器之间的通信 在同一个pod中的容器共享资源和网络&#xff0c;使用同一个网络命名空间&#xff0c;可以直接通信的 2、同一个node节点之内&#xff0c;不同pod之间的通信 每个pod都有一个全局的真实的ip地址…

BMS均衡技术

一、电池的不一致性&#xff1f; 每个电池都有自己的“个性”&#xff0c;要说均衡&#xff0c;得先从电池谈起。即使是同一厂家同一批次生产的电池&#xff0c;也都有自己的生命周期、自己的“个性”——每个电池的容量不可能完全一致。例如以下的两个原因都会造成电池不一致…

docker部署mysql

1.查找mysql镜像 [rootVM-4-5-centos ~]# docker search mysql NAME DESCRIPTION STARS OFFICIAL AUTOMATED mysql MySQL is a widely used, open-sourc…

AnyText:多语言视觉文字生成与编辑——最详细傻瓜式安装教程

先看图,下面都是AnyText生成的,可以说效果效果确实是很震撼了。 附上地址: GitHub - tyxsspa/AnyTextContribute to tyxsspa/AnyText development by creating an account on GitHub.https://github.com/tyxsspa/AnyText接下来开始详细讲解安装过程: 1. 下载项目 (1)下…

探讨JVM垃圾回收机制与内存泄漏

目录 1. 垃圾回收机制的基本原理 2. 内存泄漏的定义与表现 3. 垃圾回收机制的局限性 4. Finalizer导致的延迟 5. 不当使用静态集合 6. JNI资源未释放 7. 解决内存泄漏的方法 8. 结语 在Java虚拟机&#xff08;JVM&#xff09;的世界中&#xff0c;垃圾回收机制被设计用…

电磁波的信号加载说明

电磁波的信号加载电磁波(Electromagnetic wave)是由同相振荡 且互相垂直的电场与磁场在空间中衍生发射的振荡粒子波&#xff0c;是以波动的形式传播的电磁场&#xff0c;具有波粒二象性&#xff0c;其粒子形态称为光子&#xff0c;电磁波与光子不是非黑即白的关系&#xff0c;而…

外显和呼叫系统的关系

经常接到推销电销&#xff0c;对于不同号码显示&#xff0c;我们选择接听电话和挂断电话的概率也是不一样的。 我们接到号码有显示运营商和归属地名称。 例如&#xff1a;北京 移动&#xff0c;广东深圳 电信&#xff0c;广电&#xff0c;广东广州 虚拟运营商等&#xff1b; 有…

快速打通 Vue 3(二):响应式对象基础

很激动进入了 Vue 3 的学习&#xff0c;作为一个已经上线了三年多的框架&#xff0c;很多项目都开始使用 Vue 3 来编写了 这一组文章主要聚焦于 Vue 3 的新技术和新特性 如果想要学习基础的 Vue 语法可以看我专栏中的其他博客 Vue&#xff08;一&#xff09;&#xff1a;Vue 入…

Flink-【时间语义、窗口、水位线】

1. 时间语义 1.1 事件时间&#xff1a;数据产生的事件&#xff08;机器时间&#xff09;&#xff1b; 1.2 处理时间&#xff1a;数据处理的时间&#xff08;系统时间&#xff09;。 &#x1f330;&#xff1a;可乐 可乐的生产日期 事件时间&#xff08;可乐产生的时间&…

算法导论复习——CHP24 单源最短路

单源最短路径问题&#xff1a; 给定一个图G (V,E)&#xff0c;找出从给定的源点s∈V到其它每个结点v∈V的最短路径。 这样最短路径具有最优子结构性&#xff1a;两个结点之间的最短路径的任何子路径都是最短的。 基本概念 负权边&#xff1a;权重为负值的边称为负权重的边。 如…

Vue3+TS+ElementPlus的安装和使用教程【详细讲解】

前言 本文简单的介绍一下vue3框架的搭建和有关vue3技术栈的使用。通过本文学习我们可以自己独立搭建一个简单项目和vue3的实战。 随着前端的日月更新&#xff0c;技术的不断迭代提高&#xff0c;如今新vue项目首选用vue3 typescript vite pinia……模式。以前我们通常使用…

webpack知识点总结(高级应用篇)

除开公共基础配置之外&#xff0c;我们意识到两点: 1. 开发环境(modedevelopment),追求强大的开发功能和效率&#xff0c;配置各种方便开 发的功能;2. 生产环境(modeproduction),追求更小更轻量的bundle(即打包产物); 而所谓高级应用&#xff0c;实际上就是进行 Webpack 优化…

计算机组成原理-期末复习

目录 第一章——计算机系统概述 一、数字计算机的主要组成结构 二、指令的形式 三、控制器的基本任务 四、指令流和数据流 五、适配器与输入/输出设备 七、计算机的系统软件 八、C 语言的转换层次图 九、计算机系统的层次结构图 第二章——运算方法和运算器 一、 数据格式…

javascript之跳转页面的几种方法?

文章目录 前言代码演示及解释使用location.href属性使用location.assign()方法使用location.replace()方法使用window.open()方法使用document.URL方法 总结 前言 本章学习的是JavaScript中的跳转页面的几种方法 代码演示及解释 使用location.href属性 可以直接将一个新的URL…

企业如何做好客户管理?有哪些关键因素?

客户管理是建立和维护客户关系的重要组成部分&#xff0c;对于企业的发展至关重要。下面就让我们来看看在做好客户管理时有哪些关键因素吧。 第一个关键因素是提供优质的客户服务。无论是线上还是线下&#xff0c;当客户需要帮助时&#xff0c;他们希望能够得到有效且及时的支持…

sqlserver根据分组的内容分别查询出匹配的一条信息

需求场景&#xff1a; 我写了条分组语句&#xff0c; select name from car_machine_command group by name 然后该表有很多条相关的数据&#xff0c;我只想拿各个分组的一条数据看看即可 解决&#xff1a;可以使用窗口函数&#xff08;Window Function&#xff09;和 ROW_NU…

Dora-rs 机器人框架学习教程(1)—— Dora-rs安装

1、dora简介 Dora-rs[1] 是一个基于 Rust 实现的化机器人框架&#xff0c;其具有极高的实时性能。Dora-rs使用Rust语言做数据流的传输和调度管理&#xff0c;可以大大减少了数据的重复拷贝和传输。它提供了Rust语言和Python语言之间的无缝集成&#xff0c;减少了跨语言的性能代…

阿里云服务器Valheim端口2456、2457和2458放行设置

使用阿里云服务器搭建Valheim英灵神殿需要开启2456-2458端口&#xff0c;阿里云服务器默认只开放了22核3389端口&#xff0c;开通2456端口是在安全组中配置的&#xff0c;阿里云服务器网aliyunfuwuqi.com来详细说下阿里云服务器安全组开通端口流程&#xff1a; 阿里云服务器安…