Yolov8模型用torch_pruning剪枝

目录

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

 遍历所有分组

高级剪枝器


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

原理

传统剪枝方法的缺陷

在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运 行模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的tp.Group.

一种通用的结构化剪枝框架DepGraph(Dependency Graph),可以应用于任意类型的神经网络架构(包括CNN、RNN、GNN和Transformer等)进行结构化剪枝。主要原理如下:

1. 神经网络内部存在着层与层之间的依赖关系,需要同时剪枝依赖的层组,否则会破坏网络结构。

2. 结构化剪枝的优势

结构化剪枝的做法是,找到网络中相互依赖的层组,把整个层组同时全部保留或全部删除,从而保证网络结构的完整性。这种做法虽然灵活性较低,但可以有效避免了网络结构被破坏的问题。

3. DepGraph通过建模层与层之间的依赖关系,明确每一层所属的层组。具体分为两种依赖关系:

   a) 层间依赖(Inter-layer Dependency): 相邻连接的层之间存在依赖   层间不依赖:resnet

   b) 层内依赖(Intra-layer Dependency): 同一层的输入和输出具有相同的剪枝方式时存在依赖   层内不依赖:没有共享权重的

4. 通过图遍历算法在DepGraph上找到最大连接分量作为层组,实现自动化的层组划分。总的来说,DepGraph解决了之前结构化剪枝算法依赖人工设计层组划分规则、缺乏通用性的问题,提出了一种自动建模层组依赖关系和组级剪枝重要性评估的通用框架。

5. DepGraph的工作原理

以ResNet的基本模块为例,如果要删除某个卷积层的滤波器核,由于残差连接的存在,我们必须同时删除该模块中所有层(BN层、ReLU层等)对应的通道。DepGraph通过建模层与层之间的依赖关系,自动将这些相互依赖的层划分到同一个层组中。在剪枝时,整个层组被统一评分,决定是完全保留还是完全删除,从而实现安全的结构化剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True).eval()# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )print(pruning_group.details())  # or print(pruning_group)# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):pruning_group.prune()

这个例子演示了使用 DepGraph剪枝的基本流程, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 可以看到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作. 

--------------------------------Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #idxs=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), #idxs=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), #idxs=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), #idxs=3
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), #idxs=3
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), #idxs=3
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), #idxs=3
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #idxs=3
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #idxs=3
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
--------------------------------
 遍历所有分组

可以利用DG.get_all_groups(ignored_layers, root_module_types)来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引idxs=[0,1,2,3,...,K], 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用group.prune(idxs=idxs)来指定具体的修剪通道/维度.

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):# handle groups in sequential orderidxs = [2,4,6] # your pruning indicesgroup.prune(idxs=idxs)print(group)
高级剪枝器
import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True)# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2) # p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性   重要性如何计算??什么是重要的?值大还是小?是损失吗ignored_layers = []
for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers,
)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/733500.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM基本概念、命令、参数、GC日志总结

原文: 赵侠客 一、前言 NPE&#xff08;NullPointerException&#xff09;和OOM&#xff08;OutofMemoryError&#xff09;在JAVA程序员中扮演着重要的角色&#xff0c;它也是很多人始终摆脱不掉的梦魇&#xff0c;与NPE不同的是OOM一旦在生产环境中出现就意味着只靠代码已经无…

java集合题库详解

1. Arraylist与LinkedList区别 可以从它们的底层数据结构、效率、开销进行阐述哈 ArrayList是数组的数据结构&#xff0c;LinkedList是链表的数据结构。 随机访问的时候&#xff0c;ArrayList的效率比较高&#xff0c;因为LinkedList要移动指针&#xff0c;而ArrayList是基于索…

Java 客户端向服务端上传文件(TCP通信)

一、实验内容 编写一个客户端向服务端上传文件的程序&#xff0c;要求使用TCP通信的的知识&#xff0c;完成将本地机器输入的路径下的文件上传到D盘中名称为upload的文件夹中。并把客户端的IP地址加上count标识作为上传后文件的文件名&#xff0c;即IP&#xff08;count&#…

OpenSearch 与 Elasticsearch:哪个开源搜索引擎适合您?

当谈论到搜索引擎产品时&#xff0c;Elasticsearch 和 OpenSearch 是两个备受关注的选择。它们都以其出色的功能和灵活性而闻名&#xff0c;但在一些方面存在一些差异。在本文中&#xff0c;我们将从功能和延展性、工具与资源、价格和许可这三个角度对这两个产品进行论述。通过…

qt+opencv人脸人眼检测识别

项目运行涉及到opencv库&#xff0c;以及haarcascade_frontalface_default.xml和haarcascade_eye_tree_eyeglasses.xml。qt配置opencv可见先前文章&#xff0c;另外这两份OpenCV 中用于眼睛检测的级联分类器xml文件&#xff0c;是我在网上下载的。 把要使用到的文件都放到当前…

鸿蒙培训开发:就业市场的新热点~

金三银四在即&#xff0c;随着春节假期结束&#xff0c;各行各业纷纷复工复产&#xff0c;2024年的春季招聘市场也迎来了火爆的局面。最近发布的《2024年春招市场行情周报&#xff08;第一期&#xff09;》显示&#xff0c;尽管整体就业市场仍处于人才饱和状态&#xff0c;但华…

spring-cloud-openfeign 3.0.0(对应spring boot 2.4.x之前版本)之前版本feign整合ribbon请求流程

在之前写的文章配置基础上 https://blog.csdn.net/zlpzlpzyd/article/details/136060312 下图为自己整理的

