yolov5及yolov7实战之剪枝

之前有讲过一次yolov5的剪枝:yolov5实战之模型剪枝_yolov5模型剪枝-CSDN博客
当时基于的是比较老的yolov5版本,剪枝对整个训练代码的改动也比较多。最近发现一个比较好用的剪枝库,可以在不怎么改动原有训练代码的情况下,实现剪枝的操作,这篇文章就简单介绍一下,剪枝的概念以及为什么要剪枝可以参看上一篇,这里就不赘述了。

Torch-Pruning

VainF/Torch-Pruning: [CVPR 2023] Towards Any Structural Pruning; LLMs / Diffusion / Transformers / YOLOv8 / CNNs (github.com)
今天我们要用到的就是这个剪枝库,这个库集成了很多剪枝的方法,毕竟使用比较简单。

用法

这个剪枝库既有low level的剪枝,也就是手动控制剪枝哪些层,也有high level的剪枝,就是使用预设的剪枝算法,自动选择剪枝的部分。对于我们来说,更适合使用high level剪枝。具体的这里使用和上一篇yolov5里面的剪枝一样的算法,在这个库里叫BNScalePruner。

安装

首先我们需要安装上面提到的库,有两种方式来安装:

pip install torch-pruning

或源码安装(当碰到bug发布版本没修复,源码修复的时候):

pip install git+https://github.com/VainF/Torch-Pruning.git

稀疏化训练

为了更好的剪枝,我们在训练剪枝前的网络时,推荐开启稀疏化训练,利用这个库,我们可以很方便的实现这个操作。
首先在我们的训练代码中定义好剪枝器, 这里的opt.prune是我自己加的来控制是否开启稀疏化训练的标志:

# prune
if opt.prune:examle_input = torch.randn(1, 3, imgsz, imgsz).to(device)imp = tp.importance.BNScaleImportance()pruner = tp.pruner.BNScalePruner(model, examle_input, imp,reg=0.0001)

稀疏化训练主要需要设置reg参数,一般设置0.001~1e-6之间。
定义好剪枝器后,在训练代码的scaler.scale(loss).backward()之后,添加如下代码:

if opt.prune:pruner.regularize(model)

即可实现稀疏化训练。

剪枝

稀疏化训练后(也可以不做稀疏化训练),我们就可以进行剪枝操作了。这个库可以在训练中交互式进行多次剪枝,简单起见,我们这里分离剪枝和训练的代码,只进行剪枝操作。

import torch_pruning as tp
from models.experimental import attempt_load
import torchweights = "yolov7.pt"
model = attempt_load(weights, map_location=torch.device('cuda:0'), fuse=False)
for p in model.parameters():p.requires_grad = True
ignored_layers = []
from models.yolo import Detect, IDetect
from models.common import ImplicitA, ImplicitM
for m in model.modules():if isinstance(m, (Detect,IDetect)):ignored_layers.append(m.m)
unwrapped_parameters = []
for name, m in model.named_parameters():if isinstance(m, (ImplicitA,ImplicitM,)):unwrapped_parameters.append((name,1)) # pruning 1st dimension of implicit matrixprint(ignored_layers)
example_inputs = torch.rand(1, 3, 416, 416, device='cuda:0')
imp = tp.importance.BNScaleImportance()
pruner = tp.pruner.BNScalePruner(model, example_inputs, imp,ignored_layers=ignored_layers,unwrapped_parameters=unwrapped_parameters,global_pruning=True,ch_sparsity=0.3,round_to=8,)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
pruned_model = pruner.model
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
print(f"macs: {base_macs} -> {pruned_macs}")
print(f"nparams: {base_nparams} -> {pruned_nparams}")
macs_cutoff_ratio = (base_macs - pruned_macs) / base_macs
nparams_cutoff_ratio = (base_nparams - pruned_nparams) / base_nparams
print(f"macs cutoff ratio: {macs_cutoff_ratio}")
print(f"nparams cutoff ratio: {nparams_cutoff_ratio}")
save_path = weights.replace(".pt", "_pruned_bn_0.3.pt")torch.save({"model": pruned_model.module if hasattr(pruned_model, 'module') else pruned_model}, save_path)

去掉一些计算剪枝比例的,保存代码等代码外,剪枝操作其实由pruner.step()这一步完成。这里我们主要需要设置的参数是:

  • ch_sparsity: 可以理解成剪枝的比例,越大剪得越多
  • global_pruning: True表示整个模型的权重按一个整体排序后剪枝,False表示按分组内部按比例剪枝
  • round_to: 剪枝后的通道保留为多少的倍数,一般在硬件上,保留8的倍数

