YOLOv6-4.0部分代码阅读笔记-yolo.py

yolo.py

yolov6\models\yolo.py

目录

yolo.py

1.所需的库和模块

2.class Model(nn.Module): 

3.def make_divisible(x, divisor): 

4.def build_network(config, channels, num_classes, num_layers, fuse_ab=False, distill_ns=False): 

5.def build_model(cfg, num_classes, device, fuse_ab=False, distill_ns=False): 


1.所需的库和模块

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from yolov6.layers.common import *
from yolov6.utils.torch_utils import initialize_weights
from yolov6.models.efficientrep import *
from yolov6.models.reppan import *
from yolov6.utils.events import LOGGER

2.class Model(nn.Module): 

class Model(nn.Module):# 具有主干、颈部和头部的 YOLOv6 模型。# 默认部件是 EfficientRep Backbone , Rep-PAN 和 Efficient Decoupled Head 。'''YOLOv6 model with backbone, neck and head.The default parts are EfficientRep Backbone, Rep-PAN andEfficient Decoupled Head.'''# 1.config :模型配置对象,包含了模型结构和训练相关的参数。# 2.channels :输入图像的通道数,默认为3(彩色图像)。# 3.num_classes :目标检测任务中的类别数。# 4.fuse_ab :是否融合 A 和 B 特征层。# 5.distill_ns :是否进行蒸馏。def __init__(self, config, channels=3, num_classes=None, fuse_ab=False, distill_ns=False):  # model, input channels, number of classes#  调用父类的初始化方法,调用 nn.Module 的初始化方法,这是 PyTorch 中定义模型类时的标准做法。super().__init__()# Build network    建立网络# 从配置中获取头部网络的层数。num_layers = config.model.head.num_layers# 调用 build_network 函数构建模型的三个部分,并将返回的 骨干网络 、 颈部网络 和 头部网络 分别赋值给 self.backbone 、 self.neck 和 self.detect 。self.backbone, self.neck, self.detect = build_network(config, channels, num_classes, num_layers, fuse_ab=fuse_ab, distill_ns=distill_ns)# Init Detect head    初始检测头# 从头部网络(detect)获取模型的步长(stride),步长是目标检测中的一个重要参数,它决定了检测的精度和召回率。self.stride = self.detect.stride# 对检测头的偏置进行初始化,这通常是根据特定的初始化策略来设置偏置值,以帮助模型在训练初期更快地收敛。self.detect.initialize_biases()# Init weights    初始权重# 对整个模型的权重进行初始化,这通常包括对卷积层、全连接层等的权重进行初始化,以确保模型在训练开始时有一个良好的起点。initialize_weights(self)def forward(self, x):export_mode = torch.onnx.is_in_onnx_export()x = self.backbone(x)x = self.neck(x)if export_mode == False:featmaps = []featmaps.extend(x)x = self.detect(x)return x if export_mode is True else [x, featmaps]# 这段代码定义了 Model 类的 _apply 方法,这个方法在 PyTorch 中用于对模型中的参数和缓冲区应用一个函数 fn 。这通常用于模型的变换,比如参数的修改或者参数的复制等。def _apply(self, fn):# 首先调用 nn.Module 的 _apply 方法,这个方法会对模型中的所有参数和缓存应用函数 fn 。这里的 self 被重新赋值为 _apply 方法的返回值,这意味着模型的参数已经被 fn 函数处理过了。self = super()._apply(fn)# 对检测头中的 stride 属性应用函数 fn 。 stride 是一个重要的参数,它决定了模型输出的空间分辨率。通过 fn 函数,可以对 stride 进行调整或变换。self.detect.stride = fn(self.detect.stride)# 对检测头中的 grid 属性应用函数 fn 。 grid 属性通常是一个列表或者张量,包含了在不同尺度上的目标位置信息。通过 map 函数, fn 被应用到 grid 的每一个元素上,然后返回一个新的列表。self.detect.grid = list(map(fn, self.detect.grid))# 返回经过参数变换后的模型实例。return self

3.def make_divisible(x, divisor): 

