YOLOv5 + SE注意力机制:提升目标检测性能的实践

一、引言

目标检测是计算机视觉领域的一个重要任务,广泛应用于自动驾驶、安防监控、工业检测等领域。YOLOv5作为YOLO系列的最新版本,以其高效性和准确性在实际应用中表现出色。然而,随着应用场景的复杂化,传统的卷积神经网络在处理复杂背景和多尺度目标时可能会遇到性能瓶颈。为此,引入注意力机制成为了一种有效的改进方法。本文将详细介绍如何在YOLOv5中引入SE(Squeeze-and-Excitation)注意力机制,通过修改模型配置文件和代码实现,提升模型性能,并对比训练效果。

YOLOv5是YOLO系列的最新版本,相较于之前的版本,YOLOv5在模型结构、训练策略和数据增强等方面进行了多项改进,显著提升了模型的性能和效率。其主要特点包括:

  • 模型结构优化:YOLOv5采用新的骨干网络(Backbone)和路径聚合网络(Neck),提高了特征提取和融合的能力。
  • 数据增强策略:引入了多种数据增强方法,如Mosaic、MixUp等,提升了模型的泛化能力。
  • 训练策略改进:采用动态标签分配策略(SimOTA),提高了训练效率和检测精度。

然而,随着任务复杂度的增加,传统的卷积神经网络在处理多尺度目标时的表现不够理想,SE注意力机制的引入为提升目标检测精度提供了新的思路。

二、YOLOv5与SE注意力机制

2.1 YOLOv5简介

YOLOv5以其高效性和准确性在目标检测中得到了广泛应用。其主要结构特点是:

  • Backbone:负责从输入图像中提取特征。
  • Neck:通过特征融合提高模型的多尺度感知能力。
  • Head:根据提取的特征进行预测。

2.2 SE注意力机制简介

SE(Squeeze-and-Excitation)注意力机制是一种轻量级的注意力模块,旨在通过显式地建模通道间的依赖关系,提升模型的表示能力。SE模块由两个关键部分组成:

  • Squeeze(压缩):通过全局平均池化操作,将特征图的空间维度压缩为1,生成通道描述符。
  • Excitation(激励):通过两个全连接层和一个Sigmoid激活函数生成通道权重,用于重新校准特征图的通道响应。

通过引入SE模块,YOLOv5能够更加关注重要的特征通道,抑制不重要的特征通道,从而提升模型性能。

三、YOLOv5 + SE注意力机制的实现

3.1 模型配置文件修改

首先,想要将SE注意力机制引入到Yolov5中去,需要修改以下几个文件:commom.py、yolo.py和yolov5s.yaml文件。需要修改YOLOv5的模型配置文件(yolov5_se.yaml),在Backbone和Neck中引入SE模块。注意将SE模块引入之后,需要更改层数的号码,SE注意力机制也可以加入到其他层中,比如head层的P3输出之前等等。以下是修改后的配置文件内容:

# YOLOv5 馃殌 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SENet,[1024]], #SEAttention #9[-1, 1, SPPF, [1024, 5]],  # 10]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [256, False]],  # 18 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [512, False]],  # 21 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [1024, False]],  # 24 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

3.2 SE注意力模块的代码实现

在YOLOv5的代码中,需要实现SE模块。以下是一个SEBlock的实现:

import torch
import torch.nn as nnclass SENet(nn.Module):#c1, c2, n=1, shortcut=True, g=1, e=0.5def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5 ):super(SENet, self).__init__()#c*1*1self.avgpool = nn.AdaptiveAvgPool2d(1)self.l1 = nn.Linear(c1, c1 // 16, bias=False)self.relu = nn.ReLU(inplace=True)self.l2 = nn.Linear(c1 // 16, c1, bias=False)self.sig = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()y = self.avgpool(x).view(b, c)y = self.l1(y)y = self.relu(y)y = self.l2(y)y = self.sig(y)y = y.view(b, c, 1, 1)return x * y.expand_as(x)

3.3 使用SE注意力模块

为了在YOLOv5的Backbone和Neck中引入SE模块,可以对Yolo.py文件原有的parse_model进行修改,以下是修改后的Bottleneck模块:

