经典目标检测YOLO系列(二)YOLOV2的复现(2)正样本的匹配、损失函数的实现及模型训练

经典目标检测YOLO系列(二)YOLOV2的复现(2)正样本的匹配、损失函数的实现及模型训练

我们在之前实现YOLOv1的基础上,加入了先验框机制,快速的实现了YOLOv2的网络架构,并且实现了前向推理过程。

经典目标检测YOLO系列(二)YOLOV2的复现(1)总体网络架构及前向推理过程

如前所述,我们使用基于先验框的正样本匹配策略。

1 正样本匹配策略

1.1 基于先验框的正样本匹配策略

  • 由于每个网格只输出一个边界框,因此在YOLOv1中的正样本匹配策略很简单,目标框的中心点落在哪个网格,这个网格(左上角点)就是正样本。
  • 但是,我们现在引入了先验框机制,每个网格会输出5个预测框。那么目标框的中心点所在的每一个网格,我们都需要确定这5个预测框中,哪些是正样本,哪些是负样本。
  • 既然我们已经有了具有边界框尺寸信息的先验框,那么我们可以基于先验框来筛选正样本。

假设一个含有目标框中心的网格上的5个先验框分别为A、B、C、D、E,那么需要计算这5个先验框与目标框O的IoU值,分别为:IoU_A、IoU_B、IoU_C、IoU_D、IoU_E,然后设定一个阈值iou_thresh:

  • 第1种情况:如果IoU_A、IoU_B、IoU_C、IoU_D、IoU_E都小于iou_thresh,为了不丢失这个训练样本,我们选择选择IoU值最大的先验框P_A。将P_A对应的预测框B_A,标记为正样本,即先验框决定哪些预测框会参与到何种损失的计算中去
  • 第2种情况:仅有一个IoU值大于iou_thresh,那么这个先验框所对应的预测框会被标记为正样本,会参与到置信度、类别及位置损失的计算。
  • 第3种情况:有多个IoU值大于iou_thresh,那么这些先验框所对应的预测框都会被标记为正样本,即一个目标会被匹配上多个正样本

这种正样本匹配策略,似乎保证了每个目标都会至少匹配上一个正样本,但其实存在漏洞。假如,有2个目标的中心点都落到了同一个目标框,可能会导致原本属于目标A的先验框后来又分配给目标B

  • 在YOLOv1中,2个目标的中心点都落到了同一个目标框,网络就只能学习一个。
  • 在YOLOv2中,虽然每个网格会输出多个预测框,但是在制作正样本时候,也会存在刚才说的语义歧义现象,会使得某些目标匹配不到正样本,其信息也就不会被网络学习到,不过我们现在不做处理。

在这里插入图片描述

1.2 代码实现

1.2.1 正样本匹配

pytorch读取VOC数据集:

  • 一批图像数据的维度是 [B, 3, H, W] ,分别是batch size,色彩通道数,图像的高和图像的宽。

  • 标签数据是一个包含 B 个图像的标注数据的python的list变量(如下所示),其中,每个图像的标注数据的list变量又包含了 M 个目标的信息(类别和边界框)。

  • 获得了这一批数据后,图片是可以直接喂到网络里去训练的,但是标签不可以,需要再进行处理一下。

