【目标检测】yolov8结构及代码分析

yolov8代码:https://github.com/ultralytics/ultralytics

yolov8的整体结构如下图(来自mmyolo):

yolov8的配置文件:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

可以看出,主要包含Conv,C2f,SPPF,Concat,Detect几个模块。

一、Conv

Conv模块包含卷积层、BN层和激活函数层。

代码如下:

class Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU()  # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))

二、C2f

C2f就是模型结构图中的CSPLayer_2Conv,包含多个DarkNetBottleNeck。代码如下:

class C2f(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))class Bottleneck(nn.Module):"""Standard bottleneck."""def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, andexpansion."""super().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, k[0], 1)self.cv2 = Conv(c_, c2, k[1], 1, g=g)self.add = shortcut and c1 == c2def forward(self, x):"""'forward()' applies the YOLO FPN to input data."""return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

三、SPPF

yolov8的SPPF实现和mmyolo的结构图有点区别,只有一个池化层代码如下:

class SPPF(nn.Module):"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""def __init__(self, c1, c2, k=5):"""Initializes the SPPF layer with given input/output channels and kernel size.This module is equivalent to SPP(k=(5, 9, 13))."""super().__init__()c_ = c1 // 2  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c_ * 4, c2, 1, 1)self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)def forward(self, x):"""Forward pass through Ghost Convolution block."""x = self.cv1(x)y1 = self.m(x)y2 = self.m(y1)return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))

四、Concat

和torch.cat功能几乎一致:

class Concat(nn.Module):"""Concatenate a list of tensors along dimension."""def __init__(self, dimension=1):"""Concatenates a list of tensors along a specified dimension."""super().__init__()self.d = dimensiondef forward(self, x):"""Forward pass for the YOLOv8 mask Proto module."""return torch.cat(x, self.d)

五、Detect

yolov8的检测头:

class Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False  # force grid reconstructionexport = False  # export modeshape = Noneanchors = torch.empty(0)  # initstrides = torch.empty(0)  # initdef __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc  # number of classesself.nl = len(ch)  # number of detection layersself.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4  # number of outputs per anchorself.stride = torch.zeros(self.nl)  # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""shape = x[0].shape  # BCHWfor i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapex_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesif self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_sizey = torch.cat((dbox, cls.sigmoid()), 1)return y if self.export else (y, x)def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self  # self.model[-1]  # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride):  # froma[-1].bias.data[:] = 1.0  # boxb[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/585386.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

偏好对齐RLHF-OpenAI·DeepMind·Anthropic对比分析

OpenAI paper: InstructGPT, Training language models to follow instructions with human feedback paper: Learning to summarize from human feedback Introducing ChatGPT 解密Prompt系列4介绍了InstructGPT指令微调的部分,这里只看偏好对齐的部分 样本构建…

Zookeeper-Zookeeper特性与节点数据类型详解

1.Zookeeper介绍 ZooKeeper 是一个开源的分布式协调框架,是Apache Hadoop 的一个子项目,主要用来解决分布式集群中应用系统的一致性问题。Zookeeper 的设计目标是将那些复杂目容易出错的分布式一致性服务封装起来,构成一高效可靠的原语集&…

Python列表的介绍与操作 增改查,连接,赋值,复制,清空

列表 在日常中我们通过给变量赋值来存储数据,比如 a "hello" b "world" c "你好啊" d "....."由于变量一次只能存储一个数据,但我们如果想一次存储多个数据,的话这样存储会很复杂,所以,我们可以通过列表 列表(List)是Python中的…

Linux驱动开发学习笔记6《蜂鸣器实验》

目录 一、蜂鸣器驱动原理 二、硬件原理分析 三、实验程序编写 1、 修改设备树文件 (1)添加pinctrl节点 (2)添加BEEP设备节点 (3)检查PIN 是否被其他外设使用 2、蜂鸣器驱动程序编写 3、编写测试AP…

基于js和html的骰子游戏

介绍: 1.游戏者选择“大”时,三个骰子点数之和为11-18时,游戏者获胜。2.游戏者选择“小”时,三个骰子点数之和为3-10时,游戏者获胜。3.如果游戏者选择具体点数,则根据三个骰子的点数计算,如果与…

Linux 安装Jupyter notebook 并开启远程访问

文章目录 安装Python安装pip安装Jupyter启动Jupyter Notebook1. 生成配置文件2. 创建密码3. 修改jupyter notebook的配置文件4. 启动jupyter notebook5. 远程访问jupyter notebook 安装Python 确保你的系统上已经安装了Python。大多数Linux发行版都预装了Python。你可以在终端…

【深度学习-图像分类】02 - AlexNet 论文学习与总结

论文地址:ImageNet Classification with Deep Convolutional Neural Networks 论文学习 1. 摘要 本研究训练了一个大型深度卷积神经网络(CNN),用于对ImageNet LSVRC-2010比赛中的1.2百万高分辨率图像进行分类,这些图…

MYSQL 深入探索系列六 SQL执行计划

概述 好久不见了,近期一直在忙项目的事,才有时间写博客,近期频繁出现sql问题,今天正好不忙咱们看看千万级别的表到底该如何优化sql。 案例 近期有个小伙伴生产环境收到了告警,有个6千万的日志表,查询耗时大…

“undefined reference to XXX“问题总结

"undefined reference to XXX"问题总结 引言 我们在Linux下用C/C工作的时候,经常会遇到"undefined reference to XXX"的问题,直白地说就是在链接(从.cpp源代码到可执行的ELF文件,要经过预处理->编译->链接三个阶…

1panel使用指南(一)面板安装

一、1panel简介 1Panel是杭州飞致云信息科技有限公司推出的产品 [1],帮助用户实现快速建站。 [2]是一款现代化、开源的Linux服务器运维管理面板,于2023年3月推出,深度集成WordPress和Halo,一键完成域名绑定、SSL证书配置等操作&a…

draw.io学习笔记

1、链接 1.1、自动连接图形 鼠标放在图形上,点击出现的箭头,会自动出常用图形 1.2、固定连接 如果拖动其中一个图形的话,固定链接的形状会是曲线连过去。 方法:不要点击左边图形鼠标放在边框上面左边出现绿圆点鼠标左键点击图形的…

【Linux】深度解剖环境变量

> 作者简介:დ旧言~,目前大二,现在学习Java,c,c,Python等 > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:熟悉并掌握Linux的环境变量。 > 毒鸡汤&#x…

LeetCode刷题--- 单词搜索

个人主页:元清加油_【C】,【C语言】,【数据结构与算法】-CSDN博客 个人专栏 力扣递归算法题 http://t.csdnimg.cn/yUl2I 【C】 ​​​​​​http://t.csdnimg.cn/6AbpV 数据结构与算法 ​​​​http://t.csdnimg.cn/hKh2l 前言:这个专栏主要讲述…

LLM之RAG实战(十一)| 使用Mistral-7B和Langchain搭建基于PDF文件的聊天机器人

在本文中,使用LangChain、HuggingFaceEmbeddings和HuggingFace的Mistral-7B LLM创建一个简单的Python程序,可以从任何pdf文件中回答问题。 一、LangChain简介 LangChain是一个在语言模型之上开发上下文感知应用程序的框架。LangChain使用带prompt和few-…

CENTOS docker拉取私服镜像

概述 docker的应用越来越多,安装部署越来越方便,批量自动化的镜像生成和发布都需要docker镜像的拉取。 centos6版本太老,docker的使用过程中问题较多,centos7相对简单容易。 本文档主要介绍centos系统安装docker和拉取docker私…

一文了解无线通信 - NB-IOT、LoRa

NB-IOT、LoRa 目录概述需求: 设计思路实现思路分析 NB-IOT1.LoRa2.区别 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for change,chall…

简单了解SQL堆叠注入与二次注入(基于sqllabs演示)

1、堆叠注入 使用分号 ; 成堆的执行sql语句 以sqllabs-less-38为例 ?id1 简单测试发现闭合点为单引号 ?id1 order by 3 ?id1 order by 4使用order by探测发现只有三列(字段数) 尝试简单的联合注入查询 ?id-1 union select 1,database(),user()-…

开放网络+私有云=?星融元的私有云承载网络解决方案实例

在全世界范围内的云服务市场上,开放网络一直是一个备受关注的话题。相比于传统供应商的网络设备,开放网络具备软硬件解耦、云原生、可选组件丰富等优势,对云服务商和超大型企业有足够的吸引力。 SONiC作为开源的网络操作系统,使得…

uni-app uni-app内置组件

锋哥原创的uni-app视频教程: 2023版uniapp从入门到上天视频教程(Java后端无废话版),火爆更新中..._哔哩哔哩_bilibili2023版uniapp从入门到上天视频教程(Java后端无废话版),火爆更新中...共计23条视频,包括:第1讲 uni…

【JS笔记】JavaScript语法 《基础+重点》 知识内容,快速上手(二)

数组 什么是数组? 字面理解就是 数字的组合 其实不太准确,准确的来说数组是一个 数据的集合 也就是我们把一些数据放在一个盒子里面,按照顺序排好 [1, 2, 3, hello, true, false]这个东西就是一个数组,存储着一些数据的集合 …