def parse_model(d, ch):  # model_dict, input_channels(3)# Parse a YOLOv5 model.yaml dictionaryLOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()LOGGER.info(f"{colorstr('activation:')} {act}")  # printna = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchorsno = na * (nc + 5)  # number of outputs = anchors * (classes + 5)layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, argsm = eval(m) if isinstance(m, str) else m  # eval stringsfor j, a in enumerate(args):with contextlib.suppress(NameError):args[j] = eval(a) if isinstance(a, str) else a  # eval stringsn = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gainif m in {Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x,SENet,}:c1, c2 = ch[f], args[0]if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x, CBAMBottleneck, CABottleneck, CBAMC3, SENet, CANet, CAC3, CBAM, ECANet, GAMNet}:args.insert(2, n)  # number of repeatsn = 1elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum(ch[x] for x in f)# TODO: channel, gw, gdelif m in {Detect, Segment}:args.append([ch[x] for x in f])if isinstance(args[1], int):  # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)if m is Segment:args[3] = make_divisible(args[3] * gw, 8)elif m is Contract:c2 = ch[f] * args[0] ** 2elif m is Expand:c2 = ch[f] // args[0] ** 2else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace('__main__.', '')  # module typenp = sum(x.numel() for x in m_.parameters())  # number paramsm_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number paramsLOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)

3.4 模型训练与效果对比

完成模型配置文件和代码的修改后,可以开始训练模型。推荐使用

COCO数据集或自定义数据集进行训练和验证。或者其他的自定义数据集也可以,在这里我使用自定义数据集camel_elephant_training进行100个epoch训练,该数据集仅仅有骆驼和大象两个种类。

训练完成后,可以通过AP(平均精度)指标来评估引入SE注意力机制前后的模型性能。一般情况下,引入SE模块后,YOLOv5在复杂背景和多尺度目标的检测中表现更为出色。

训练之后的结果如下:

由于时间有限我仅仅训练了100个epoch,正常情况下应设置150~200epoch,从train/obj_loss来看,仍然有下降的空间。

3.5 训练步骤

  1. 配置训练环境,确保已安装YOLOv5和相关依赖。
  2. 下载COCO数据集或使用自定义数据集进行训练。
  3. 修改训练脚本,加载修改后的模型配置文件yolov5_se.yaml
  4. 开始训练并监控训练过程中的损失和精度。
  5. 完成训练后,使用验证集评估效果。

3.6 模型部署

将训练好的数据权重通过export.py文件转换成.onnx格式,可以部署到任意平台上。

