使用官方代码打印yolov8 PyTorch模型结构

理解模型结构的重要性

在进行深度学习模型的开发时,一个清晰的模型结构有助于理解网络是如何从输入数据中提取特征,并执行分类或回归任务的。对于如YOLOv8这样的复杂模型来说,理解每个层的作用和相互间的连结尤为重要。

下面是我整合的代码:

import contextlib
import glob
import math
import re
import urllib
from copy import deepcopy
from pathlib import Path
import yaml
import torch
from torch import nn
from ultralytics.utils.tal import dist2bboxfrom ultralytics.utils import downloads
from ultralytics.nn.modules import RTDETRDecoder, Segment, Pose, OBB, Concat, ResNetLayer, HGBlock, HGStem, AIFI, BottleneckCSP, C1, C2, C3, C3Ghost, C3x, RepC3, C3TR, C2f, DWConvTranspose2d, Focus, Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, DFLFILE = Path(__file__).resolve()
ROOT = FILE.parents[1]  # YOLOclass Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False  # force grid reconstructionexport = False  # export modeshape = Noneanchors = torch.empty(0)  # initstrides = torch.empty(0)  # initdef __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc  # number of classesself.nl = len(ch)  # number of detection layersself.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4  # number of outputs per anchorself.stride = torch.zeros(self.nl)  # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()def forward(self, x):print("888888888888",self.nl)print("999999999999",self.stride)print(self.strides)print(torch.zeros(3))"""Concatenates and returns predicted bounding boxes and class probabilities."""# x is a list of tensors from previous layersoutputs = []  # You can store each output here if you don't want to print immediatelyfor i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)# Print the shape after processing with cv2 and cv3print(f'Layer {i + 28} output shape: {x[i].shape}')  # +18 because the first layer to print is 18outputs.append(x[i].detach())# import time## time.sleep(9000)if self.training:  # Training pathreturn x# Inference pathshape = x[0].shape  # BCHWx_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)# if self.dynamic or self.shape != shape:#     self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))##     self.shape = shapeif self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"):  # avoid TF FlexSplitV opsbox = x_cat[:, : self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = self.decode_bboxes(box)if self.export and self.format in ("tflite", "edgetpu"):# Precompute normalization factor to increase numerical stability# See https://github.com/ultralytics/ultralytics/issues/7371img_h = shape[2]img_w = shape[3]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1)norm = self.strides / (self.stride[0] * img_size)dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1)y = torch.cat((dbox, cls.sigmoid()), 1)return y if self.export else (y, x)def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self  # self.model[-1]  # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride):  # froma[-1].bias.data[:] = 1.0  # boxb[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)def decode_bboxes(self, bboxes):"""Decode bounding boxes."""return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesdef make_divisible(x, divisor):"""Returns the nearest number that is divisible by the given divisor.Args:x (int): The number to make divisible.divisor (int | torch.Tensor): The divisor.Returns:(int): The nearest number divisible by the divisor."""if isinstance(divisor, torch.Tensor):divisor = int(divisor.max())  # to intreturn math.ceil(x / divisor) * divisordef url2file(url):"""Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""return Path(clean_url(url)).namedef clean_url(url):"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""url = Path(url).as_posix().replace(":/", "://")  # Pathlib turns :// -> :/, as_posix() for Windowsreturn urllib.parse.unquote(url).split("?")[0]  # '%2F' to '/', split https://url.com/file.txt?authdef colorstr(*input):*args, string = input if len(input) > 1 else ("blue", "bold", input[0])  # color arguments, stringcolors = {"black": "\033[30m",  # basic colors"red": "\033[31m","green": "\033[32m","yellow": "\033[33m","blue": "\033[34m","magenta": "\033[35m","cyan": "\033[36m","white": "\033[37m","bright_black": "\033[90m",  # bright colors"bright_red": "\033[91m","bright_green": "\033[92m","bright_yellow": "\033[93m","bright_blue": "\033[94m","bright_magenta": "\033[95m","bright_cyan": "\033[96m","bright_white": "\033[97m","end": "\033[0m",  # misc"bold": "\033[1m","underline": "\033[4m",}return "".join(colors[x] for x in args) + f"{string}" + colors["end"]def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):"""Check file(s) for acceptable suffix."""if file and suffix:if isinstance(suffix, str):suffix = (suffix,)for f in file if isinstance(file, (list, tuple)) else [file]:s = Path(f).suffix.lower().strip()  # file suffixif len(s):assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"def check_yolov5u_filename(file: str, verbose: bool = True):"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""if "yolov3" in file or "yolov5" in file:if "u.yaml" in file:file = file.replace("u.yaml", ".yaml")  # i.e. yolov5nu.yaml -> yolov5n.yamlelif ".pt" in file and "u" not in file:original_file = filefile = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file)  # i.e. yolov5n.pt -> yolov5nu.ptfile = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file)  # i.e. yolov5n6.pt -> yolov5n6u.ptfile = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file)  # i.e. yolov3-spp.pt -> yolov3-sppu.ptif file != original_file and verbose:print(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n")return filedef check_file(file, suffix="", download=True, hard=True):"""Search/download file (if necessary) and return path."""check_suffix(file, suffix)  # optionalfile = str(file).strip()  # convert to string and strip spacesfile = check_yolov5u_filename(file)  # yolov5n -> yolov5nuif (not fileor ("://" not in file and Path(file).exists())  # '://' check required in Windows Python<3.10or file.lower().startswith("grpc://")):  # file exists or gRPC Triton imagesreturn fileelif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):  # downloadurl = file  # warning: Pathlib turns :// -> :/file = url2file(file)  # '%2F' to '/', split https://url.com/file.txt?authif Path(file).exists():print(f"Found {clean_url(url)} locally at {file}")  # file already existselse:downloads.safe_download(url=url, file=file, unzip=False)return fileelse:  # searchfiles = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file))  # find fileif not files and hard:raise FileNotFoundError(f"'{file}' does not exist")elif len(files) > 1 and hard:raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")return files[0] if len(files) else []  # return filedef check_yaml(file, suffix=(".yaml", ".yml"), hard=True):"""Search/download YAML file (if necessary) and return path, checking suffix."""return check_file(file, suffix, hard=hard)def yaml_load(file="data.yaml", append_filename=False):"""Load YAML data from a file.Args:file (str, optional): File name. Default is 'data.yaml'.append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.Returns:(dict): YAML data and file name."""assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"with open(file, errors="ignore", encoding="utf-8") as f:s = f.read()  # string# Remove special charactersif not s.isprintable():s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)# Add YAML filename to dict and returndata = yaml.safe_load(s) or {}  # always return a dict (yaml.safe_load() may return None for empty files)if append_filename:data["yaml_file"] = str(file)return datadef guess_model_scale(model_path):"""Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The functionuses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted byn, s, m, l, or x. The function returns the size character of the model scale as a string.Args:model_path (str | Path): The path to the YOLO model's YAML file.Returns:(str): The size character of the model's scale, which can be n, s, m, l, or x."""with contextlib.suppress(AttributeError):import rereturn re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1)  # n, s, m, l, or xreturn ""def yaml_model_load(path):"""Load a YOLOv8 model from a YAML file."""import repath = Path(path)if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)print(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")path = path.with_name(new_stem + path.suffix)unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path))  # i.e. yolov8x.yaml -> yolov8.yamlyaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)d = yaml_load(yaml_file)  # model dictd["scale"] = guess_model_scale(path)d["yaml_file"] = str(path)return ddef parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)"""Parse a YOLO model.yaml dictionary into a PyTorch model."""import ast# Argsmax_channels = float("inf")nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))if scales:scale = d.get("scale")if not scale:scale = tuple(scales.keys())[0]print(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")depth, width, max_channels = scales[scale]if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()if verbose:print(f"{colorstr('activation:')} {act}")  # printif verbose:print(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")ch = [ch]# print("ch -------------- : ",ch)layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, argsm = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]  # get modulefor j, a in enumerate(args):if isinstance(a, str):with contextlib.suppress(ValueError):args[j] = locals()[a] if a in locals() else ast.literal_eval(a)n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gainif m in (Classify,Conv,ConvTranspose,GhostConv,Bottleneck,GhostBottleneck,SPP,SPPF,DWConv,Focus,BottleneckCSP,C1,C2,C2f,C3,C3TR,C3Ghost,nn.ConvTranspose2d,DWConvTranspose2d,C3x,RepC3,):c1, c2 = ch[f], args[0]# print("c2 -------------- : ",c2)if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(min(c2, max_channels) * width, 8)# print("c2 m -------------- : ",c2)args = [c1, c2, *args[1:]]if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):args.insert(2, n)  # number of repeatsn = 1elif m is AIFI:args = [ch[f], *args]elif m in (HGStem, HGBlock):c1, cm, c2 = ch[f], args[0], args[1]args = [c1, cm, c2, *args[2:]]if m is HGBlock:args.insert(4, n)  # number of repeatsn = 1elif m is ResNetLayer:c2 = args[1] if args[3] else args[1] * 4elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum(ch[x] for x in f)elif m in (Detect, Segment, Pose, OBB):args.append([ch[x] for x in f])if m is Segment:args[2] = make_divisible(min(args[2], max_channels) * width, 8)elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1args.insert(1, [ch[x] for x in f])else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace("__main__.", "")  # module typem.np = sum(x.numel() for x in m_.parameters())  # number paramsm_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, typeif verbose:print(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}")  # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)cfg = "yolov8x-p2.yaml"
yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dictmodel, save = parse_model(deepcopy(yaml), ch=3, verbose=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/2719.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ORACLE创建表空间及用户

创建用户之前需要先为用户创建表空间。 创建临时表空间&#xff1a;Oracle 临时表空间用于存储数据库操作过程中的临时数据&#xff0c;例如排序、哈希操作或者大查询的中间结果。临时表空间主要用于保证用户会话中的操作不会影响到系统的稳定性。 #用管理员登入数据库。 sq…

React的Key和diff

React的Key 先说说React组件的中Key,在渲染一个列表的时候,都要求设置一个唯一的Key,不然就会提示:Each child in a list should have a unique "key" prop. 意思是列表的每一个子元素都应该需要设置一个唯一的key值。在开发中一搬会以id作为key。比如 const …

Java的八大基本数据类型和 println 的介绍

前言 如果你有C语言的基础&#xff0c;这部分内容就会很简单&#xff0c;但是会有所不同~~ 这是我将要提到的八大基本数据类型&#xff1a; 注意&#xff0c;Java的数据类型是有符号的&#xff01;&#xff01;&#xff01;和C语言不同&#xff0c;Java不存在无符号的数据。 整…

Day:动态规划 LeedCode 123.买卖股票的最佳时机III 188.买卖股票的最佳时机IV

123. 买卖股票的最佳时机 III 给定一个数组&#xff0c;它的第 i 个元素是一支给定的股票在第 i 天的价格。 设计一个算法来计算你所能获取的最大利润。你最多可以完成 两笔 交易。 注意&#xff1a;你不能同时参与多笔交易&#xff08;你必须在再次购买前出售掉之前的股票&a…

安全开发实战(2)---域名反查IP

目录 安全开发专栏 前言 域名与ip的关系 域名反查ip的作用 1.2.1 One 1.2.2 Two 1.2.3 批量监测 ​总结 安全开发专栏 安全开发实战http://t.csdnimg.cn/25N7H 这步是比较关键的一步,一般进行cdn监测后,获取到真实ip地址后,或是域名时,然后进行域名反查IP地址,进行进…

基于Springboot的职称评审管理系统

基于SpringbootVue的职称评审管理系统的设计与实现 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录 首页 评审条件 论坛信息 系统公告 后台登录页面 用户管理 评审员管理 省份…

再谈C语言——理解指针(四)

assert断⾔ assert.h 头⽂件定义了宏 assert() &#xff0c;⽤于在运⾏时确保程序符合指定条件&#xff0c;如果不符合&#xff0c;就报错终⽌运⾏。这个宏常常被称为“断⾔”。 assert(p ! NULL); 上⾯代码在程序运⾏到这⼀⾏语句时&#xff0c;验证变量 p 是否等于 NULL 。…

​LeetCode解法汇总2385. 感染二叉树需要的总时间

目录链接&#xff1a; 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目&#xff1a; https://github.com/September26/java-algorithms 原题链接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 描述&#xff1a; 给你一棵二叉树的根节点 root &#xff0…

Linux-System V共享内存

目录 System V共享内存共享内存概述创建或打开共享内存使用共享内存分离共享内存控制共享内存1、IPC_STAT2、IPC_SET 代码示例 System V共享内存 共享内存概述 共享内存是所有IPC中最快的一种。它之所以快是因为共享内存一旦映射到进程的地址空间&#xff0c;进程之间数据的传…

创建型设计模式

七大原则 1. 开闭原则&#xff08;Open-Closed Principle, OCP&#xff09; 详解&#xff1a;软件实体&#xff08;类、模块、函数等&#xff09;应该易于扩展&#xff0c;但是不易于修改。换句话说&#xff0c;当软件需求变化时&#xff0c;应该通过添加新代码来实现变化&am…

CSP初赛知识精讲--排列组合

第十一节 排列组合 基础知识 排列是指从给定个数的元素中取出指定个数的元素进行排序。  组合是指从给定个数的元素中仅仅取出指定元素个数的元素&#xff0c;不考虑排序。  排列组合问题的关键就是研究给定要求的排列和组合可能出现的情况的总数。 定义与公式  排列&…

销冠必备:高效跟进客户的四个技巧

作为一名销售&#xff0c;高效而精准地跟进客户是取得成功的关键。今天&#xff0c;我将分享四个技巧&#xff0c;让你也能够高效的跟进客户。 1、善于发问 通过多询问客户&#xff0c;你可以更好地了解客户的需求和痛点。在与客户交流时&#xff0c;不要只是简单地回答问题&…

LeetCode 0216.组合总和 III:回溯(剪枝) OR 二进制枚举

【LetMeFly】216.组合总和 III&#xff1a;回溯(剪枝) OR 二进制枚举 力扣题目链接&#xff1a;https://leetcode.cn/problems/combination-sum-iii/ 找出所有相加之和为 n 的 k 个数的组合&#xff0c;且满足下列条件&#xff1a; 只使用数字1到9每个数字 最多使用一次 返…

业务复习知识点Oracle查询

业务数据查询-1 单表查询 数据准备 自来水收费系统建表语句.sql 简单条件查询 精确查询 需求 &#xff1a;查询水表编号为 30408 的业主记录 查询语句 &#xff1a; select * from t_owners where watermeter 30408; 查询结果 &#xff1a; 模糊查询 需求 &#xff1a;查询业…

毕业设计注意事项(2024届更新中)

1.开题 根据学院发的开题报告模板完成&#xff0c;其中大纲部分可参考资料 2.毕设 根据资料中的毕设评价标准&#xff0c;对照工作量 3.论文 3.1 格式问题 非常重要&#xff0c;认真对比资料中我发的模板&#xff0c;格式有问题&#xff0c;答辩输一半&#xff01; 以word…

W801学习笔记十四:掌机系统——菜单——尝试打造自己的UI

未来将会有诸多应用&#xff0c;这些应用将通过菜单进行有序组织和管理。因此&#xff0c;我们需要率先打造好菜单。 LCD 驱动通常是直接写屏的&#xff0c;虽然速度较快&#xff0c;但用于界面制作则不太适宜。所以&#xff0c;最好能拥有一套 UI 框架。如前所述&#xff0c;…

【linux】编译器使用

目录 1. gcc &#xff0c;g 编译器使用 a. 有关gcc的指令&#xff08;g同理&#xff09; 2. .o 文件和库的链接方式 a. 链接方式 b. 动态库 和 静态库 优缺点对比 c. debug 版本 和 release 版本 1. gcc &#xff0c;g 编译器使用 a. 有关gcc的指令&#xff08;g同理&…

设计模式-创建型-抽象工厂模式-Abstract Factory

UML类图 工厂接口类 public interface ProductFactory {Phone phoneProduct();//生产手机Router routerProduct();//生产路由器 } 小米工厂实现类 public class XiaomiFactoryImpl implements ProductFactory {Overridepublic Phone phoneProduct() {return new XiaomiPhone…

Node.js -- fs模块

文章目录 1. 写入文件1.1 写入文件1.2 同步和异步1.3 文件追加写入1.4 流式写入1.5 文件写入的场景 2. 读取文件2.1 异步和同步读取2.2 读取文件应用场景2.3 流式读取2.4 fs 练习 -- 文件复制 3. 文件重命名和移动4. 文件删除5. 文件夹操作5.1 创建文件夹5.2 读取文件夹5.3 删除…

crossover和wine哪个好 wine和crossover有什么本质区别 苹果电脑运行Windows crossover24

CrossOver是Wine的延伸产品&#xff0c;CrossOver可以简单的理解为类虚拟机&#xff0c;那么wine是什么&#xff0c;许多小伙伴就可能有些一知半解。CrossOver和wine哪个好&#xff0c;wine和CrossOver有什么本质区别呢&#xff1f;下文将围绕着这两个问题展开。 一、CrossOve…