微调

经过剪枝的网络,精度是下降比较明显的,需要再在数据上finetune一些epoch才能把精度拉回来。
yolov7默认是通过yaml文件创建模型结构,然后再载入权重进行训练的,而我们剪枝后的模型是没有模型结构文件的,因此需要对训练代码做一定的修改,具体而言,只是对模型的载入进行一点修改。其中opt.finetune是用来控制是否处于finetune模式的标志位。

if opt.finetune: # for model without cfgnew = torch.load(weights, map_location=device)  # createmodel = new["model"]print("Finetune Mode...")
elif pretrained:
...

比较简单的改法是这样,从checkpoint中载入结构和权重,还有一种方式则是修改yolov7的Model类,这个在后面讲yolov7剪枝后蒸馏的时候再讲,暂时用上面这种方式就可以了。

评测

我在自己的任务上的效果是yolov7剪枝50%,微调后基本上能达到剪枝前的map,没记错的话这是和稀疏化训练的比,毕竟开启稀疏化训练本身也会掉点。大家可以在自己的任务上尝试一下,总体上精度还是可以的

结语

这篇文章简述了以下yolov7的剪枝,yolov5也可用,希望对大家有帮助。
f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/97343.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

李沐深度学习记录5:13.Dropout

Dropout从零开始实现 import torch from torch import nn from d2l import torch as d2l# 定义Dropout函数 def dropout_layer(X, dropout):assert 0 < dropout < 1# 在本情况中&#xff0c;所有元素都被丢弃if dropout 1:return torch.zeros_like(X)# 在本情况中&…

超自动化加速落地,助力运营效率和用户体验显著提升|爱分析报告

RPA、iPaaS、AI、低代码、BPM、流程挖掘等在帮助企业实现自动化的同时&#xff0c;也在构建一座座“自动化烟囱”。自动化工具尚未融为一体&#xff0c;协同价值没有得到释放。Gartner于2019年提出超自动化&#xff08;Hyperautomation&#xff09;概念&#xff0c;主要从技术组…

【动手学深度学习】课程笔记 04 数据操作和数据预处理

目录 数据操作 N维数组样例 访问元素 数据操作实现 入门 运算符 广播机制 节省内存 转换为其他Python对象 数据预处理实现 数据操作 N维数组是机器学习和神经网路的主要数据结构。 N维数组样例 访问元素 数据操作实现 下面介绍一下本课程中需要用到的PyTorch相关操…

BJT晶体管

BJT晶体管也叫双极结型三极管&#xff0c;主要有PNP、NPN型两种&#xff0c;符号如下&#xff1a; 中间的是基极&#xff08;最薄&#xff0c;用于控制&#xff09;&#xff0c;带箭头的是发射极&#xff08;自由电子浓度高&#xff09;&#xff0c;剩下的就是集电极&#xff0…

no Go files in ...问题

golang项目&#xff0c;当我们微服务分模块开发时&#xff0c;习惯把main.go放在cmd目录下分模块放置&#xff0c;此时&#xff0c;我们在项目根目录下执行go test . 或go build . 时会报错“no Go files in ...”, 这是因为在.目录下找不到go程序&#xff0c;或者找不到程序入…

springboot整合pi支付开发

pi支付流程图&#xff1a; 使用Pi SDK功能发起支付由 Pi SDK 自动调用的回调函数&#xff08;让您的应用服务器知道它需要发出批准 API 请求&#xff09;从您的应用程序服务器到 Pi 服务器的 API 请求以批准付款&#xff08;让 Pi 服务器知道您知道此付款&#xff09;Pi浏览器向…

【排序算法】堆排序详解与实现

一、堆排序的思想 堆排序(Heapsort)是指利用堆积树&#xff08;堆&#xff09;这种数据结构所设计的一种排序算法&#xff0c;它是选择排序的一种。它是通过堆&#xff08;若不清楚什么是堆&#xff0c;可以看我前面的文章&#xff0c;有详细阐述&#xff09;来进行选择数据&am…

论文阅读-- A simple transmit diversity technique for wireless communications

一种简单的无线通信发射分集技术 论文信息&#xff1a; Alamouti S M. A simple transmit diversity technique for wireless communications[J]. IEEE Journal on selected areas in communications, 1998, 16(8): 1451-1458. 创新性&#xff1a; 提出了一种新的发射分集方…