import argparse
import contextlib
import json
import os
import platform
import re
import subprocess
import sys
import time
import warnings
from pathlib import Pathimport pandas as pd
import torch
from torch.utils.mobile_optimizer import optimize_for_mobileFILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:sys.path.append(str(ROOT))  # add ROOT to PATH
if platform.system() != 'Windows':ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relativefrom models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
from utils.torch_utils import select_device, smart_inference_modeMACOS = platform.system() == 'Darwin'  # macOS environmentdef export_formats():# YOLOv5 export formatsx = [['PyTorch', '-', '.pt', True, True],['TorchScript', 'torchscript', '.torchscript', True, True],['ONNX', 'onnx', '.onnx', True, True],['OpenVINO', 'openvino', '_openvino_model', True, False],['TensorRT', 'engine', '.engine', False, True],['CoreML', 'coreml', '.mlmodel', True, False],['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],['TensorFlow GraphDef', 'pb', '.pb', True, True],['TensorFlow Lite', 'tflite', '.tflite', True, False],['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],['TensorFlow.js', 'tfjs', '_web_model', False, False],['PaddlePaddle', 'paddle', '_paddle_model', True, True],]return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])def try_export(inner_func):# YOLOv5 export decorator, i..e @try_exportinner_args = get_default_args(inner_func)def outer_func(*args, **kwargs):prefix = inner_args['prefix']try:with Profile() as dt:f, model = inner_func(*args, **kwargs)LOGGER.info(f'{prefix} export success 鉁?{dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')return f, modelexcept Exception as e:LOGGER.info(f'{prefix} export failure 鉂?{dt.t:.1f}s: {e}')return None, Nonereturn outer_func@try_export
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):# YOLOv5 TorchScript model exportLOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')f = file.with_suffix('.torchscript')ts = torch.jit.trace(model, im, strict=False)d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()if optimize:  # https://pytorch.org/tutorials/recipes/mobile_interpreter.htmloptimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)else:ts.save(str(f), _extra_files=extra_files)return f, None@try_export
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):# YOLOv5 ONNX exportcheck_requirements('onnx')import onnxLOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')f = file.with_suffix('.onnx')output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']if dynamic:dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}}  # shape(1,3,640,640)if isinstance(model, SegmentationModel):dynamic['output0'] = {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'}  # shape(1,32,160,160)elif isinstance(model, DetectionModel):dynamic['output0'] = {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)torch.onnx.export(model.cpu() if dynamic else model,  # --dynamic only compatible with cpuim.cpu() if dynamic else im,f,verbose=False,opset_version=opset,do_constant_folding=True,input_names=['images'],output_names=output_names,dynamic_axes=dynamic or None)# Checksmodel_onnx = onnx.load(f)  # load onnx modelonnx.checker.check_model(model_onnx)  # check onnx model# Metadatad = {'stride': int(max(model.stride)), 'names': model.names}for k, v in d.items():meta = model_onnx.metadata_props.add()meta.key, meta.value = k, str(v)onnx.save(model_onnx, f)# Simplifyif simplify:try:cuda = torch.cuda.is_available()check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))import onnxsimLOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')model_onnx, check = onnxsim.simplify(model_onnx)assert check, 'assert check failed'onnx.save(model_onnx, f)except Exception as e:LOGGER.info(f'{prefix} simplifier failure: {e}')return f, model_onnx@try_export
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):# YOLOv5 OpenVINO exportcheck_requirements('openvino-dev')  # requires openvino-dev: https://pypi.org/project/openvino-dev/import openvino.inference_engine as ieLOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')f = str(file).replace('.pt', f'_openvino_model{os.sep}')cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"subprocess.run(cmd.split(), check=True, env=os.environ)  # exportyaml_save(Path(f) / file.with_suffix('.yaml').name, metadata)  # add metadata.yamlreturn f, None@try_export
def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):# YOLOv5 Paddle exportcheck_requirements(('paddlepaddle', 'x2paddle'))import x2paddlefrom x2paddle.convert import pytorch2paddleLOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')f = str(file).replace('.pt', f'_paddle_model{os.sep}')pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im])  # exportyaml_save(Path(f) / file.with_suffix('.yaml').name, metadata)  # add metadata.yamlreturn f, None@try_export
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):# YOLOv5 CoreML exportcheck_requirements('coremltools')import coremltools as ctLOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')f = file.with_suffix('.mlmodel')ts = torch.jit.trace(model, im, strict=False)  # TorchScript modelct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)if bits < 32:if MACOS:  # quantization only supported on macOSwith warnings.catch_warnings():warnings.filterwarnings("ignore", category=DeprecationWarning)  # suppress numpy==1.20 float warningct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)else:print(f'{prefix} quantization only supported on macOS, skipping...')ct_model.save(f)return f, ct_model@try_export
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrtassert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'try:import tensorrt as trtexcept Exception:if platform.system() == 'Linux':check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')import tensorrt as trtif trt.__version__[0] == '7':  # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12model.model[-1].anchor_grid = gridelse:  # TensorRT >= 8check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12onnx = file.with_suffix('.onnx')LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')assert onnx.exists(), f'failed to export ONNX file: {onnx}'f = file.with_suffix('.engine')  # TensorRT engine filelogger = trt.Logger(trt.Logger.INFO)if verbose:logger.min_severity = trt.Logger.Severity.VERBOSEbuilder = trt.Builder(logger)config = builder.create_builder_config()config.max_workspace_size = workspace * 1 << 30# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)  # fix TRT 8.4 deprecation noticeflag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))network = builder.create_network(flag)parser = trt.OnnxParser(network, logger)if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')inputs = [network.get_input(i) for i in range(network.num_inputs)]outputs = [network.get_output(i) for i in range(network.num_outputs)]for inp in inputs:LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')for out in outputs:LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')if dynamic:if im.shape[0] <= 1:LOGGER.warning(f"{prefix} WARNING 鈿狅笍 --dynamic model requires maximum --batch-size argument")profile = builder.create_optimization_profile()for inp in inputs:profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)config.add_optimization_profile(profile)LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')if builder.platform_has_fast_fp16 and half:config.set_flag(trt.BuilderFlag.FP16)with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())return f, None@try_export
def export_saved_model(model,im,file,dynamic,tf_nms=False,agnostic_nms=False,topk_per_class=100,topk_all=100,iou_thres=0.45,conf_thres=0.25,keras=False,prefix=colorstr('TensorFlow SavedModel:')):# YOLOv5 TensorFlow SavedModel exporttry:import tensorflow as tfexcept Exception:check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")import tensorflow as tffrom tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2from models.tf import TFModelLOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')f = str(file).replace('.pt', '_saved_model')batch_size, ch, *imgsz = list(im.shape)  # BCHWtf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)im = tf.zeros((batch_size, *imgsz, ch))  # BHWC order for TensorFlow_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)keras_model.trainable = Falsekeras_model.summary()if keras:keras_model.save(f, save_format='tf')else:spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)m = tf.function(lambda x: keras_model(x))  # full modelm = m.get_concrete_function(spec)frozen_func = convert_variables_to_constants_v2(m)tfm = tf.Module()tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])tfm.__call__(im)tf.saved_model.save(tfm,f,options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())return f, keras_model@try_export
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlowimport tensorflow as tffrom tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')f = file.with_suffix('.pb')m = tf.function(lambda x: keras_model(x))  # full modelm = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))frozen_func = convert_variables_to_constants_v2(m)frozen_func.graph.as_graph_def()tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)return f, None@try_export
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):# YOLOv5 TensorFlow Lite exportimport tensorflow as tfLOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')batch_size, ch, *imgsz = list(im.shape)  # BCHWf = str(file).replace('.pt', '-fp16.tflite')converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]converter.target_spec.supported_types = [tf.float16]converter.optimizations = [tf.lite.Optimize.DEFAULT]if int8:from models.tf import representative_dataset_gendataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]converter.target_spec.supported_types = []converter.inference_input_type = tf.uint8  # or tf.int8converter.inference_output_type = tf.uint8  # or tf.int8converter.experimental_new_quantizer = Truef = str(file).replace('.pt', '-int8.tflite')if nms or agnostic_nms:converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)tflite_model = converter.convert()open(f, "wb").write(tflite_model)return f, None@try_export
def export_edgetpu(file, prefix=colorstr('Edge TPU:')):# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/cmd = 'edgetpu_compiler --version'help_url = 'https://coral.ai/docs/edgetpu/compiler/'assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0  # sudo installed on systemfor c in ('curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -','echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list','sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')f = str(file).replace('.pt', '-int8_edgetpu.tflite')  # Edge TPU modelf_tfl = str(file).replace('.pt', '-int8.tflite')  # TFLite modelcmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"subprocess.run(cmd.split(), check=True)return f, None@try_export
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):# YOLOv5 TensorFlow.js exportcheck_requirements('tensorflowjs')import tensorflowjs as tfjsLOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')f = str(file).replace('.pt', '_web_model')  # js dirf_pb = file.with_suffix('.pb')  # *.pb pathf_json = f'{f}/model.json'  # *.json pathcmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'subprocess.run(cmd.split())json = Path(f_json).read_text()with open(f_json, 'w') as j:  # sort JSON Identity_* in ascending ordersubst = re.sub(r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, 'r'"Identity_1": {"name": "Identity_1"}, 'r'"Identity_2": {"name": "Identity_2"}, 'r'"Identity_3": {"name": "Identity_3"}}}', json)j.write(subst)return f, Nonedef add_tflite_metadata(file, metadata, num_outputs):# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadatawith contextlib.suppress(ImportError):# check_requirements('tflite_support')from tflite_support import flatbuffersfrom tflite_support import metadata as _metadatafrom tflite_support import metadata_schema_py_generated as _metadata_fbtmp_file = Path('/tmp/meta.txt')with open(tmp_file, 'w') as meta_f:meta_f.write(str(metadata))model_meta = _metadata_fb.ModelMetadataT()label_file = _metadata_fb.AssociatedFileT()label_file.name = tmp_file.namemodel_meta.associatedFiles = [label_file]subgraph = _metadata_fb.SubGraphMetadataT()subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputsmodel_meta.subgraphMetadata = [subgraph]b = flatbuffers.Builder(0)b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)metadata_buf = b.Output()populator = _metadata.MetadataPopulator.with_model_file(file)populator.load_metadata_buffer(metadata_buf)populator.load_associated_files([str(tmp_file)])populator.populate()tmp_file.unlink()@smart_inference_mode()
def run(data=ROOT / 'data/coco128.yaml',  # 'dataset.yaml path'weights=ROOT / 'yolov5s.pt',  # weights pathimgsz=(640, 640),  # image (height, width)batch_size=1,  # batch sizedevice='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpuinclude=('torchscript', 'onnx'),  # include formatshalf=False,  # FP16 half-precision exportinplace=False,  # set YOLOv5 Detect() inplace=Truekeras=False,  # use Kerasoptimize=False,  # TorchScript: optimize for mobileint8=False,  # CoreML/TF INT8 quantizationdynamic=False,  # ONNX/TF/TensorRT: dynamic axessimplify=False,  # ONNX: simplify modelopset=12,  # ONNX: opset versionverbose=False,  # TensorRT: verbose logworkspace=4,  # TensorRT: workspace size (GB)nms=False,  # TF: add NMS to modelagnostic_nms=False,  # TF: add agnostic NMS to modeltopk_per_class=100,  # TF.js NMS: topk per class to keeptopk_all=100,  # TF.js NMS: topk for all classes to keepiou_thres=0.45,  # TF.js NMS: IoU thresholdconf_thres=0.25,  # TF.js NMS: confidence threshold
):t = time.time()include = [x.lower() for x in include]  # to lowercasefmts = tuple(export_formats()['Argument'][1:])  # --include argumentsflags = [x in include for x in fmts]assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags  # export booleansfile = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)  # PyTorch weights# Load PyTorch modeldevice = select_device(device)if half:assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'model = attempt_load(weights, device=device, inplace=True, fuse=True)  # load FP32 model# Checksimgsz *= 2 if len(imgsz) == 1 else 1  # expandif optimize:assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'# Inputgs = int(max(model.stride))  # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection# Update modelmodel.eval()for k, m in model.named_modules():if isinstance(m, Detect):m.inplace = inplacem.dynamic = dynamicm.export = Truefor _ in range(2):y = model(im)  # dry runsif half and not coreml:im, model = im.half(), model.half()  # to FP16shape = tuple((y[0] if isinstance(y, tuple) else y).shape)  # model output shapemetadata = {'stride': int(max(model.stride)), 'names': model.names}  # model metadataLOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")# Exportsf = [''] * len(fmts)  # exported filenameswarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarningif jit:  # TorchScriptf[0], _ = export_torchscript(model, im, file, optimize)if engine:  # TensorRT required before ONNXf[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)if onnx or xml:  # OpenVINO requires ONNXf[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)if xml:  # OpenVINOf[3], _ = export_openvino(file, metadata, half)if coreml:  # CoreMLf[4], _ = export_coreml(model, im, file, int8, half)if any((saved_model, pb, tflite, edgetpu, tfjs)):  # TensorFlow formatsassert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'f[5], s_model = export_saved_model(model.cpu(),im,file,dynamic,tf_nms=nms or agnostic_nms or tfjs,agnostic_nms=agnostic_nms or tfjs,topk_per_class=topk_per_class,topk_all=topk_all,iou_thres=iou_thres,conf_thres=conf_thres,keras=keras)if pb or tfjs:  # pb prerequisite to tfjsf[6], _ = export_pb(s_model, file)if tflite or edgetpu:f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)if edgetpu:f[8], _ = export_edgetpu(file)add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))if tfjs:f[9], _ = export_tfjs(file)if paddle:  # PaddlePaddlef[10], _ = export_paddle(model, im, file, metadata)# Finishf = [str(x) for x in f if x]  # filter out '' and Noneif any(f):cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel))  # typedir = Path('segment' if seg else 'classify' if cls else '')h = '--half' if half else ''  # --half FP16 inference args = "# WARNING 鈿狅笍 ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \"# WARNING 鈿狅笍 SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'f"\nResults saved to {colorstr('bold', file.parent.resolve())}"f"\nDetect:          python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"f"\nValidate:        python {dir / 'val.py'} --weights {f[-1]} {h}"f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')  {s}"f"\nVisualize:       https://netron.app")return f  # return list of exported files/dirsdef parse_opt():parser = argparse.ArgumentParser()parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')parser.add_argument('--batch-size', type=int, default=1, help='batch size')parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')parser.add_argument('--half', action='store_true', help='FP16 half-precision export')parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')parser.add_argument('--keras', action='store_true', help='TF: use Keras')parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')parser.add_argument('--include',nargs='+',default=['torchscript'],help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')opt = parser.parse_args()print_args(vars(opt))return optdef main(opt):for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):run(**vars(opt))if __name__ == "__main__":opt = parse_opt()main(opt)

