yolov5及yolov7实战之剪枝

之前有讲过一次yolov5的剪枝:yolov5实战之模型剪枝_yolov5模型剪枝-CSDN博客
当时基于的是比较老的yolov5版本,剪枝对整个训练代码的改动也比较多。最近发现一个比较好用的剪枝库,可以在不怎么改动原有训练代码的情况下,实现剪枝的操作,这篇文章就简单介绍一下,剪枝的概念以及为什么要剪枝可以参看上一篇,这里就不赘述了。

Torch-Pruning

VainF/Torch-Pruning: [CVPR 2023] Towards Any Structural Pruning; LLMs / Diffusion / Transformers / YOLOv8 / CNNs (github.com)
今天我们要用到的就是这个剪枝库,这个库集成了很多剪枝的方法,毕竟使用比较简单。

用法

这个剪枝库既有low level的剪枝,也就是手动控制剪枝哪些层,也有high level的剪枝,就是使用预设的剪枝算法,自动选择剪枝的部分。对于我们来说,更适合使用high level剪枝。具体的这里使用和上一篇yolov5里面的剪枝一样的算法,在这个库里叫BNScalePruner。

安装

首先我们需要安装上面提到的库,有两种方式来安装:

pip install torch-pruning

或源码安装(当碰到bug发布版本没修复,源码修复的时候):

pip install git+https://github.com/VainF/Torch-Pruning.git

稀疏化训练

为了更好的剪枝,我们在训练剪枝前的网络时,推荐开启稀疏化训练,利用这个库,我们可以很方便的实现这个操作。
首先在我们的训练代码中定义好剪枝器, 这里的opt.prune是我自己加的来控制是否开启稀疏化训练的标志:

# prune
if opt.prune:examle_input = torch.randn(1, 3, imgsz, imgsz).to(device)imp = tp.importance.BNScaleImportance()pruner = tp.pruner.BNScalePruner(model, examle_input, imp,reg=0.0001)

稀疏化训练主要需要设置reg参数,一般设置0.001~1e-6之间。
定义好剪枝器后,在训练代码的scaler.scale(loss).backward()之后,添加如下代码:

if opt.prune:pruner.regularize(model)

即可实现稀疏化训练。

剪枝

稀疏化训练后(也可以不做稀疏化训练),我们就可以进行剪枝操作了。这个库可以在训练中交互式进行多次剪枝,简单起见,我们这里分离剪枝和训练的代码,只进行剪枝操作。

import torch_pruning as tp
from models.experimental import attempt_load
import torchweights = "yolov7.pt"
model = attempt_load(weights, map_location=torch.device('cuda:0'), fuse=False)
for p in model.parameters():p.requires_grad = True
ignored_layers = []
from models.yolo import Detect, IDetect
from models.common import ImplicitA, ImplicitM
for m in model.modules():if isinstance(m, (Detect,IDetect)):ignored_layers.append(m.m)
unwrapped_parameters = []
for name, m in model.named_parameters():if isinstance(m, (ImplicitA,ImplicitM,)):unwrapped_parameters.append((name,1)) # pruning 1st dimension of implicit matrixprint(ignored_layers)
example_inputs = torch.rand(1, 3, 416, 416, device='cuda:0')
imp = tp.importance.BNScaleImportance()
pruner = tp.pruner.BNScalePruner(model, example_inputs, imp,ignored_layers=ignored_layers,unwrapped_parameters=unwrapped_parameters,global_pruning=True,ch_sparsity=0.3,round_to=8,)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
pruned_model = pruner.model
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
print(f"macs: {base_macs} -> {pruned_macs}")
print(f"nparams: {base_nparams} -> {pruned_nparams}")
macs_cutoff_ratio = (base_macs - pruned_macs) / base_macs
nparams_cutoff_ratio = (base_nparams - pruned_nparams) / base_nparams
print(f"macs cutoff ratio: {macs_cutoff_ratio}")
print(f"nparams cutoff ratio: {nparams_cutoff_ratio}")
save_path = weights.replace(".pt", "_pruned_bn_0.3.pt")torch.save({"model": pruned_model.module if hasattr(pruned_model, 'module') else pruned_model}, save_path)

去掉一些计算剪枝比例的,保存代码等代码外,剪枝操作其实由pruner.step()这一步完成。这里我们主要需要设置的参数是:

  • ch_sparsity: 可以理解成剪枝的比例,越大剪得越多
  • global_pruning: True表示整个模型的权重按一个整体排序后剪枝,False表示按分组内部按比例剪枝
  • round_to: 剪枝后的通道保留为多少的倍数,一般在硬件上,保留8的倍数