# 将给定的数值 x 调整为最接近的、可以被 divisor 整除的数值。
# 1.x :需要调整的原始数值。
# 2.divisor :用于调整 x 的除数,即 x 需要被这个数整除。
def make_divisible(x, divisor):# 向上修正 x 的值,使其可以被除数整除。# Upward revision the value x to make it evenly divisible by the divisor.# math.ceil(x)# 方法将 x 向上舍入到最接近的整数。# x / divisor :首先计算 x 除以 divisor 的结果,得到一个浮点数。# math.ceil(x / divisor) :使用 math.ceil 函数对上述结果进行向上取整,即取大于或等于该浮点数的最小整数。# math.ceil(x / divisor) * divisor :将向上取整后的结果乘以 divisor ,得到一个可以被 divisor 整除的数值。return math.ceil(x / divisor) * divisor

4.def build_network(config, channels, num_classes, num_layers, fuse_ab=False, distill_ns=False): 

def build_network(config, channels, num_classes, num_layers, fuse_ab=False, distill_ns=False):depth_mul = config.model.depth_multiple    # 控制网络结构深度的缩放因子width_mul = config.model.width_multiple    # 控制网络结构宽度的缩放因子num_repeat_backbone = config.model.backbone.num_repeats    # 主干网络每个stage中基础模块的重复个数channels_list_backbone = config.model.backbone.out_channels    # 主干网络每个stage中输出的通道数fuse_P2 = config.model.backbone.get('fuse_P2')    # 是否融合骨干网络P2层特征cspsppf = config.model.backbone.get('cspsppf')    # 是否使用CSPSPPF模块以替换SPPF模块num_repeat_neck = config.model.neck.num_repeats    # Neck网络连接每个特征层基础模块的重复个数channels_list_neck = config.model.neck.out_channels    # Neck网络连接每个特征层的上采样/下采样模块通道数use_dfl = config.model.head.use_dfl    # 是否使用 distributed focal loss ,若后续想继续蒸馏,需要设为True以保留DFL分支reg_max = config.model.head.reg_max    # 若 use_dfl 为 False, 则 reg_max 设为 0;若 use_dfl 为 True, 则reg_max 设为 16num_repeat = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in (num_repeat_backbone + num_repeat_neck)]# def make_divisible(x, divisor): -> 向上修正 x 的值,使其可以被除数整除。 -> return math.ceil(x / divisor) * divisor# 使用 make_divisible 函数调整骨干网络和颈部网络的通道数,使其可以被 8 整除。channels_list = [make_divisible(i * width_mul, 8) for i in (channels_list_backbone + channels_list_neck)]# 根据训练模式获取相应的网络模块。block = get_block(config.training_mode)# 根据配置动态加载骨干网络的类。BACKBONE = eval(config.model.backbone.type)    # 网络类型# 根据配置动态加载颈部网络的类。NECK = eval(config.model.neck.type)    # 检测器 Neck 的类别,目前可选用'RepPANNeck', 'CSPRepPANNeck','RepBiFPANNeck','CSPRepBiFPANNeck','RepBiFPANNeck6','CSPRepBiFPANNeck_P6' 6种if 'CSP' in config.model.backbone.type:    # 主干网络的类别,目前可支持'EfficientRep', 'CSPBepBackbone','EfficientRep6','CSPBepBackbone_P6' 4种if "stage_block_type" in config.model.backbone:stage_block_type = config.model.backbone.stage_block_typeelse:stage_block_type = "BepC3"  #default    默认backbone = BACKBONE(in_channels=channels,channels_list=channels_list,num_repeats=num_repeat,block=block,csp_e=config.model.backbone.csp_e,fuse_P2=fuse_P2,cspsppf=cspsppf,stage_block_type=stage_block_type)neck = NECK(channels_list=channels_list,num_repeats=num_repeat,block=block,csp_e=config.model.neck.csp_e,stage_block_type=stage_block_type)else:backbone = BACKBONE(in_channels=channels,channels_list=channels_list,num_repeats=num_repeat,block=block,fuse_P2=fuse_P2,cspsppf=cspsppf)neck = NECK(channels_list=channels_list,num_repeats=num_repeat,block=block)# 它处理了在模型中启用蒸馏(distillation)时头部网络的构建。if distill_ns:# 如果 distill_ns 参数为 True ,则从 yolov6.models.heads.effidehead_distill_ns 模块导入 Detect 类和 build_effidehead_layer 函数。这些是专门用于蒸馏的头部网络组件。from yolov6.models.heads.effidehead_distill_ns import Detect, build_effidehead_layerif num_layers != 3:# 如果 num_layers 不等于 3,使用 LOGGER.error 记录错误信息,并退出程序。这是因为在蒸馏模式下,模型的头部网络层数需要与特定的配置相匹配,这里假设蒸馏模型的头部网络层数固定为 3。# 错误:蒸馏模式不适合带有 P6 头的 n/s 型号。LOGGER.error('ERROR in: Distill mode not fit on n/s models with P6 head.\n')# exit()# 结束整个程序。# 在python中运行一段代码,如果在某处已经完成整次任务,可以用 exit 退出整个运行。并且还可以在 exit() 的括号里加入自己退出程序打印说明。exit()# def build_effidehead_layer(channels_list, num_anchors, num_classes, reg_max=16): -> return head_layers# 调用 build_effidehead_layer 函数来构建头部网络的层结构。它根据输入的通道数列表 channels_list 、锚框数量(这里为 1)、类别数 num_classes 和 reg_max 来构建头部网络层。head_layers = build_effidehead_layer(channels_list, 1, num_classes, reg_max=reg_max)# class Detect(nn.Module): -> 高效分离头,实现免费蒸馏。(适用于纳米/小型型号)。 -> def __init__(self, num_classes=80, num_layers=3, inplace=True, head_layers=None, use_dfl=True, reg_max=16):  # detection layer# 使用 Detect 类创建头部网络实例。这里传入类别数 num_classes 、头部网络层数 num_layers 、头部网络层结构 head_layers 和是否使用分布式焦点损失的标志 use_dfl 。head = Detect(num_classes, num_layers, head_layers=head_layers, use_dfl=use_dfl)# 这段代码处理了在模型中启用特征融合(fuse A and B features)时头部网络的构建。elif fuse_ab:# 如果 fuse_ab 参数为 True ,则从 yolov6.models.heads.effidehead_fuseab 模块导入 Detect 类和 build_effidehead_layer 函数。这些是专门用于融合特征的头部网络组件。from yolov6.models.heads.effidehead_fuseab import Detect, build_effidehead_layer# 从配置中获取初始化锚点(anchors),这些锚点用于目标检测中的目标框定位。anchors_init = config.model.head.anchors_inithead_layers = build_effidehead_layer(channels_list, 3, num_classes, reg_max=reg_max, num_layers=num_layers)# class Detect(nn.Module): -> 高效的分离头,用于融合锚框分支。 -> def __init__(self, num_classes=80, anchors=None, num_layers=3, inplace=True, head_layers=None, use_dfl=True, reg_max=16):  # detection layerhead = Detect(num_classes, anchors_init, num_layers, head_layers=head_layers, use_dfl=use_dfl)else:from yolov6.models.effidehead import Detect, build_effidehead_layerhead_layers = build_effidehead_layer(channels_list, 1, num_classes, reg_max=reg_max, num_layers=num_layers)# class Detect(nn.Module): -> 高效分离头。利用硬件感知设计,使用混合通道方法对解耦头进行优化。 -> def __init__(self, num_classes=80, num_layers=3, inplace=True, head_layers=None, use_dfl=True, reg_max=16):  # detection layer    检测层head = Detect(num_classes, num_layers, head_layers=head_layers, use_dfl=use_dfl)return backbone, neck, head