四、总结

本文介绍了如何在YOLOv5中引入SE注意力机制,包括模型配置文件的修改、代码实现、训练步骤以及效果对比。通过引入SE模块,YOLOv5在多尺度目标和复杂背景下的检测精度有所提升。未来,可以继续探索其他注意力机制(如CBAM、ECA等)的应用,以进一步提升YOLOv5的性能。感谢大家的支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/896549.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

跟我学C++中级篇——定时器的设计

一、定时器 谈到定时器&#xff0c;理论上讲是各种语言和各种设计都无法避开的一个技术点。对于定时器来说&#xff0c;表面上就是一种时间间隔的处理约定&#xff0c;但对程序来说&#xff0c;可能就是设计层面、接口层面和库或框架以及系统应用的一个大集合。不同的系统&…

智能机器人加速进化:AI大模型与传感器的双重buff加成

Deepseek不仅可以在手机里为你解答现在的困惑、占卜未来的可能&#xff0c;也将成为你的贴心生活帮手&#xff01; 2月21日&#xff0c;追觅科技旗下Dreamehome APP正式接入DeepSeek-R1大模型&#xff0c;2月24日发布的追觅S50系列扫地机器人也成为市面上首批搭载DeepSeek-R1的…

PostgreSQL10 逻辑复制实战:构建高可用数据同步架构!