[{'boxes':     tensor([[ 29., 230., 148., 321.]]),  # bbox的坐标(xmin, ymin, xmax, ymax)'labels':    tensor([18.]),                       # 标签'orig_size': [281, 500]                           # 图片的原始大小}, {'boxes':      tensor([[  0.,  79., 416., 362.]]), 'labels':     tensor([1.]),'orig_size': [375, 500]}
]

标签处理主要包括3个部分,

  • 一是将真实框中心所在网格对应正样本位置(anchor_idx)的置信度置为1,其他默认为0
  • 二是将真实框中心所在网格对应正样本位置(anchor_idx)的标签类别为1(one-hot格式),其他类别设置为0
  • 三是将真实框中心所在网格对应正样本位置(anchor_idx)的bbox信息设置为真实框的bbox信息。
# 处理好的shape如下:
# gt_objectness  
torch.Size([2, 845, 1])  # 845=13×13×5
# gt_classes
torch.Size([2, 845, 20])
# gt_bboxes
torch.Size([2, 845, 4])

1.2.2 具体代码实现

# RT-ODLab/models/detectors/yolov2/matcher.pyimport torch
import numpy as npclass Yolov2Matcher(object):def __init__(self, iou_thresh, num_classes, anchor_size):self.num_classes = num_classesself.iou_thresh = iou_thresh# anchor boxself.num_anchors = len(anchor_size)self.anchor_size = anchor_sizeself.anchor_boxes = np.array([ [0., 0., anchor[0], anchor[1]] for anchor in anchor_size])  # [KA, 4]def compute_iou(self, anchor_boxes, gt_box):"""函数功能: 计算目标框和5个先验框的IoU值anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).gt_box : ndarray -> [1, 4] (cx, cy, bw, bh).返回值: iou变量,类型为ndarray类型,shape为[5,], iou[i]就表示该目标框和第i个先验框的IoU值"""# 1、计算5个anchor_box的面积# anchors: [KA, 4]anchors = np.zeros_like(anchor_boxes)anchors[..., :2] = anchor_boxes[..., :2] - anchor_boxes[..., 2:] * 0.5  # x1y1anchors[..., 2:] = anchor_boxes[..., :2] + anchor_boxes[..., 2:] * 0.5  # x2y2anchors_area = anchor_boxes[..., 2] * anchor_boxes[..., 3]# 2、gt_box复制5份,计算5个相同gt_box的面积# gt_box: [1, 4] -> [KA, 4]gt_box = np.array(gt_box).reshape(-1, 4)gt_box = np.repeat(gt_box, anchors.shape[0], axis=0)gt_box_ = np.zeros_like(gt_box)gt_box_[..., :2] = gt_box[..., :2] - gt_box[..., 2:] * 0.5  # x1y1gt_box_[..., 2:] = gt_box[..., :2] + gt_box[..., 2:] * 0.5  # x2y2gt_box_area = np.prod(gt_box[..., 2:] - gt_box[..., :2], axis=1)# 3、计算计算目标框和5个先验框的IoU值# intersection  交集inter_w = np.minimum(anchors[:, 2], gt_box_[:, 2]) - \np.maximum(anchors[:, 0], gt_box_[:, 0])inter_h = np.minimum(anchors[:, 3], gt_box_[:, 3]) - \np.maximum(anchors[:, 1], gt_box_[:, 1])inter_area = inter_w * inter_h# unionunion_area = anchors_area + gt_box_area - inter_area# iouiou = inter_area / union_areaiou = np.clip(iou, a_min=1e-10, a_max=1.0)return iou@torch.no_grad()def __call__(self, fmp_size, stride, targets):"""img_size: (Int) input image sizestride: (Int) -> stride of YOLOv1 output.targets: (Dict) dict{'boxes': [...], 'labels': [...], 'orig_size': ...}"""# preparebs = len(targets)fmp_h, fmp_w = fmp_sizegt_objectness = np.zeros([bs, fmp_h, fmp_w, self.num_anchors, 1]) gt_classes = np.zeros([bs, fmp_h, fmp_w, self.num_anchors, self.num_classes]) gt_bboxes = np.zeros([bs, fmp_h, fmp_w, self.num_anchors, 4])# 第一层for循环遍历每一张图像的标签for batch_index in range(bs):# targets_per_image是python的Dict类型targets_per_image = targets[batch_index]# [N,] N表示一个图像中有N个目标对象tgt_cls = targets_per_image["labels"].numpy()# [N, 4]tgt_box = targets_per_image['boxes'].numpy()# 第二层for循环遍历这张图像标签的每一个目标数据for gt_box, gt_label in zip(tgt_box, tgt_cls):x1, y1, x2, y2 = gt_box# xyxy -> cxcywhxc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5bw, bh = x2 - x1, y2 - y1gt_box = [0, 0, bw, bh]# checkif bw < 1. or bh < 1.:continue    # 1、计算该目标框和5个先验框的IoU值iou = self.compute_iou(self.anchor_boxes, gt_box)iou_mask = (iou > self.iou_thresh)# 2、基于先验框的标签分配策略label_assignment_results = []# 第一种情况:所有的IoU值均低于阈值,选择IoU最大的先验框if iou_mask.sum() == 0:# We assign the anchor box with highest IoU score.iou_ind = np.argmax(iou)anchor_idx = iou_ind# compute the grid cellxc_s = xc / strideyc_s = yc / stridegrid_x = int(xc_s)grid_y = int(yc_s)label_assignment_results.append([grid_x, grid_y, anchor_idx])else:# 第二种和第三种情况:至少有一个IoU值大于阈值for iou_ind, iou_m in enumerate(iou_mask):if iou_m:anchor_idx = iou_ind# compute the gride cellxc_s = xc / strideyc_s = yc / stridegrid_x = int(xc_s)grid_y = int(yc_s)label_assignment_results.append([grid_x, grid_y, anchor_idx])# label assignment# 获取到被标记为正样本的先验框,我们就可以为这次先验框对应的预测框制作学习标签for result in label_assignment_results:grid_x, grid_y, anchor_idx = resultif grid_x < fmp_w and grid_y < fmp_h:# objectness标签,采用0,1离散值gt_objectness[batch_index, grid_y, grid_x, anchor_idx] = 1.0# classification标签,采用one-hot格式cls_ont_hot = np.zeros(self.num_classes)cls_ont_hot[int(gt_label)] = 1.0gt_classes[batch_index, grid_y, grid_x, anchor_idx] = cls_ont_hot# box标签,采用目标框的坐标值gt_bboxes[batch_index, grid_y, grid_x, anchor_idx] = np.array([x1, y1, x2, y2])# [B, H, W, A, C] -> [B, HWA, C]gt_objectness = gt_objectness.reshape(bs, -1, 1)gt_classes = gt_classes.reshape(bs, -1, self.num_classes)gt_bboxes = gt_bboxes.reshape(bs, -1, 4)# to tensorgt_objectness = torch.from_numpy(gt_objectness).float()gt_classes = torch.from_numpy(gt_classes).float()gt_bboxes = torch.from_numpy(gt_bboxes).float()return gt_objectness, gt_classes, gt_bboxesif __name__ == '__main__':anchor_size  = [[17, 25], [55, 75], [92, 206], [202, 21], [289, 311]]matcher = Yolov2Matcher(iou_thresh=0.5, num_classes=20, anchor_size=anchor_size)targets = [{'boxes':     torch.tensor([[ 29., 230., 148., 321.]]),  # bbox的坐标(xmin, ymin, xmax, ymax)'labels':    torch.tensor([18.]),                       # 标签'orig_size': [281, 500]                                 # 图片的原始大小},{'boxes':      torch.tensor([[  0.,  79., 416., 362.]]),'labels':     torch.tensor([1.]),'orig_size': [375, 500]}
]gt_objectness, gt_classes, gt_bboxes = matcher(fmp_size=(13, 13),stride=32, targets=targets )print(gt_objectness.shape)print(gt_classes.shape)print(gt_bboxes.shape)
  • 最终这段代码返回了gt_objectness, gt_classes, gt_bboxes三个Tensor类型的变量:
    • gt_objectness包含一系列的0和1,标记了哪些预测框是正样本,哪些预测框是负样本
    • gt_classes包含一系列的one-hot格式的类别标签
    • gt_bboxes包含的是正样本要学习的边界框的位置参数
  • 在上述代码实现中,在计算IoU时候,我们将目标框的中心点坐标和先验框的中心点坐标都设置为0,这是因为一个目标框在做匹配时候,仅仅考虑到目标框中心点所在的网格中的5个先验框,周围的网格都不进行考虑
  • 在SSD以及Faster R-CNN中,每一个目标框都是和全局的先验框去计算IoU,这些算法都会考虑目标框的中心点坐标和先验框的中心点坐标。因此,其每一个目标框匹配上的先验框不仅来自中心点所在的网格,也会来自周围的网格。这是YOLO和其他工作一个重要差别所在,YOLO这种只考虑中心点的做法,处理起来更加简便、更易学习。