WEB各类常用测试工具

一、单元测试/测试运行器 1、Jest 知名的 Java 单元测试工具&#xff0c;由 Facebook 开源&#xff0c;开箱即用。它在最基础层面被设计用于快速、简单地编写地道的 Java 测试&#xff0c;能自动模拟 require() 返回的 CommonJS 模块&#xff0c;并提供了包括内置的测试环境 …

华为OD机试 - 最小步骤数(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入&#xff1a;4 8 7 5 2 3 6 4 8 12、输出&#xff1a;23、说明&#xff1a;4、思路分析 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《…

aarch64 平台 musl gcc 工具链手动编译方法

目标 手动编译一个 aarch64 平台的 musl gcc 工具链 musl libc 与 glibc、uclibc 等,都是 标准C 库, musl libc 是基于系统调用之上的 标准C 库,也就是用户态的 标准C 库。 musl libc 轻量、开源、免费,是一些 操作系统的选择,当前 Lite-OS 与 RT-Smart 等均采用自制的 mu…

【Vue面试题八】、为什么data属性是一个函数而不是一个对象?

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a;为什么data属性是一个函…

Spring实例化源码解析之Custom Events上集(八)

Events使用介绍 在ApplicationContext中&#xff0c;事件处理通过ApplicationEvent类和ApplicationListener接口提供。如果将实现ApplicationListener接口的bean部署到上下文中&#xff0c;每当一个ApplicationEvent被发布到ApplicationContext时&#xff0c;该bean将被通知。…

使用企业订货系统后的效果|软件定制开发|APP小程序搭建

使用企业订货系统后的效果|软件定制开发|APP小程序搭建 企业订货系统是一种高效的采购管理系统&#xff0c;它可以帮助企业更好地管理采购流程&#xff0c;降低采购成本&#xff0c;提高采购效率。 可以帮助企业提高销售效率和降低成本的软件工具。使用该系统后&#xff0c;企业…

如何使用 Tensor.art 实现文生图

摘要&#xff1a;Tensor.art 是一个基于 AI 的文本生成图像工具。本文介绍了如何使用 Tensor.art 来实现文生图的功能。 正文&#xff1a; 文生图是指将文本转换为图像的技术。它具有广泛的应用&#xff0c;例如在广告、教育和娱乐等领域。 Tensor.art 是一个基于 AI 的文本…

强制删除文件?正确操作方法分享!

“我昨天在删除文件时有个文件一直删除不掉。想用强制删除的方法来把它删掉&#xff0c;应该怎么操作呢&#xff1f;谁能教教我呀&#xff1f;” 在使用电脑的过程中&#xff0c;我们有时候可能会发现文件无论怎么删除都无法删掉&#xff0c;如果我们想要强制删除文件但不知道怎…

Leetcode hot 100之回溯O(N!):选择/DFS

目录 框架&#xff1a;排列/组合/子集 元素无重不可复选 全排列 子集 组合&#xff1a;[1, n] 中的 k 个数 分割成回文串 元素无重不可复选&#xff1a;排序&#xff0c;多条值相同的只遍历第一条 子集/组合 先进行排序&#xff0c;让相同的元素靠在一起&#xff0c;如…

VMProtect使用教程(VC++MFC中使用)

VMProtect使用教程(VCMFC中使用) VMProtect是一种商业级别的代码保护工具&#xff0c;可以用于保护VC MFC程序。以下是使用VMProtect保护VC MFC程序的步骤&#xff1a; 1. 下载并安装VMProtect,C包含库及目录。 2. 在VC MFC项目中添加VMProtectSDK.h头文件&#xff0c;并在需…

Android ncnn-android-yolov8-seg源码解析 : 实现人像分割

1. 前言 上篇文章&#xff0c;我们已经将人像分割的ncnn-android-yolov8-seg项目运行起来了&#xff0c;后续文章我们会抽取出Demo中的核心代码&#xff0c;在自己的项目中&#xff0c;来接入人体识别和人像分割功能。 先来看下效果&#xff0c;整个图像的是相机的原图&#…

番外--Task2:

任务&#xff1a;root与普通用户的互切&#xff08;区别&#xff09;&#xff0c;启动的多用户文本见面与图形界面的互切命令&#xff08;区别&#xff09;。 输入图示命令&#xff0c;重启后就由图形界面转成文本登录界面&#xff1b; 输入图示命令&#xff0c;重启后就由文本…