5.def build_model(cfg, num_classes, device, fuse_ab=False, distill_ns=False): 

def build_model(cfg, num_classes, device, fuse_ab=False, distill_ns=False):# class Model(nn.Module): -> 具有主干、颈部和头部的 YOLOv6 模型。默认部件是 EfficientRep Backbone , Rep-PAN 和 Efficient Decoupled Head 。# -> def __init__(self, config, channels=3, num_classes=None, fuse_ab=False, distill_ns=False):  # model, input channels, number of classesmodel = Model(cfg, channels=3, num_classes=num_classes, fuse_ab=fuse_ab, distill_ns=distill_ns).to(device)return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/884424.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Serverless + AI 让应用开发更简单

本文整理自 2024 云栖大会,阿里云智能高级技术专家,史明伟演讲议题《Serverless AI 让应用开发更简单》 随着云计算和人工智能(AI)技术的飞速发展,企业对于高效、灵活且成本效益高的解决方案的需求日益增长。本文旨在…

conda迁移虚拟环境路径

方法一:使用软连接 ln -s ~/Anaconda3/envs /new/path/envs 方法二:修改~/.condarc文件 1.打开~/.condarc文件 #添加下面参数 envs_dirs: - /newpath/anaconda3/envs pkgs_dirs: - /newpath/anaconda3/pkgs 2. source ~/.bashrc 3.查看是否成功con…