2 损失函数的计算、YOLOv2的训练

2.1 损失函数的计算

  • YOLOv2损失函数计算(RT-ODLab/models/detectors/yolov2/loss.py)和之前实现的YOLOv1基本一致,不再赘述
  • 我们实现的YOLOv2和之前实现的YOLOv1相比,仅仅多了先验框以及由此带来的正样本匹配上的一些细节上的差别。

2.2 YOLOv2的训练

  • 完成了YOLOv2的网络搭建,标签匹配以及损失函数的计算,就可以进行训练了

  • 数据读取、数据预处理及数据增强操作,和之前实现的YOLOv1一致,不再赘述

  • YOLOv1和YOLOv2都在同一个项目代码中,数据代码、训练代码及测试代码均一致,我们只需要修改训练脚本即可

    nohup python -u train.py --cuda \-d voc                 \-m yolov2              \-bs 16                 \-size 640              \--wp_epoch 3           \--max_epoch 150        \--eval_epoch 10        \--no_aug_epoch 10      \--ema                  \--fp16                 \--multi_scale          \--num_workers 8 1>./logs/yolo_v2_train_log.txt 2>./logs/yolo_v2_warning_log.txt &
    

相关参数讲解可以参考YOLOv1:

经典目标检测YOLO系列(一)复现YOLOV1(5)模型的训练及验证