微调

经过剪枝的网络,精度是下降比较明显的,需要再在数据上finetune一些epoch才能把精度拉回来。
yolov7默认是通过yaml文件创建模型结构,然后再载入权重进行训练的,而我们剪枝后的模型是没有模型结构文件的,因此需要对训练代码做一定的修改,具体而言,只是对模型的载入进行一点修改。其中opt.finetune是用来控制是否处于finetune模式的标志位。

if opt.finetune: # for model without cfgnew = torch.load(weights, map_location=device)  # createmodel = new["model"]print("Finetune Mode...")
elif pretrained:
...

比较简单的改法是这样,从checkpoint中载入结构和权重,还有一种方式则是修改yolov7的Model类,这个在后面讲yolov7剪枝后蒸馏的时候再讲,暂时用上面这种方式就可以了。

评测

我在自己的任务上的效果是yolov7剪枝50%,微调后基本上能达到剪枝前的map,没记错的话这是和稀疏化训练的比,毕竟开启稀疏化训练本身也会掉点。大家可以在自己的任务上尝试一下,总体上精度还是可以的

结语

这篇文章简述了以下yolov7的剪枝,yolov5也可用,希望对大家有帮助。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/97343.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

李沐深度学习记录5:13.Dropout

Dropout从零开始实现 import torch from torch import nn from d2l import torch as d2l# 定义Dropout函数 def dropout_layer(X, dropout):assert 0 < dropout < 1# 在本情况中&#xff0c;所有元素都被丢弃if dropout 1:return torch.zeros_like(X)# 在本情况中&…

超自动化加速落地,助力运营效率和用户体验显著提升|爱分析报告

RPA、iPaaS、AI、低代码、BPM、流程挖掘等在帮助企业实现自动化的同时&#xff0c;也在构建一座座“自动化烟囱”。自动化工具尚未融为一体&#xff0c;协同价值没有得到释放。Gartner于2019年提出超自动化&#xff08;Hyperautomation&#xff09;概念&#xff0c;主要从技术组…

【动手学深度学习】课程笔记 04 数据操作和数据预处理

目录 数据操作 N维数组样例 访问元素 数据操作实现 入门 运算符 广播机制 节省内存 转换为其他Python对象 数据预处理实现 数据操作 N维数组是机器学习和神经网路的主要数据结构。 N维数组样例 访问元素 数据操作实现 下面介绍一下本课程中需要用到的PyTorch相关操…

BJT晶体管

BJT晶体管也叫双极结型三极管&#xff0c;主要有PNP、NPN型两种&#xff0c;符号如下&#xff1a; 中间的是基极&#xff08;最薄&#xff0c;用于控制&#xff09;&#xff0c;带箭头的是发射极&#xff08;自由电子浓度高&#xff09;&#xff0c;剩下的就是集电极&#xff0…

no Go files in ...问题

golang项目&#xff0c;当我们微服务分模块开发时&#xff0c;习惯把main.go放在cmd目录下分模块放置&#xff0c;此时&#xff0c;我们在项目根目录下执行go test . 或go build . 时会报错“no Go files in ...”, 这是因为在.目录下找不到go程序&#xff0c;或者找不到程序入…

python之subprocess模块详解

介绍 subprocess是Python 2.4中新增的一个模块&#xff0c;它允许你生成新的进程&#xff0c;连接到它们的 input/output/error 管道&#xff0c;并获取它们的返回&#xff08;状态&#xff09;码。 这个模块的目的在于替换几个旧的模块和方法。 那么我们到底该用哪个模块、哪个…

SQL_ERROR_INFO: “Duplicate entry ‘9003‘ for key ‘examination_info.exam_id‘“

今天刷题的时候&#xff0c;往数据库中插入一条语句&#xff0c;但是这个语句已经存在于数据库中了&#xff0c;所以不能用insert into 语句来插入&#xff0c;应该使用replace into 来插入。 REPLACE INTO examination_info(exam_id,tag,difficulty,duration,release_time) V…

springboot整合pi支付开发

pi支付流程图&#xff1a; 使用Pi SDK功能发起支付由 Pi SDK 自动调用的回调函数&#xff08;让您的应用服务器知道它需要发出批准 API 请求&#xff09;从您的应用程序服务器到 Pi 服务器的 API 请求以批准付款&#xff08;让 Pi 服务器知道您知道此付款&#xff09;Pi浏览器向…

