YOLOv8-对注意力机制模型进行通道剪枝-同时实现涨点和轻量化【附代码】

文章目录

  • 前言
  • 视频效果
  • 文章概述
  • 必要环境
  • 一、训练自己的模型
    • 1、 训练命令
    • 2、 训练参数解析
  • 二、模型剪枝
    • 1、 对训练好的模型将进行剪枝
    • 2、 剪枝代码详解
      • 1.解析命令行参数
      • 2. 定义剪枝函数
      • 3. 定义剪枝结构
      • 4. 更新注意力机制
      • 5. 保存更新后的模型
      • 6. 主函数
  • 三、剪枝后的训练
    • 运行命令如下
  • 四、实验指标对比
  • 五、剪枝前后效果对比
  • 六、完整代码获取
  • 总结


前言

在上期博客中,我们实现了对YOLOv10模型的结构化通道剪枝,本篇文章将介绍如何对增加了MCA注意力机制的YOLOv8模型进行通道剪枝,并详细解读每个参数和模块的作用。
上期博客地址:YOLOv10结构化通道剪枝【附代码】


视频效果

b站链接:魔改YOLOv8 在参数量下降51.3%的情况下涨点1% (KITTI验证集)


文章概述

本篇博客将详细介绍如何对yolov8注意力机制模型进行通道剪枝,具体步骤包括参数解析、剪枝代码讲解、fine-tune训练,最后将对比剪枝前后模型在KITTI数据集上的表现,包括MAP、参数量和FPS等指标,以验证剪枝效果


必要环境

  1. 配置yolov10环境 可参考往期博客(v8和v10环境配置方法可通用)
    地址:搭建YOLOv10环境 训练+推理+模型评估

  2. 安装torch-pruning 0.2.7版本,安装命令如下

    pip install torch-pruning==0.2.7
    
  3. 结构化剪枝论文地址
    地址:Pruning Filters for Efficient ConvNets

  4. MCA注意力机制论文地址
    地址:Multidimensional collaborative attention in deep convolutional neural networks for image recognition


一、训练自己的模型

1、 训练命令

python 1_yolov8_train.py --mode train_ch --yaml_path yolov8n.yaml --epoch 200 --batch 32 --model_path ''

运行效果
在这里插入图片描述
可以看到正常训练时会打印模型在yaml文件中定义的网络结构

2、 训练参数解析

# 解析命令行参数
parser = argparse.ArgumentParser(description='Train or validate YOLO model.')
parser.add_argument('--mode', type=str, default='val', choices=['train_ori', 'train_ch', 'val'],help='Mode of operation.')
parser.add_argument('--yaml_path', type=str, default='yolov8n.yaml', help='Path to YAML file.')
parser.add_argument('--model_path', type=str, default=r'runs/kitti_ori/weights/best.pt', help='Path to model file.')
parser.add_argument('--data_path', type=str, default='./data.yaml', help='Path to data file.')
parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.')
parser.add_argument('--batch', type=int, default=16, help='Batch size.')
parser.add_argument('--workers', type=int, default=8, help='Number of workers.')
parser.add_argument('--device', type=str, default='0', help='Device to use.')
parser.add_argument('--name', type=str, default='', help='Name data file.')
args = parser.parse_args()

参数详解:

  1. –mode: 用于指定操作模式
    可选值为train_ori、train_ch和val。train_ori用于训练原始模型,train_ch用于训练改进后的模型(如增加注意力机制或增加检测头),val用于验证模型并计算精度指标

  2. –yaml_path: 指定改进网络结构的YAML文件路径
    当选择训练改进后的模型时,需要提供相应的网络结构文件路径,如训练带有MCA注意力机制的模型时,此处填写相应的YAML文件路径

  3. –model_path: 指定模型文件路径
    当mode不等于val时,该参数为预训练模型的路径(如训练8n模型时,此处填写yolov8n.pt路径)
    当mode等于val时,该参数为训练好的模型路径,用于计算指标,通常保存在runs目录下

  4. –data_path: 指定数据集文件路径
    该参数用于提供数据集的路径,对应一个YAML文件

  5. –epoch: 指定训练的轮数
    默认值为200,表示模型的训练轮次

  6. –batch: 指定批次大小
    默认值为16,表示每次训练迭代中所处理的样本数量

  7. –workers: 指定工作线程数
    默认值为8,表示用于数据加载的工作线程数量,windows系统这里改为0

  8. –device: 指定使用的设备
    默认值为0,表示使用的GPU设备编号

  9. –name: 指定保存模型文件夹的名称

二、模型剪枝

1、 对训练好的模型将进行剪枝

运行命令如下