2.3 可视化检测结果、计算mAP指标

  • 训练结束后,模型默认保存在weights/voc/yolov2/文件夹下,名为yolov2_voc_best.pth,保存了训练阶段在测试集上mAP指标最高的模型。

  • 运行项目中所提供的eval.py文件可以验证模型的性能,具体命令如下行所示

  • 可以给定不同的图像尺寸来测试实现的YOLOv1在不同输入尺寸下的性能

    python eval.py \
    --cuda -d voc \
    --root path/to/voc -m yolov2 \
    --weight path/to/yolov2_voc_best.pth \
    -size 416
    
  • 也可以可视化训练好的模型

    python test.py \
    --cuda -d voc \
    --root path/to/voc -m yolov2 
    --weight path/to/yolov2_voc_best.pth \
    -size 416 -vt 0.3 \
    --show# -size表示输入图像的最大边尺寸
    # -vt是可视化的置信度阈值,只有高于此值的才会被可视化出来
    # --show表示展示检测结果的可视化图片
    

2.4 训练结果

《YOLO目标检测》作者训练好的模型,在VOC2007测试集测试指标如下:

从表格中可以看到,实现的YOLOv2达到了官方YOLOv2的性能。

模型输入尺寸mAP(%)
YOLOv2*(官方)41676.8
YOLOv2*(官方)48077.8
YOLOv2*(官方)54478.6
YOLOv241676.8
YOLOv248078.4
YOLOv254479.6
YOLOv264079.8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/633030.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Go面试向】rune和byte类型的认识与使用

【Go】rune和byte类型的认识与使用 大家好 我是寸铁&#x1f44a; 总结了一篇rune和byte类型的认识与使用的文章✨ 喜欢的小伙伴可以点点关注 &#x1f49d; byte和rune类型定义 byte,占用1个字节&#xff0c;共8个比特位&#xff0c;所以它实际上和uint8没什么本质区别,它表示…

Joern环境的安装(Windows版)