从0开始学PHP面向对象内容之(类,对象,构造/析构函数)

上期我们讲了面向对象的一些基本信息&#xff0c;这期让我们详细的了解一下 一、面向对象—类 1、PHP类的定义语法&#xff1a; <?php class className {var $var1;var $var2 "constant string";function classfunc ($arg1, $arg2) {[..]}[..] } ?>2、解…

(八)JavaWeb后端开发——Tomcat

目录 1.Web服务器概念 2.tomcat 1.Web服务器概念 服务器&#xff1a;安装了服务器软件的计算机服务器软件&#xff1a;接收用户的请求&#xff0c;处理请求&#xff0c;做出响应web服务器软件&#xff1a;在web服务器软件中&#xff0c;可以部署web项目&#xff0c;让用户通…

CSS3新增边框属性(五)

1、新增边框属性 1.1 border-radius 设置盒子的圆角。 <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><title>border-radios</title><style>div {height: 400px;width: 400px;border: 1px so…

【Linux系列】Linux 和 Unix 系统中的`set`命令与错误处理

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

代码随想录刷题学习日记

仅为个人记录复盘学习历程&#xff0c;解题思路来自代码随想录 代码随想录刷题笔记总结网址:代码随想录 递归函数什么时候需要返回值&#xff1f;什么时候不需要返回值&#xff1f; 如果需要搜索整棵二叉树且不用处理递归返回值&#xff0c;递归函数就不要返回值。 如果需要…

Nuxt.js 应用中的 nitro:config 事件钩子详解

title: Nuxt.js 应用中的 nitro:config 事件钩子详解 date: 2024/11/2 updated: 2024/11/2 author: cmdragon excerpt: nitro:config 是 Nuxt 3 中的一个生命周期钩子,允许开发者在初始化 Nitro 之前自定义 Nitro 的配置。Nitro 是 Nuxt 3 的服务器引擎,负责处理请求、渲…

[论文阅读]LOGAN: Membership Inference Attacks Against Generative Models

LOGAN: Membership Inference Attacks Against Generative Models https://arxiv.org/abs/1705.07663v4 Proceedings on Privacy Enhancing Technologies &#xff08;PoPETs&#xff09;&#xff0c;第 2019 卷&#xff0c;第 1 期。 这篇文章是17年的一篇文章&#xff0c;…

使用Vite构建现代化前端应用

&#x1f493; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4dd; Gitee主页&#xff1a;瑕疵的gitee主页 ⏩ 文章专栏&#xff1a;《热点资讯》 使用Vite构建现代化前端应用 引言 Vite 简介 安装 Vite 创建项目 启动开发服务器 项目结构 配置 Vite 开发模式 生产构建 使用插…

Node.js:模块 包

Node.js&#xff1a;模块 & 包 模块module对象 包npm安装包配置文件镜像源 分类 模块 模块化是指解决一个复杂问题时&#xff0c;自顶向下逐层把系统划分成若干模块的过程。对于整个系统来说&#xff0c;模块是可组合、分解和更换的单元。 简单来说&#xff0c;就是把一个…

使用 Faster Whisper 和 Gradio 实现实时语音转文字

随着人工智能技术的进步&#xff0c;语音识别已经成为最热门的研究领域之一。如何实现高效、准确的实时语音转文字功能&#xff0c;是许多开发者关注的重点。本文将介绍如何使用 Faster Whisper 和 Gradio 这两个强大工具&#xff0c;快速构建一个实时语音转文字应用。 Faster…

【Arduino】一分钟快速在vs code 编译开发Arduino

下载Arduino 对于一些开发者来说&#xff0c;Arduino开发较为不方便&#xff0c;不管从代码的阅读性、开发效率等等方面&#xff0c;vs code都要优于Arduino IDE开发&#xff0c;而且vs code开发可以使用插件&#xff0c;比如一些AI代码插件&#xff0c;可以加快开发速率&#…

WPF+MVVM案例实战(十九)- 自定义字体图标按钮的封装与实现(EF类)

文章目录 1、案例效果1、按钮分类2、E类按钮功能实现与封装1.文件创建与代码实现2、样式引用与封装 3、F类按钮功能实现与封装1、文件创建与代码实现2、样式引用与封装 3、按钮案例演示1、页面实现与文件创建2、运行效果如下 4、源代码获取 1、案例效果 1、按钮分类 在WPF开发…

Java基本语法和基础数据类型——针对实习面试

目录 Java基本语法和基础数据类型标识符和关键字有什么区别&#xff1f;Java关键字有哪些&#xff1f;Java基本数据类型有哪些&#xff1f;什么是自动装箱和拆箱&#xff1f;自动装箱&#xff08;Autoboxing&#xff09;自动拆箱&#xff08;Unboxing&#xff09; 自动装箱和拆…

c# 值类型

目录 1、c#类型2、值类型2.1 结构体2.2 枚举 1、c#类型 类型&#xff08;Type&#xff09;又叫数据类型&#xff08;Data Type&#xff09;。 A data type is a homogeneous collection of values,effectively prensented,equipped with a set of operations which manipulate…

【压力测试】如何确定系统最大并发用户数?

一、明确测试目的与了解需求 明确测试目的&#xff1a;首先需要明确测试的目的&#xff0c;即为什么要确定系统的最大并发用户数。这通常与业务需求、系统预期的最大用户负载以及系统的稳定性要求相关。 了解业务需求&#xff1a;深入了解系统的业务特性&#xff0c;包括用户行…

怎么在哔哩哔哩保存完整视频

哔哩哔哩(B站)作为一个集视频分享、弹幕互动于一体的平台&#xff0c;吸引了大量用户。许多人希望能够将自己喜欢的完整视频保存到本地&#xff0c;以便离线观看或分享。直接下载视频的功能并不总是可用&#xff0c;因此&#xff0c;本文将介绍几种在哔哩哔哩上保存完整视频的方…

【玉米叶部病害识别】Python+深度学习+人工智能+图像识别+CNN卷积神经网络算法+TensorFlow

一、介绍 玉米病害识别系统&#xff0c;本系统使用Python作为主要开发语言&#xff0c;通过收集了8种常见的玉米叶部病害图片数据集&#xff08;‘矮花叶病’, ‘健康’, ‘灰斑病一般’, ‘灰斑病严重’, ‘锈病一般’, ‘锈病严重’, ‘叶斑病一般’, ‘叶斑病严重’&#x…

PAT甲级-1048 Find Coins

题目 题目大意 给出硬币的个数n和要付费的钱m&#xff0c;接下来给出每个硬币的面值。要求从这些硬币中找到两个硬币v1, v2&#xff0c;使得v1 v2 m&#xff0c;且v1 < v2&#xff0c;输出v1 v2。如果不能找到这两个硬币&#xff0c;输出No Solution。 思路 刚开始用的…