python 3_yolov8_pruning.py --model_path weights/kitti_baseline/weights/best.pt --prune_type l1 --prune_ratio 0.4

运行效果
在这里插入图片描述

运行成功后会输出剪枝后的网络结构,以及剪枝前后模型的参数量对比

2、 剪枝代码详解

1.解析命令行参数

解析命令行参数的,其方便各位在命令行中指定模型路径、剪枝策略以及剪枝比例等参数

# 解析命令行参数
def parse_args():parser = argparse.ArgumentParser(description="Prune YOLOv8 model.")parser.add_argument("--model_path", type=str,default=r"weights/kitti_baseline/weights/best.pt",help="Path to the YOLOv8 model.")parser.add_argument("--prune_type", type=str, default="l2", choices=["l1", "l2", "random"],help="Pruning strategy to use.")parser.add_argument("--prune_ratio", type=float, default=0.4, help="Pruning ratio.")args = parser.parse_args()return args

参数详解:

  1. –model_path: 指定需要剪枝的模型路径
  2. –prune_type: 指定剪枝策略,可选方案为 l1, l2, random,默认使用 l1策略
  3. –prune_ratio: 指定剪枝比例,默认值为0.4,表示对定义的卷积层减掉40%的通道数

2. 定义剪枝函数

用于根据指定的修剪策略和比例对给定的模型进行修剪

def prune_model(model, prune_type, prune_ratio, input_tensor):strategy = {'l1': tp.strategy.L1Strategy(),'l2': tp.strategy.L2Strategy(),'random': tp.strategy.RandomStrategy()}.get(prune_type, tp.strategy.RandomStrategy())dependency_graph = tp.DependencyGraph().build_dependency(model, example_inputs=input_tensor)included_layers = get_included_layers(model)original_params = tp.utils.count_params(model)pruning_plans = [dependency_graph.get_pruning_plan(m, tp.prune_conv, idxs=strategy(m.weight, amount=prune_ratio))for m in model.modules() if isinstance(m, nn.Conv2d) and m in included_layers]

关键步骤详解:

  1. 策略选择
    根据 prune_type 参数, 选择对应的剪枝策略,如果 prune_type 不是预定义的值, 则默认使用随机剪枝策略

  2. 构建依赖图
    使用 tp.DependencyGraph().build_dependency 函数构建模型的依赖关系图, 以便后续进行剪枝操作

  3. 获取包含的层
    使用 get_included_layers 函数获取需要进行剪枝的层, 即模型中的 nn.Conv2d 层

  4. 计算原始参数数量
    使用 tp.utils.count_params 函数计算模型的原始参数数量

  5. 制定剪枝计划
    对于每个需要剪枝的 nn.Conv2d 层, 使用对应的剪枝策略计算剪枝的索引, 并生成剪枝计划

3. 定义剪枝结构

从指定模型中, 找出所有可以进行剪枝操作的层, 并将它们添加到 included_layers 列表中

def get_included_layers(model):included_layers = []for layer in model.model:if isinstance(layer, Conv):included_layers.append(layer.conv)...if isinstance(layer, Detect):...return included_layers

关键模块详解:

  1. model: 指定yolov8模型,函数将遍历这个模型的层来识别可剪枝的部分
  2. included_layers: 用于存储可以进行剪枝操作的层,函数会将这些层添加到这个列表中
  3. 定义模型中不同类型的层,函数会根据层的类型采取不同的处理方式,将可剪枝的部分添加到 included_layers 列表中

4. 更新注意力机制

由于torch_pruning中的某些bug,剪枝后会使注意力机制中某些模块的通道数变为负数,为了确保剪枝后的网络能够正确工作,我们需要更新这些层的通道数

def replace_conv_macayer(original_layer, new_in_channels, new_out_channels):# 获取原始层的参数original_weight = original_layer.conv.weight.dataoriginal_bias = original_layer.conv.bias.data if original_layer.conv.bias is not None else None# 创建一个新的卷积层new_conv_layer = nn.Conv2d(in_channels=new_in_channels, out_channels=new_out_channels,kernel_size=original_layer...)# 复制权重...return new_conv_layer

关键模块详解:

  1. original_layer: 原始的卷积层。这是一个包含卷积层的对象,通常是一个网络中的某个层。
  2. new_in_channels: 新的输入通道数。
  3. new_out_channels: 新的输出通道数。
  4. 返回值:最终会返回一个新的卷积层,该层具有更新后的输入和输出通道数,并且尽可能保留了原始层的权重和偏置。

5. 保存更新后的模型

**剪枝操作完成后,我们需要将剪枝后的模型保存,以便后续使用 **

# 保存更新后的模型
def save_pruned_model(model, ckpt, prune_type):param_dict = {'model': model,'ema': ckpt['ema'],...}torch.save(param_dict, f'prune_model_{prune_type}.pt')

