理解模型结构的重要性
在进行深度学习模型的开发时,一个清晰的模型结构有助于理解网络是如何从输入数据中提取特征,并执行分类或回归任务的。对于如YOLOv8这样的复杂模型来说,理解每个层的作用和相互间的连结尤为重要。
下面是我整合的代码:
import contextlib
import glob
import math
import re
import urllib
from copy import deepcopy
from pathlib import Path
import yaml
import torch
from torch import nn
from ultralytics.utils.tal import dist2bboxfrom ultralytics.utils import downloads
from ultralytics.nn.modules import RTDETRDecoder, Segment, Pose, OBB, Concat, ResNetLayer, HGBlock, HGStem, AIFI, BottleneckCSP, C1, C2, C3, C3Ghost, C3x, RepC3, C3TR, C2f, DWConvTranspose2d, Focus, Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, DFLFILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOclass Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False # force grid reconstructionexport = False # export modeshape = Noneanchors = torch.empty(0) # initstrides = torch.empty(0) # initdef __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc # number of classesself.nl = len(ch) # number of detection layersself.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4 # number of outputs per anchorself.stride = torch.zeros(self.nl) # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()def forward(self, x):print("888888888888",self.nl)print("999999999999",self.stride)print(self.strides)print(torch.zeros(3))"""Concatenates and returns predicted bounding boxes and class probabilities."""# x is a list of tensors from previous layersoutputs = [] # You can store each output here if you don't want to print immediatelyfor i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)# Print the shape after processing with cv2 and cv3print(f'Layer {i + 28} output shape: {x[i].shape}') # +18 because the first layer to print is 18outputs.append(x[i].detach())# import time## time.sleep(9000)if self.training: # Training pathreturn x# Inference pathshape = x[0].shape # BCHWx_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)# if self.dynamic or self.shape != shape:# self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))## self.shape = shapeif self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV opsbox = x_cat[:, : self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = self.decode_bboxes(box)if self.export and self.format in ("tflite", "edgetpu"):# Precompute normalization factor to increase numerical stability# See https://github.com/ultralytics/ultralytics/issues/7371img_h = shape[2]img_w = shape[3]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1)norm = self.strides / (self.stride[0] * img_size)dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1)y = torch.cat((dbox, cls.sigmoid()), 1)return y if self.export else (y, x)def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self # self.model[-1] # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride): # froma[-1].bias.data[:] = 1.0 # boxb[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)def decode_bboxes(self, bboxes):"""Decode bounding boxes."""return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesdef make_divisible(x, divisor):"""Returns the nearest number that is divisible by the given divisor.Args:x (int): The number to make divisible.divisor (int | torch.Tensor): The divisor.Returns:(int): The nearest number divisible by the divisor."""if isinstance(divisor, torch.Tensor):divisor = int(divisor.max()) # to intreturn math.ceil(x / divisor) * divisordef url2file(url):"""Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""return Path(clean_url(url)).namedef clean_url(url):"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windowsreturn urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?authdef colorstr(*input):*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, stringcolors = {"black": "\033[30m", # basic colors"red": "\033[31m","green": "\033[32m","yellow": "\033[33m","blue": "\033[34m","magenta": "\033[35m","cyan": "\033[36m","white": "\033[37m","bright_black": "\033[90m", # bright colors"bright_red": "\033[91m","bright_green": "\033[92m","bright_yellow": "\033[93m","bright_blue": "\033[94m","bright_magenta": "\033[95m","bright_cyan": "\033[96m","bright_white": "\033[97m","end": "\033[0m", # misc"bold": "\033[1m","underline": "\033[4m",}return "".join(colors[x] for x in args) + f"{string}" + colors["end"]def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):"""Check file(s) for acceptable suffix."""if file and suffix:if isinstance(suffix, str):suffix = (suffix,)for f in file if isinstance(file, (list, tuple)) else [file]:s = Path(f).suffix.lower().strip() # file suffixif len(s):assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"def check_yolov5u_filename(file: str, verbose: bool = True):"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""if "yolov3" in file or "yolov5" in file:if "u.yaml" in file:file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yamlelif ".pt" in file and "u" not in file:original_file = filefile = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.ptfile = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.ptfile = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.ptif file != original_file and verbose:print(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n")return filedef check_file(file, suffix="", download=True, hard=True):"""Search/download file (if necessary) and return path."""check_suffix(file, suffix) # optionalfile = str(file).strip() # convert to string and strip spacesfile = check_yolov5u_filename(file) # yolov5n -> yolov5nuif (not fileor ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10or file.lower().startswith("grpc://")): # file exists or gRPC Triton imagesreturn fileelif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # downloadurl = file # warning: Pathlib turns :// -> :/file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?authif Path(file).exists():print(f"Found {clean_url(url)} locally at {file}") # file already existselse:downloads.safe_download(url=url, file=file, unzip=False)return fileelse: # searchfiles = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find fileif not files and hard:raise FileNotFoundError(f"'{file}' does not exist")elif len(files) > 1 and hard:raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")return files[0] if len(files) else [] # return filedef check_yaml(file, suffix=(".yaml", ".yml"), hard=True):"""Search/download YAML file (if necessary) and return path, checking suffix."""return check_file(file, suffix, hard=hard)def yaml_load(file="data.yaml", append_filename=False):"""Load YAML data from a file.Args:file (str, optional): File name. Default is 'data.yaml'.append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.Returns:(dict): YAML data and file name."""assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"with open(file, errors="ignore", encoding="utf-8") as f:s = f.read() # string# Remove special charactersif not s.isprintable():s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)# Add YAML filename to dict and returndata = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files)if append_filename:data["yaml_file"] = str(file)return datadef guess_model_scale(model_path):"""Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The functionuses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted byn, s, m, l, or x. The function returns the size character of the model scale as a string.Args:model_path (str | Path): The path to the YOLO model's YAML file.Returns:(str): The size character of the model's scale, which can be n, s, m, l, or x."""with contextlib.suppress(AttributeError):import rereturn re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or xreturn ""def yaml_model_load(path):"""Load a YOLOv8 model from a YAML file."""import repath = Path(path)if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)print(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")path = path.with_name(new_stem + path.suffix)unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yamlyaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)d = yaml_load(yaml_file) # model dictd["scale"] = guess_model_scale(path)d["yaml_file"] = str(path)return ddef parse_model(d, ch, verbose=True): # model_dict, input_channels(3)"""Parse a YOLO model.yaml dictionary into a PyTorch model."""import ast# Argsmax_channels = float("inf")nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))if scales:scale = d.get("scale")if not scale:scale = tuple(scales.keys())[0]print(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")depth, width, max_channels = scales[scale]if act:Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()if verbose:print(f"{colorstr('activation:')} {act}") # printif verbose:print(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")ch = [ch]# print("ch -------------- : ",ch)layers, save, c2 = [], [], ch[-1] # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, argsm = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get modulefor j, a in enumerate(args):if isinstance(a, str):with contextlib.suppress(ValueError):args[j] = locals()[a] if a in locals() else ast.literal_eval(a)n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gainif m in (Classify,Conv,ConvTranspose,GhostConv,Bottleneck,GhostBottleneck,SPP,SPPF,DWConv,Focus,BottleneckCSP,C1,C2,C2f,C3,C3TR,C3Ghost,nn.ConvTranspose2d,DWConvTranspose2d,C3x,RepC3,):c1, c2 = ch[f], args[0]# print("c2 -------------- : ",c2)if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(min(c2, max_channels) * width, 8)# print("c2 m -------------- : ",c2)args = [c1, c2, *args[1:]]if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):args.insert(2, n) # number of repeatsn = 1elif m is AIFI:args = [ch[f], *args]elif m in (HGStem, HGBlock):c1, cm, c2 = ch[f], args[0], args[1]args = [c1, cm, c2, *args[2:]]if m is HGBlock:args.insert(4, n) # number of repeatsn = 1elif m is ResNetLayer:c2 = args[1] if args[3] else args[1] * 4elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum(ch[x] for x in f)elif m in (Detect, Segment, Pose, OBB):args.append([ch[x] for x in f])if m is Segment:args[2] = make_divisible(min(args[2], max_channels) * width, 8)elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1args.insert(1, [ch[x] for x in f])else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # modulet = str(m)[8:-2].replace("__main__.", "") # module typem.np = sum(x.numel() for x in m_.parameters()) # number paramsm_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, typeif verbose:print(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)cfg = "yolov8x-p2.yaml"
yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dictmodel, save = parse_model(deepcopy(yaml), ch=3, verbose=True)