Excel 快速填充/输入内容

目录 一. Ctrl D/R 向下/右填充二. 批量输入内容 一. Ctrl D/R 向下/右填充 ⏹如下图所示&#xff0c;通过快捷键向下和向右填充数据 &#x1f914;当选中第一个单元格之后&#xff0c;可以按住Shift后&#xff0c;再选中最后一个单元格&#xff0c;可以选中第一个单元格和最…

自动驾驶技术解析与关键步骤

目录 前言1 自动驾驶主要技术流程1.1 车辆周围环境感知1.2 车辆和行人检测分析1.3 运动轨迹规划 2 关键技术概述2.1 车辆探测与图片输入2.2 行人检测2.3 运动规划2.4 电子地图2.5 轨迹预测2.6 交通灯分析2.7 故障检测 结语 前言 自动驾驶汽车作为未来交通领域的重要发展方向&a…

【Python】-入门:安装配置和IDLE的使用

Python的安装和配置 一、下载Python安装包 首先&#xff0c;你需要从Python的官方网站&#xff08;https://www.python.org/downloads/&#xff09;下载适合你操作系统的Python安装包。请注意&#xff0c;Python 2.x版本即将停止维护&#xff0c;因此推荐下载Python 3.x版本。…

【LGR-176-Div.2】[yLCPC2024] 洛谷 3 月月赛 I(A~C and G<oeis>)

[yLCPC2024] A. dx 分计算 前缀和提前处理一下区间和&#xff0c;做到O&#xff08;1&#xff09;访问就可以过。 #include <bits/stdc.h> //#define int long long #define per(i,j,k) for(int (i)(j);(i)<(k);(i)) #define rep(i,j,k) for(int (i)(j);(i)>(k);…

Redis作为缓存的数据一致性问题

背景 使用Reids作为缓存的原因&#xff1a; 在高并发场景下&#xff0c;传统关系型数据库的并发能力相对比较薄弱&#xff08;QPS不能太大&#xff09;&#xff1b; 使用Redis做一个缓存。让用户请求先打到Redis上而不是直接打到数据库上。 但是如果出现数据更新操作&#xff…

Windows下同一电脑配置多个Git公钥访问不同的账号

前言 产生这个问题的原因是我在Gitee码云上有两个账号,为了方便每次不用使用http模式推拉代码,于是我就使用了ssh的模式,起初呢我用两台电脑分别连接两个账号,用起来也相安无事,近段时时间台式机在家里,我在外地出差了,就想着把ssh公钥同时添加到不同的账号里,结果却发现不能用…

超网、IP 聚合、IP 汇总分别是什么?三者有啥区别和联系?

一、超网 超网&#xff08;Supernet&#xff09;是一种网络地址聚合技术&#xff0c;它可以将多个连续的网络地址合并成一个更大的网络地址&#xff0c;从而减少路由表的数量和大小。超网技术可以将多个相邻的网络地址归并成一个更大的网络地址&#xff0c;这个更大的网络地址…

Git使用教程:入门到精通

Git使用教程&#xff1a;入门到精通 一、Git安装根据需求选择电脑位数安装&#xff1b;20231023210945建议这里先新建一个文件夹如&#xff1a;D:/Git&#xff1b;专门来存放Git安装包和后续Git代码&#xff0c;方便管理&#xff1b; 二、Git使用前的配置需要先创建自己的Gitee…

贪心算法(蓝桥杯 C++ 题目 代表 注解)

介绍&#xff1a; 贪心算法&#xff08;Greedy Algorithm&#xff09;是一种在每一步选择中都采取当前状态下最好或最优&#xff08;即最有利&#xff09;的选择&#xff0c;从而希望最终能够得到全局最好或最优的结果的算法。它通常用来解决一些最优化问题&#xff0c;如最小生…

❤ Vue3项目搭建系统篇(二)

❤ Vue3项目搭建系统篇&#xff08;二&#xff09; 1、安装和配置 Element Plus&#xff08;完整导入&#xff09; yarn add element-plus --savemain.ts中引入&#xff1a; // 引入组件 import ElementPlus from element-plus import element-plus/dist/index.css const ap…

剑指offer经典题目整理(二)

一、斐波那契数列&#xff08;fib&#xff09; 1.链接 斐波那契数列_牛客题霸_牛客网 (nowcoder.com) 2.描述 斐波那契数列就是数列中任意一项数字&#xff0c;都会等于前两项之和&#xff0c;满足f(n) f(n-1) f(n-2) 的一个数列&#xff0c;例如&#xff1a;1 1 2 3 5 8…

VMware虚拟机安装Ubuntu kylin22.04系统教程(附截图详细步骤)

一、版本信息 虚拟机产品&#xff1a;VMware Workstation 17 Pro 虚拟机版本&#xff1a;17.0.0 build-20800274 ISO映像文件&#xff1a;ubuntukylin-22.04-pro-amd64.iso 二、安装步骤 打开虚拟机&#xff0c;点击创建新的虚拟机&#xff1a; 选择自定义&#xff1a; 硬…

HarmonyOS NEXT应用开发之MpChart图表实现案例

介绍 MpChart是一个包含各种类型图表的图表库&#xff0c;主要用于业务数据汇总&#xff0c;例如销售数据走势图&#xff0c;股价走势图等场景中使用&#xff0c;方便开发者快速实现图表UI。本示例主要介绍如何使用三方库MpChart实现柱状图UI效果。如堆叠数据类型显示&#xf…