YOLO物体检测-系列教程6:YOLOV3源码解读4之 YOLO层

🎈🎈🎈YOLO 系列教程 总目录

上篇内容:

YOLOV3项目实战1之 整体介绍与数据处理

YOLOV3提出论文:《Yolov3: An incremental improvement》

6、yolo层

6.1 yolo层

class YOLOLayer(nn.Module):"""Detection layer"""def __init__(self, anchors, num_classes, img_dim=416):def compute_grid_offsets(self, grid_size, cuda=True):def forward(self, x, targets=None, img_dim=None):

6.2 构造函数

    def __init__(self, anchors, num_classes, img_dim=416):super(YOLOLayer, self).__init__()self.anchors = anchorsself.num_anchors = len(anchors)self.num_classes = num_classesself.ignore_thres = 0.5self.mse_loss = nn.MSELoss()self.bce_loss = nn.BCELoss()self.obj_scale = 1self.noobj_scale = 100self.metrics = {}self.img_dim = img_dimself.grid_size = 0  # grid size

6.3 偏移量计算

    def compute_grid_offsets(self, grid_size, cuda=True):self.grid_size = grid_sizeg = self.grid_sizeFloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensorself.stride = self.img_dim / self.grid_size# Calculate offsets for each gridself.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, g, g]).type(FloatTensor)self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, g, g]).type(FloatTensor)self.scaled_anchors = FloatTensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors])self.anchor_w = self.scaled_anchors[:, 0:1].view((1, self.num_anchors, 1, 1))self.anchor_h = self.scaled_anchors[:, 1:2].view((1, self.num_anchors, 1, 1))

6.4 前向传播

    def forward(self, x, targets=None, img_dim=None):# Tensors for cuda supportprint (x.shape)FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensorLongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensorByteTensor = torch.cuda.ByteTensor if x.is_cuda else torch.ByteTensorself.img_dim = img_dimnum_samples = x.size(0)grid_size = x.size(2)prediction = (x.view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size).permute(0, 1, 3, 4, 2).contiguous())print (prediction.shape)# Get outputsx = torch.sigmoid(prediction[..., 0])  # Center xy = torch.sigmoid(prediction[..., 1])  # Center yw = prediction[..., 2]  # Widthh = prediction[..., 3]  # Heightpred_conf = torch.sigmoid(prediction[..., 4])  # Confpred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.# If grid size does not match current we compute new offsetsif grid_size != self.grid_size:self.compute_grid_offsets(grid_size, cuda=x.is_cuda) #相对位置得到对应的绝对位置比如之前的位置是0.5,0.5变为 11.5,11.5这样的# Add offset and scale with anchors #特征图中的实际位置pred_boxes = FloatTensor(prediction[..., :4].shape)pred_boxes[..., 0] = x.data + self.grid_xpred_boxes[..., 1] = y.data + self.grid_ypred_boxes[..., 2] = torch.exp(w.data) * self.anchor_wpred_boxes[..., 3] = torch.exp(h.data) * self.anchor_houtput = torch.cat( (pred_boxes.view(num_samples, -1, 4) * self.stride, #还原到原始图中pred_conf.view(num_samples, -1, 1),pred_cls.view(num_samples, -1, self.num_classes),),-1,)if targets is None:return output, 0else:iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(pred_boxes=pred_boxes,pred_cls=pred_cls,target=targets,anchors=self.scaled_anchors,ignore_thres=self.ignore_thres,)# iou_scores:真实值与最匹配的anchor的IOU得分值 class_mask:分类正确的索引  obj_mask:目标框所在位置的最好anchor置为1 noobj_mask obj_mask那里置0,还有计算的iou大于阈值的也置0,其他都为1 tx, ty, tw, th, 对应的对于该大小的特征图的xywh目标值也就是我们需要拟合的值 tconf 目标置信度# Loss : Mask outputs to ignore non-existing objects (except with conf. loss)loss_x = self.mse_loss(x[obj_mask], tx[obj_mask]) # 只计算有目标的loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])loss_h = self.mse_loss(h[obj_mask], th[obj_mask])loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask]) loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj #有物体越接近1越好 没物体的越接近0越好loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask]) #分类损失total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls #总损失# Metricscls_acc = 100 * class_mask[obj_mask].mean()conf_obj = pred_conf[obj_mask].mean()conf_noobj = pred_conf[noobj_mask].mean()conf50 = (pred_conf > 0.5).float()iou50 = (iou_scores > 0.5).float()iou75 = (iou_scores > 0.75).float()detected_mask = conf50 * class_mask * tconfprecision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)self.metrics = {"loss": to_cpu(total_loss).item(),"x": to_cpu(loss_x).item(),"y": to_cpu(loss_y).item(),"w": to_cpu(loss_w).item(),"h": to_cpu(loss_h).item(),"conf": to_cpu(loss_conf).item(),"cls": to_cpu(loss_cls).item(),"cls_acc": to_cpu(cls_acc).item(),"recall50": to_cpu(recall50).item(),"recall75": to_cpu(recall75).item(),"precision": to_cpu(precision).item(),"conf_obj": to_cpu(conf_obj).item(),"conf_noobj": to_cpu(conf_noobj).item(),"grid_size": grid_size,}return output, total_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/80447.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