参数详解:

  1. model:剪枝后的模型**
  2. ckpt:模型训练状态和相关参数的字典,需要将必要部分写入到剪枝模型中**
  3. prune_type:剪枝类型,用于命名保存的模型文件**

6. 主函数

定义主函数,整合上述各个步骤,实现完整的剪枝流程

def main():args = parse_args()# 加载模型yolov8 = YOLO(args.model_path)# 使模型参数可训练for para in model.parameters():para.requires_grad = Truepruned_model, original_params = prune_model(model, args.prune_type, args.prune_ratio, input_tensor)# 更新模型中的注意力层update_model_attention_layers(pruned_model)# 保存更新后的模型save_pruned_model(pruned_model, ckpt, args.prune_type)pruned_params = tp.utils.count_params(model)percentage_reduction = ((original_params - pruned_params) / original_params) * 100logger.info(f"Params: {original_params * 4 / 1024 / 1024:.2f} MB => {pruned_params * 4 / 1024 / 1024:.2f} MB (Reduction: {percentage_reduction:.2f}%)")

关键模块解读:
1. parse_args():解析命令行参数。
2. YOLO(args.model_path):加载YOLOv8模型
3. prune_model():执行剪枝操作
4. save_pruned_model():保存剪枝后的模型
5. 计算剪枝前后参数的变化,并打印模型信息和参数减少的百分比

三、剪枝后的训练

运行命令如下

python 4_yolov8-finetune.py --finetune --epochs 200 --batch_size 16

运行效果
在这里插入图片描述
可以看到剪枝后训练不会打印模型在yaml文件中定义的网络结构

四、实验指标对比

如下表:

模型MAPPRImgszParam
YOLOv8n原模型86.796.874.864011.47
YOLOv8n+MCA88.495.778.564011.47
YOLOv8n+MCA+剪枝87.797.276.56405.59
  1. 由此可见在KITTI验证集上,MCA注意力机制+剪枝可以做到在参数量下降51%的情况下涨点1%
  2. 验证集是通过留出法将训练集按9:1比例进行划分所得

五、剪枝前后效果对比

剪枝前:
在这里插入图片描述
剪枝后:
在这里插入图片描述

实验设备为RTX2060,如上图所示 剪枝后的模型FPS更高,推理速度更快

六、完整代码获取

链接:YOLOv8结合MCA注意力机制+通道剪枝-同时实现涨点和轻量化


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/39424.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【日常记录】【JS】动态执行JS脚本

文章目录 1、第一种方式:eval2、第二种方式:setTimeout3、第三种方式:创建script 标签插入body4、第四种方式:创建 Function5、对比6、 参考链接 1、第一种方式:eval 语法 eval(string)参数 string:一个…

获取目标机器的ssh反弹权限后,如何通过一台公网服务器的服务 jar 包进行偷梁换柱植入目录进行钓鱼,从而获取目标使用人的终端设备权限和个人信息?

网络攻防实战中获取目标机器的ssh反弹权限后,如何通过一台公网服务器的服务 jar 包进行偷梁换柱植入目录进行钓鱼,从而获取目标使用人的终端设备权限和个人信息? 具体流程如下: 1)获取了目标用户经常访问的一台服务器信息,并能反弹shell回来进行远程ssh链接; 2)分析…

Linux指定文件权限的两种方式-符号与八进制数方式示例

一、指定文件权限可用的两种方式: 对于八进制数指定的方式,文件权限字符代表的有效位设为‘1’,即“rw-”、“rw-”、“r--”,以二进制表示为“110”、“110”、“100”,再转换为八进制6、6、4,所以777代表…

Android 11.0 SettingsProvider 源码分析

文章目录 一、SettingsProvider 的概述二、SettingsProvider 的启动流程三、对 SettingsProvider 进行操作方法四、客制化示例 一、SettingsProvider 的概述 SettingsProvider 是一个为 Android 系统设置提供数据共享的 Provider,它包含全局、安全和系统级别的用户…

配置WLAN 示例

规格 仅AR129CVW、AR129CGVW-L、AR109W、AR109GW-L、AR161W、AR161EW、AR161FGW-L、AR161FW、AR169FVW、AR169JFVW-4B4S、AR169JFVW-2S、AR169EGW-L、AR169EW、AR169FGW-L、AR169W-P-M9、AR1220EVW和AR301W支持WLAN-FAT AP功能。 组网需求 如图1所示,企业使用WLAN…

【拓展】理解AppID、OpenID、UnionID