PostgreSQL10 逻辑复制实战&#xff1a;打造高可用数据同步架构&#xff01; 概述 PostgreSQL 10 引入了逻辑复制&#xff08;Logical Replication&#xff09;&#xff0c;为数据库高可用和数据同步提供了更灵活的选择。PostgreSQL 复制机制主要分为物理复制和逻辑复制两种&…

LVS+Keepalived高可用群集配置案例

以下是一个 LVSKeepalived 高可用群集配置案例&#xff1a; 1、环境准备 LVS 主调度器&#xff08;lvs1&#xff09;&#xff1a;IP 地址为 192.168.8.101&#xff0c;心跳 IP 为 192.168.4.101LVS 备调度器&#xff08;lvs2&#xff09;&#xff1a;IP 地址为 192.168.8.102…

原生家庭独立的艺术:找到自我与家庭的平衡点

原生家庭独立的艺术&#xff1a;找到自我与家庭的平衡点 &#x1f331; 引言 &#x1f308; 小林刚刚和父母结束了一次激烈的电话对峙。父母坚持认为他应该回到家乡工作&#xff0c;“这样我们也能照顾你”&#xff0c;而他则努力解释自己在大城市的职业规划。挂掉电话后&…

Java进阶——注解一文全懂

Java注解&#xff08;Annotation&#xff09;是一种强大的元数据机制&#xff0c;为代码提供了附加信息&#xff0c;能简化配置、增强代码的可读性和可维护性。本文将深入探讨 Java 注解的相关知识。首先阐述了注解的基础概念&#xff0c;包括其本质、作用以及核心分类&#xf…