Joern环境的安装(Windows版) 网上很少有关于Windows下安装Joern的教程&#xff0c;而我最初使用也是装在Ubuntu虚拟机中&#xff0c;这样使用很占内存&#xff0c;影响体验感。在Windows下使用源码安装Joern也是非常简单的过程&#xff1a; 提前需要的本地环境&#xff1a; …

基于Java+SSM框架的办公用品管理系统详细设计和实现【附源码】

基于JavaSSM框架的办公用品管理系统详细设计和实现【附源码】 &#x1f345; 作者主页 央顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; &#x1f345; 查看下方微信号获取联系方式 承接各种定…

GO 中如何防止 goroutine 泄露

文章目录 概述如何监控泄露一个简单的例子泄露情况分类chanel 引起的泄露发送不接收接收不发送nil channel真实的场景 传统同步机制MutexWaitGroup 总结参考资料 今天来简单谈谈&#xff0c;Go 如何防止 goroutine 泄露。 概述 Go 的并发模型与其他语言不同&#xff0c;虽说它…

蓝天采集器,功能逆天的网站数据抓取神器,轻松助你成为采集达人,附带搭建配置文档

源码介绍 蓝天采集器是一款专为web服务器打造的数据采集神器。与市面上常见的桌面端采集工具&#xff08;如火车头等&#xff09;相比&#xff0c;蓝天采集器在易用性、上手成本和灵活性方面更胜一筹。它部署简便&#xff0c;无需复杂的设置&#xff0c;即可迅速融入您的web服…

详解IP安全:IPSec协议簇 | AH协议 | ESP协议 | IKE协议_ipsec esp

目录 IP安全概述 IPSec协议簇 IPSec的实现方式 AH&#xff08;Authentication Header&#xff0c;认证头&#xff09; ESP&#xff08;Encapsulating Security Payload&#xff0c;封装安全载荷&#xff09; IKE&#xff08;Internet Key Exchange&#xff0c;因特网密钥…

storm统计服务开启zookeeper、kafka 、Storm(sasl认证)

部署storm统计服务开启zookeeper、kafka 、Storm&#xff08;sasl认证&#xff09; 当前测试验证结果&#xff1a; 单独配置zookeeper 支持acl 设置用户和密码&#xff0c;在storm不修改代码情况下和kafka支持当kafka 开启ACL时&#xff0c;storm 和ccod模块不清楚配置用户和密…

2018年认证杯SPSSPRO杯数学建模A题(第二阶段)海豚与沙丁鱼全过程文档及程序

2018年认证杯SPSSPRO杯数学建模 基于聚类分析的海豚捕食合作策略 A题 海豚与沙丁鱼 原题再现&#xff1a; 沙丁鱼以聚成大群的方式来对抗海豚的捕食。由于水下光线很暗&#xff0c;所以在距离较远时&#xff0c;海豚只能使用回声定位方法来判断鱼群的整体位置&#xff0c;难…

第4章 C++的类

类的保留字&#xff1a;class、struct 或 union 可用来声明和定义类。类的声明由保留字class、struct或union加上类的名称构成。类的定义包括类的声明部分和类的由{}括起来的主体两部分构成。类的实现通常指类的函数成员的实现&#xff0c;即定义类的函数成员。 class 类名;//…

C#,字符串匹配(模式搜索)Sunday算法的源代码

Sunday算法是Daniel M.Sunday于1990年提出的一种字符串模式匹配算法。 核心思想&#xff1a;在匹配过程中&#xff0c;模式串并不被要求一定要按从左向右进行比较还是从右向左进行比较&#xff0c;它在发现不匹配时&#xff0c;算法能跳过尽可能多的字符以进行下一步的匹配&…

港科夜闻|香港科大团队研发多功能,可重构和抗破坏单线感测器阵列