【排序算法】堆排序详解与实现

一、堆排序的思想 堆排序(Heapsort)是指利用堆积树&#xff08;堆&#xff09;这种数据结构所设计的一种排序算法&#xff0c;它是选择排序的一种。它是通过堆&#xff08;若不清楚什么是堆&#xff0c;可以看我前面的文章&#xff0c;有详细阐述&#xff09;来进行选择数据&am…

论文阅读-- A simple transmit diversity technique for wireless communications

一种简单的无线通信发射分集技术 论文信息&#xff1a; Alamouti S M. A simple transmit diversity technique for wireless communications[J]. IEEE Journal on selected areas in communications, 1998, 16(8): 1451-1458. 创新性&#xff1a; 提出了一种新的发射分集方…

八大排序java

冒泡排序 /*** 冒泡排序&#xff1a;* 比较相邻的元素。如果第一个比第二个大&#xff0c;就交换他们两个。* 对每一对相邻元素作同样的工作&#xff0c;从开始第一对到结尾的最后一对。这步做完后&#xff0c;最后的元素会是最大的数。* 针对所有的元素重复以上的步骤&#x…

WEB各类常用测试工具

一、单元测试/测试运行器 1、Jest 知名的 Java 单元测试工具&#xff0c;由 Facebook 开源&#xff0c;开箱即用。它在最基础层面被设计用于快速、简单地编写地道的 Java 测试&#xff0c;能自动模拟 require() 返回的 CommonJS 模块&#xff0c;并提供了包括内置的测试环境 …

华为OD机试 - 最小步骤数(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入&#xff1a;4 8 7 5 2 3 6 4 8 12、输出&#xff1a;23、说明&#xff1a;4、思路分析 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《…

aarch64 平台 musl gcc 工具链手动编译方法

目标 手动编译一个 aarch64 平台的 musl gcc 工具链 musl libc 与 glibc、uclibc 等,都是 标准C 库, musl libc 是基于系统调用之上的 标准C 库,也就是用户态的 标准C 库。 musl libc 轻量、开源、免费,是一些 操作系统的选择,当前 Lite-OS 与 RT-Smart 等均采用自制的 mu…

【Vue面试题八】、为什么data属性是一个函数而不是一个对象?

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a;为什么data属性是一个函…

Spring实例化源码解析之Custom Events上集(八)

Events使用介绍 在ApplicationContext中&#xff0c;事件处理通过ApplicationEvent类和ApplicationListener接口提供。如果将实现ApplicationListener接口的bean部署到上下文中&#xff0c;每当一个ApplicationEvent被发布到ApplicationContext时&#xff0c;该bean将被通知。…

使用企业订货系统后的效果|软件定制开发|APP小程序搭建

使用企业订货系统后的效果|软件定制开发|APP小程序搭建 企业订货系统是一种高效的采购管理系统&#xff0c;它可以帮助企业更好地管理采购流程&#xff0c;降低采购成本&#xff0c;提高采购效率。 可以帮助企业提高销售效率和降低成本的软件工具。使用该系统后&#xff0c;企业…

如何使用 Tensor.art 实现文生图

摘要&#xff1a;Tensor.art 是一个基于 AI 的文本生成图像工具。本文介绍了如何使用 Tensor.art 来实现文生图的功能。 正文&#xff1a; 文生图是指将文本转换为图像的技术。它具有广泛的应用&#xff0c;例如在广告、教育和娱乐等领域。 Tensor.art 是一个基于 AI 的文本…

【SA8295P 源码分析】103 - QNX DDR RAM 内存布局分析

【SA8295P 源码分析】103 - QNX DDR RAM 内存布局分析 一、SA8295P QNX RAM 内存布局 (16G DDR)1.1 DDR 汇总描述1.2 QNX Meta reserved memory, DDR Rank01.3 Reserved for qnx1.4 Android GVM SysRam 相关内存(可修改)1.5 Reserved for qnx(不要修改)1.6 QNX SysRam 相关内…

强制删除文件?正确操作方法分享!

“我昨天在删除文件时有个文件一直删除不掉。想用强制删除的方法来把它删掉&#xff0c;应该怎么操作呢&#xff1f;谁能教教我呀&#xff1f;” 在使用电脑的过程中&#xff0c;我们有时候可能会发现文件无论怎么删除都无法删掉&#xff0c;如果我们想要强制删除文件但不知道怎…