DeepSeek 15天指导手册——从入门到精通 PDF(附下载)

DeepSeek使用教程系列--DeepSeek 15天指导手册——从入门到精通pdf下载&#xff1a; https://pan.baidu.com/s/1PrIo0Xo0h5s6Plcc_smS8w?pwd1234 提取码: 1234 或 https://pan.quark.cn/s/2e8de75027d3 《DeepSeek 15天指导手册——从入门到精通》以系统化学习路径为核心&…

【智能音频新风尚】智能音频眼镜+FPC,打造极致听觉享受!【新立电子】

智能音频眼镜&#xff0c;作为一款将时尚元素与前沿科技精妙融合的智能设备&#xff0c;这种将音频技术与眼镜形态完美结合的可穿戴设备&#xff0c;不仅解放了用户的双手&#xff0c;更为人们提供了一种全新的音频交互体验。新立电子FPC在智能音频眼镜中的应用&#xff0c;为音…

常用的 pip 命令

pip 是 Python 的包管理工具&#xff0c;可用于安装、卸载、更新和管理 Python 包。以下是一些常用的 pip 命令&#xff1a; 1. 安装包 安装最新版本的包 pip install package_namepackage_name 是你要安装的 Python 包的名称&#xff0c;例如 pip install requests 可以安装…