关注并星标 每周阅读港科夜闻 建立新视野 开启新思维 1、香港科大团队研发多功能、可重构和抗破坏单线感测器阵列。研究人员开发出一种受人类听觉系统启发的感测器阵列设计技术。透过模仿人耳根据音位分布来区分声音的能力&#xff0c;这种新型感测器阵列方法可能优化感测器阵列…

Yolov8_使用自定义数据集训练模型1

前面几篇文章介绍了如何搭建Yolov8环境、使用默认的模型训练和推理图片及视频的效果、并使用GPU版本的torch加速推理、导出.engine格式的模型进一步利用GPU加速&#xff0c;本篇介绍如何自定义数据集&#xff0c;这样就可以训练出识别特定物体的模型。 《Yolov8_使用自定义数据…

innoDB存储引擎

1.逻辑存储结构 行数据->行->页->区->段->表空间 表空间(ibd文件)&#xff0c;一个mysql实例可以对应多个表空间&#xff0c;来存储记录&#xff0c;索引等数据。 段&#xff1a;分为数据段和索引段&#xff0c;回滚段&#xff0c;数据段就是B树的叶子节点&am…

HR3D+HRAuido+HRUI+HR3D_Plugins(游戏引擎源码)

国内知名游戏公司开发的游戏引擎&#xff0c;简洁高效&#xff0c;代码值得参考。包含了这几部分&#xff1a;HR3DHRAuidoHRUIHR3D_Plugins HR3DHRAuidoHRUIHR3D_Plugins&#xff08;游戏引擎源码&#xff09; 下载地址&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1…

使用xbindkeys设置鼠标侧键

1.安装如下包 sudo apt install xbindkeys xautomation 2.生成配置文件 xbindkeys --defaults > $HOME/.xbindkeysrc 3.确定侧键键号 在终端执行下面的代码&#xff1a; xev | grep button 此时会出现如下窗口&#xff0c;将鼠标指针移动到这个窗口上&#xff1a; 单…

【机器学习】调配师:咖啡的完美预测

有一天&#xff0c;小明带着一脸期待找到了你这位数据分析大师。他掏出手机&#xff0c;屏幕上展示着一份详尽的Excel表格。“看&#xff0c;这是我咖啡店过去一年的数据。”他滑动着屏幕&#xff0c;“每个月的销售量、广告投入&#xff0c;还有当月的气温&#xff0c;我都记录…

【MYSQL】事务隔离级别

脏读、幻读、不可重复读 脏读 一个事务正在对一条记录做修改&#xff0c;在这个事务完成并提交前&#xff0c;另一个事务也来读取同一条记录&#xff0c;读取了这些未提交的“脏”数据&#xff0c;并据此做进一步的处理&#xff0c;就会产生未提交的数据依赖关系。这种现象被形…

【控制篇 / 分流】(7.4) ❀ 01. 对指定IP网段访问进行分流 ❀ FortiGate 防火墙

【简介】公司有两条宽带&#xff0c;一条ADSL拨号用来上网&#xff0c;一条移动SDWAN&#xff0c;已经连通总部内网服务器&#xff0c;领导要求&#xff0c;只有访问公司服务器IP时走移动SDWAN&#xff0c;其它访问都走ADSL拨号&#xff0c;如果你是管理员&#xff0c;你知道有…

自定义 React Hooks:编写高效、整洁和可重用代码的秘密武器

欢迎来到神奇的 React 世界 大家好!在 React 的世界中,有一个强大的秘密武器,它往往隐藏在显而易见的地方,由于缺乏理解或熟悉而没有得到充分利用。 这个强大的工具,被称为自定义 React hooks,可以彻底改变我们编写 React 应用程序代码的方式。通过提取组件中的有状态逻辑,自…

查找局域网树莓派raspberry的mac地址和ip

依赖python库&#xff1a; pip install socket pip install scapy运行代码&#xff1a; import socket from scapy.layers.l2 import ARP, Ether, srpdef get_hostname(ip_address):try:return socket.gethostbyaddr(ip_address)[0]except socket.herror:# 未能解析主机名ret…