竞赛 基于机器视觉的银行卡识别系统 - opencv python

1 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的银行卡识别算法设计 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng…

Vue知识系列(7)每天10个小知识点

目录 系列文章目录Vue知识系列(1)每天10个小知识点Vue知识系列(2)每天10个小知识点Vue知识系列(3)每天10个小知识点Vue知识系列(4)每天10个小知识点Vue知识系列(5&#x…

Android 格式化存储之Formatter

格式化存储相关的数值时,可以用 android.text.format.Formatter 。 Formatter.formatFileSize(Context context, long sizeBytes) 源码说明,在 Android O 后,存储单位的进制是 1000 ,Android N 之前单位进制是 1024 。 /*** Fo…

【ARM Coresight 系列文章 21 -- SoC-400 介绍 】

文章目录 1. Coresight SoC-4001.1 DAP 组件1.2 SWJ-DP1.3 DAPBUS互联1.4 AXI-AP1.5 APB-AP2. 互联2.1 APB互联组件2.2 ATB互联组件2.2.1 replicator2.2.2 funnel2.2.3 upsizer2.2.4 downsizer2.2.5 asynchronous bridge2.2.6 synchronous bridge3. Timestamp 组件4. ECT组件&l…

C【动态内存管理】

1. 为什么存在动态内存分配 int val 20;//在栈空间上开辟四个字节 char arr[10] {0};//在栈空间上开辟10个字节的连续空间 2. 动态内存函数的介绍 2.1 malloc&#xff1a;stdlib.h void* malloc (size_t size); int* p (int*)malloc(40); #include <stdlib.h> #incl…

Web服务(Web Service)

简介 Web服务&#xff08;Web Service&#xff09;是一种Web应用开发技术&#xff0c;用XML描述、发布、发现Web服务。它可以跨平台、进行分布式部署。 Web服务包含了一套标准&#xff0c;例如SOAP、WSDL、UDDI&#xff0c;定义了应用程序如何在Web上实现互操作。 Web服务的服…

非对称加密系统和LINUX实践

对称加密和非对称加密 非对称加密: 非对称加密是一种加密技术,它使用一对密钥来进行数据的加密和解密,这一对密钥分别称为公钥(public key)和私钥(private key)。这两个密钥是数学相关的,并且彼此相关,但不能相互推导出来。 以下是非对称加密的基本工作原理: 公钥…

类与对象的创建

package com.mypackage.oop.later;//学生类 //类里面只存在属性和方法 public class Student {//属性&#xff1a;字段//在类里面方法外面定义一个属性&#xff08;或者说是变量&#xff09;&#xff0c;然后在方法里面对他进行不同的实例化String name; //会有一个默认值&…

Android studio 断点调试、日志断点

目录 参考文章参考文章1、运行调试2、调试操作3、断点类型行断点的使用场景属性断点的使用场景异常断点的使用场景方法断点的使用场景条件断点日志断点 4、断点管理区 参考文章 参考文章 1、运行调试 开启 Debug 调试模式有两种方式&#xff1a; Debug Run&#xff1a;直接…

windows下C++的反射功能

概述 c/c如果在日志中查看某个结构体/类的每个变量名&#xff0c;变量值信息&#xff0c;只能通过printf逐个格式化&#xff0c;非常繁琐&#xff0c;如何做到类似protobuff转json的序列化功能呢&#xff1f;该dll库先通过分析pdb文件获取结构体/类的变量名称、变量地址&#…

vue3将页面导出成PDF文件(完美解决图片、表格内容分割问题)

vue3将页面导出成PDF文件&#xff08;完美解决图片、表格内容分割问题&#xff09; 1、安装依赖2、在utils中创建htmlToPDF.js文件3、在vue中引入并使用 1、安装依赖 npm install --save html2canvas // 页面转图片 npm install jspdf --save // 图片转pdf2、在utils中创建h…

数据驱动成功:小程序积分商城的数据分析

在当今数字化时代&#xff0c;数据被认为是企业成功的关键。小程序积分商城是一种流行的营销工具&#xff0c;可帮助企业吸引和留住客户&#xff0c;并提供有关客户行为和偏好的宝贵数据。本文将深入探讨如何通过数据分析实现小程序积分商城的成功&#xff0c;包括数据的收集、…

Linux内核 6.6版本将遏制NVIDIA驱动的不正当行为

Linux 内核开发团队日前宣布&#xff0c;即将发布的 Linux 6.6 版本将增强内核模块机制&#xff0c;以更好地防御 NVIDIA 闭源驱动的不正当行为。 Linux 内核开发团队日前宣布&#xff0c;即将发布的 Linux 6.6 版本将增强内核模块机制&#xff0c;以更好地防御 NVIDIA 闭源驱…

linux shell操作- 02 常用命令及案例

文章目录 常用命令 续 常用命令 续 定时任务 通过文本编辑cron任务&#xff0c;实现定时操作 分 小时 天 月 星期 绝对路径sh or cmd* 表示每个xxx&#xff0c;如每个小时每小时的第三分钟执行cmd-> 03 * * * * /home/lauf/scraw.sh每天的第5、8个小时执行-> 00 5,8 * *…

Golang反射相关知识总结

1. Golang反射概述 Go语言的反射&#xff08;reflection&#xff09;是指在运行时动态地获取类型信息和操作对象的能力。在Go语言中&#xff0c;每个值都是一个接口类型&#xff0c;这个接口类型包含了这个值的类型信息和值的数据&#xff0c;因此&#xff0c;通过反射&#x…

C/C++—Inline关键词

1、引入 inline 关键字的原因 在 c/c 中&#xff0c;为了解决一些频繁调用的小函数大量消耗栈空间&#xff08;栈内存&#xff09;的问题&#xff0c;特别的引入了 inline 修饰符&#xff0c;表示为内联函数。 在系统下&#xff0c;栈空间是有限的&#xff0c;假如频繁大量的…

大二上学期学习计划

这个学期主要学习的技术有SpringBoot&#xff0c;Vue&#xff0c;MybatisPlus&#xff0c;redis&#xff0c;还有要坚持刷题&#xff0c;算法不能落下&#xff0c;要坚持一天至少刷2道题目&#xff0c;如果没有布置任务就刷洛谷上面的&#xff0c;有任务的话就尽量完成任务&…

win11 Windows hello录入指纹失败解决方法

刚换了xps&#xff0c;启用了administrator账号&#xff0c;win11专业版&#xff0c;发现使用Windows hello录入指纹时&#xff0c;只要一录指纹就立即出错 尝试卸载重装设备驱动--无效 把Windows update更新到最新--无效 最后查到&#xff0c;是Windows对administrator账户进…

在MuJoCo环境下详细实现PPO算法与Hopper-v2应用教程: 深度学习强化学习实战指南

第一部分:简介与MuJoCo环境的配置 1.简介 强化学习已经在许多任务中展现了其强大的能力,从简单的游戏到复杂的机器人控制。今天,我们将集中讨论PPO(Proximal Policy Optimization)算法,一个已经被证明在多种任务中具有卓越性能的强化学习算法。特别地,我们将在MuJoCo模…

【React】React入门

目录 一、何为React二、React与传统MVC的关系三、React的特性1、声明式编程①、实现标记地图 2、高效灵活3、组件式开发(Component)①、函数式组件②、类组件&#xff08;有状态组件&#xff09;③、一个组件该有的特点 4、单向式响应的数据流 四、虚拟DOM1、传统DOM更新①、举…