学习threejs,使用ShaderMaterial自定义着色器材质

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.ShaderMaterial1.1.1…

从暴力破解到时空最优:LeetCode算法设计核心思维解密

一、算法优化金字塔模型&#xff08;时间复杂度/空间复杂度协同优化&#xff09; 1.1 复杂度分析的本质 大O记号的三层认知&#xff1a; ① 理论复杂度边界&#xff08;理想模型&#xff09; ② 硬件架构影响&#xff08;缓存命中率/分支预测&#xff09; ③ 语言特性损耗&am…

Typora的Github主题美化

[!note] Typora的Github主题进行一些自己喜欢的修改&#xff0c;主要包括&#xff1a;字体、代码块、表格样式 美化前&#xff1a; 美化后&#xff1a; 一、字体更换 之前便看上了「中文网字计划」的「朱雀仿宋」字体&#xff0c;于是一直想更换字体&#xff0c;奈何自己拖延症…

用大白话解释搜索引擎Elasticsearch是什么,有什么用,怎么用

Elasticsearch是什么&#xff1f; Elasticsearch&#xff08;简称ES&#xff09;就像一个“超级智能的图书馆管理系统”&#xff0c;专门帮你从海量数据中快速找到想要的信息。它底层基于倒排索引技术&#xff08;类似书籍的目录页&#xff09;&#xff0c;能秒级搜索和分析万…

神经网络 - 激活函数(Sigmoid 型函数)

激活函数在神经元中非常重要的。为了增强网络的表示能力和学习能力&#xff0c;激活函数需要具备以下几点性质: (1) 连续并可导(允许少数点上不可导)的非线性函数。可导的激活函数可以直接利用数值优化的方法来学习网络参数. (2) 激活函数及其导函数要尽可能的简单&#xff0…

Spring 源码硬核解析系列专题(六):Spring MVC 的请求处理源码解析

在前几期中,我们探讨了 Spring 的 IoC 容器、Bean 创建、AOP、事务管理以及 Spring Boot 的自动装配,这些为 Spring MVC 的运行奠定了基础。作为 Spring 生态中处理 Web 请求的核心模块,Spring MVC 通过 DispatcherServlet 实现了灵活的请求分发与处理。本篇将深入 Dispatch…

Docker容器日常维护常用命令大全

友情提示&#xff1a;本文内容由银河易创&#xff08;https://ai.eaigx.com&#xff09;AI创作平台deepseek-v3模型生成&#xff0c;文中所有命令未进行验证&#xff0c;仅供参考。请根据具体情况和需求进行适当的调整和验证。 引言 Docker作为当前最流行的容器化技术&#xf…

Pytest测试用例执行跳过的3种方式

文章目录 1.前言2.使用 pytest.mark.skip 标记无条件跳过3.使用 pytest.mark.skipif 标记根据条件跳过4. 执行pytest.skip()方法跳过测试用例 1.前言 在实际场景中&#xff0c;我们可能某条测试用例没写完&#xff0c;代码执行时会报错&#xff0c;或者是在一些条件下不让某些…

GitHub 语析 - 基于大模型的知识库与知识图谱问答平台

语析 - 基于大模型的知识库与知识图谱问答平台 GitHub 地址&#xff1a;https://github.com/xerrors/Yuxi-Know &#x1f4dd; 项目概述 语析是一个强大的问答平台&#xff0c;结合了大模型 RAG 知识库与知识图谱技术&#xff0c;基于 Llamaindex VueJS FastAPI Neo4j 构…

vue学习七

十四 pinia 官网&#xff1a;安装 | Pinia 中文文档 集中式状态管理&#xff0c;与vuex相似&#xff0c;提供变量存储便于数据共享。 从概念上类似于php中的session吧…… 适用于少量数据的共享&#xff0c;可操作数据都是先定义后使用。 适用于判断用户是否登录&#xff…

【Prometheus】prometheus服务发现与relabel原理解析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全…