目录 历史背景AppIDAppSecretOpenIDUnionID三者区别使用方法AppIDOpenID/**UnionID**拓展 历史背景 基本概念介绍 | 微信开放文档 微信小程序:一文彻底搞懂openid和unionid-腾讯云开发者社区-腾讯云 用户进行小程序登陆时,需要获取用户信息,…

通用的ERP系统功能清单有哪些?

一、通用的ERP系统功能清单 通用的ERP(Enterprise Resource Planning,企业资源计划)系统是一套集成的业务应用程序,旨在帮助企业有效管理财务、销售、运营等关键业务流程。以下是一个清晰的ERP系统功能清单,涵盖了其主…

【Flutter】列表流畅性优化

前言 在日常APP的开发中,列表是使用频率最高的,这里讲述在Flutter中优化列表的滑动速度与流畅度,以来提高用户的体验。 方案 1、使用ListView.builder代替ListView ListView.builder在创建列表的时候要比ListView更高效,因为L…

工程技术类SCI,低分快刊首选期刊,无版面费!

1、期刊概况 【期刊简介】IF:1.0-2.0,JCR2区,中科院4区; 【检索情况】SCIE在检 【版面类型】正刊,仅少量版面; 【出刊频率】年刊 2、征稿范围 本刊主要是发表有关能源转型和可再生能源需求相关的研究文…

Snappy使用

Snappy使用 Snappy是谷歌开源的压缩和解压的开发包,目标在于实现高速的压缩而不是最大的压缩 项目地址:GitHub - google/snappy:快速压缩器/解压缩器 Cmake版本升级 该项目需要比较新的cmake,CMake 3.16.3 or higher is requi…

一首歌的时间 写成永远

大家好,我是秋意零。 就在,2024年6月20日。我本科毕业了,之前专科毕业挺有感触,也写了一篇文章进行记录。如今又毕业了,还是写一篇文章记录吧!! 专科毕业总结:大学三年总结&#xf…

【SpringBoot3学习 | 第1篇】SpringBoot3介绍与配置文件

文章目录 前言 一. SpringBoot3介绍1.1 SpringBoot项目创建1. 创建Maven工程2. 添加依赖(springboot父工程依赖 , web启动器依赖)3. 编写启动引导类(springboot项目运行的入口)4. 编写处理器Controller5. 启动项目 1.2 项目理解1. 依赖不需要写版本原因2. 启动器(Starter)3. Sp…

二刷 动态规划

什么是动态规划 Dynamic Programming DP 如果某一问题有很多重叠子问题,使用动态规划时最有效的 动态规划中每一个状态是由上一个状态推导出来的。 动规五部曲 1.确定dp数组以及下标的含义 2.确定递归公式 3.dp数组如何初始化 4.确定遍历顺序 5.举例推导dp数…

【java计算机毕设】仓库管理系统 MySQL springboot vue3 Maven 项目源码代码

目录 1项目功能 2项目介绍 3项目地址 1项目功能 【java计算机毕设】仓库管理系统MySQL springboot vue3 Maven小组项目设计源代码 2项目介绍 系统功能: vue3仓库管理系统,主要功能包含:个人信息管理,仓库管理,员工…

java设计模式(七)适配器模式(Adapter Pattern)

1、模式介绍: 适配器模式(Adapter Pattern)是一种结构型设计模式,它允许将一个类的接口转换成客户希望的另外一个接口。适配器模式通常用于需要复用现有的类,但是接口与客户端的要求不完全匹配的情况。它包括两种形式&…

【深度学习】注意力机制

https://blog.csdn.net/weixin_43334693/article/details/130189238 https://blog.csdn.net/weixin_47936614/article/details/130466448 https://blog.csdn.net/qq_51320133/article/details/138305880 注意力机制:在处理信息的时候,会将注意力放在需要…

gitee项目上不同的项目分别使用不用的用户上传

最近使用根据需要,希望不同的项目使用不同的用户上传,让不同的仓库展示不同的用户名!!! 第一步查看全局的用户信息: # 查看目前全局git配置信息 git config -l #会输出全局的git配置信息 第二步进入到要设…

大科技公司大量裁员背后的真相

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

冒泡排序、选择排序、菱形

冒泡排序、选择排序、菱形 文章目录 一、冒泡排序二、选择排序三、菱形 一、冒泡排序 思路: 外层(第一层)循环控制循环次数,和业务无关 内层(第二层)循环用于比较相邻的2个值的大小,根据小到大…

B站、小红书“崩”了!阿里云紧急回应

7月2日,“B站崩了”“小红书崩了”冲上微博热搜!据悉,“崩了”的原因是阿里云上海服务出现异常。 B站App无法使用浏览历史关注等内容,消息界面、更新界面、客服界面均不可用,用户也无法评论和发弹幕,视频评…