Yolov8模型用torch_pruning剪枝

目录

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

 遍历所有分组

高级剪枝器


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

原理

传统剪枝方法的缺陷

在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运 行模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的tp.Group.

一种通用的结构化剪枝框架DepGraph(Dependency Graph),可以应用于任意类型的神经网络架构(包括CNN、RNN、GNN和Transformer等)进行结构化剪枝。主要原理如下:

1. 神经网络内部存在着层与层之间的依赖关系,需要同时剪枝依赖的层组,否则会破坏网络结构。

2. 结构化剪枝的优势

结构化剪枝的做法是,找到网络中相互依赖的层组,把整个层组同时全部保留或全部删除,从而保证网络结构的完整性。这种做法虽然灵活性较低,但可以有效避免了网络结构被破坏的问题。

3. DepGraph通过建模层与层之间的依赖关系,明确每一层所属的层组。具体分为两种依赖关系:

   a) 层间依赖(Inter-layer Dependency): 相邻连接的层之间存在依赖   层间不依赖:resnet

   b) 层内依赖(Intra-layer Dependency): 同一层的输入和输出具有相同的剪枝方式时存在依赖   层内不依赖:没有共享权重的

4. 通过图遍历算法在DepGraph上找到最大连接分量作为层组,实现自动化的层组划分。总的来说,DepGraph解决了之前结构化剪枝算法依赖人工设计层组划分规则、缺乏通用性的问题,提出了一种自动建模层组依赖关系和组级剪枝重要性评估的通用框架。

5. DepGraph的工作原理

以ResNet的基本模块为例,如果要删除某个卷积层的滤波器核,由于残差连接的存在,我们必须同时删除该模块中所有层(BN层、ReLU层等)对应的通道。DepGraph通过建模层与层之间的依赖关系,自动将这些相互依赖的层划分到同一个层组中。在剪枝时,整个层组被统一评分,决定是完全保留还是完全删除,从而实现安全的结构化剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True).eval()# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )print(pruning_group.details())  # or print(pruning_group)# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):pruning_group.prune()

这个例子演示了使用 DepGraph剪枝的基本流程, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 可以看到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作. 

--------------------------------Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #idxs=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), #idxs=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), #idxs=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), #idxs=3
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), #idxs=3
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), #idxs=3
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), #idxs=3
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #idxs=3
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #idxs=3
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
--------------------------------
 遍历所有分组

可以利用DG.get_all_groups(ignored_layers, root_module_types)来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引idxs=[0,1,2,3,...,K], 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用group.prune(idxs=idxs)来指定具体的修剪通道/维度.

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):# handle groups in sequential orderidxs = [2,4,6] # your pruning indicesgroup.prune(idxs=idxs)print(group)
高级剪枝器
import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True)# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2) # p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性   重要性如何计算??什么是重要的?值大还是小?是损失吗ignored_layers = []
for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers,
)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/733500.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM基本概念、命令、参数、GC日志总结

原文: 赵侠客 一、前言 NPE&#xff08;NullPointerException&#xff09;和OOM&#xff08;OutofMemoryError&#xff09;在JAVA程序员中扮演着重要的角色&#xff0c;它也是很多人始终摆脱不掉的梦魇&#xff0c;与NPE不同的是OOM一旦在生产环境中出现就意味着只靠代码已经无…

java集合题库详解

1. Arraylist与LinkedList区别 可以从它们的底层数据结构、效率、开销进行阐述哈 ArrayList是数组的数据结构&#xff0c;LinkedList是链表的数据结构。 随机访问的时候&#xff0c;ArrayList的效率比较高&#xff0c;因为LinkedList要移动指针&#xff0c;而ArrayList是基于索…

深入理解数据压缩流程及 zlib 库中相关函数

数据压缩是一种常见的操作&#xff0c;可以有效地减小数据的体积&#xff0c;节省存储空间和网络带宽。在本文中&#xff0c;我们将深入讨论数据压缩的流程&#xff0c;并详细解释 zlib 库中相关函数的使用&#xff0c;包括 deflateInit2()、deflate()、deflateEnd() 等。我们将…

Java 客户端向服务端上传文件(TCP通信)

一、实验内容 编写一个客户端向服务端上传文件的程序&#xff0c;要求使用TCP通信的的知识&#xff0c;完成将本地机器输入的路径下的文件上传到D盘中名称为upload的文件夹中。并把客户端的IP地址加上count标识作为上传后文件的文件名&#xff0c;即IP&#xff08;count&#…

OpenSearch 与 Elasticsearch:哪个开源搜索引擎适合您?

当谈论到搜索引擎产品时&#xff0c;Elasticsearch 和 OpenSearch 是两个备受关注的选择。它们都以其出色的功能和灵活性而闻名&#xff0c;但在一些方面存在一些差异。在本文中&#xff0c;我们将从功能和延展性、工具与资源、价格和许可这三个角度对这两个产品进行论述。通过…

qt+opencv人脸人眼检测识别

项目运行涉及到opencv库&#xff0c;以及haarcascade_frontalface_default.xml和haarcascade_eye_tree_eyeglasses.xml。qt配置opencv可见先前文章&#xff0c;另外这两份OpenCV 中用于眼睛检测的级联分类器xml文件&#xff0c;是我在网上下载的。 把要使用到的文件都放到当前…

鸿蒙培训开发:就业市场的新热点~

金三银四在即&#xff0c;随着春节假期结束&#xff0c;各行各业纷纷复工复产&#xff0c;2024年的春季招聘市场也迎来了火爆的局面。最近发布的《2024年春招市场行情周报&#xff08;第一期&#xff09;》显示&#xff0c;尽管整体就业市场仍处于人才饱和状态&#xff0c;但华…

spring-cloud-openfeign 3.0.0(对应spring boot 2.4.x之前版本)之前版本feign整合ribbon请求流程

在之前写的文章配置基础上 https://blog.csdn.net/zlpzlpzyd/article/details/136060312 下图为自己整理的

物联网与边缘计算的结合

目录 一、实时响应与决策 二、降低网络带宽需求和传输延迟 三、隐私保护与数据安全 四、系统可靠性与稳定性 总结 物联网与边缘计算的结合&#xff1a;为未来的智能化应用注入强大动力 随着科技的飞速发展&#xff0c;物联网与边缘计算的结合已经成为推动各行各业创新发展…

Excel 快速填充/输入内容

目录 一. Ctrl D/R 向下/右填充二. 批量输入内容 一. Ctrl D/R 向下/右填充 ⏹如下图所示&#xff0c;通过快捷键向下和向右填充数据 &#x1f914;当选中第一个单元格之后&#xff0c;可以按住Shift后&#xff0c;再选中最后一个单元格&#xff0c;可以选中第一个单元格和最…

python常识系列24-->python操作mysql之pymysql

前言 pymsql是Python中操作MySQL的模块程序在运行时&#xff0c;数据都是在内存中的。当程序终止时&#xff0c;通常需要将数据保存在磁盘上。 安装模块 pip install PyMySql基本使用 ## 使用 connect 函数创建连接对象&#xff0c;此连接对象提供关闭数据库、事务回滚等操…

自动驾驶技术解析与关键步骤

目录 前言1 自动驾驶主要技术流程1.1 车辆周围环境感知1.2 车辆和行人检测分析1.3 运动轨迹规划 2 关键技术概述2.1 车辆探测与图片输入2.2 行人检测2.3 运动规划2.4 电子地图2.5 轨迹预测2.6 交通灯分析2.7 故障检测 结语 前言 自动驾驶汽车作为未来交通领域的重要发展方向&a…

【Python】-入门:安装配置和IDLE的使用

Python的安装和配置 一、下载Python安装包 首先&#xff0c;你需要从Python的官方网站&#xff08;https://www.python.org/downloads/&#xff09;下载适合你操作系统的Python安装包。请注意&#xff0c;Python 2.x版本即将停止维护&#xff0c;因此推荐下载Python 3.x版本。…

不同框架表示图像时维度顺序的区别:pytorch、kerastf、opencv、numpy、PIL

在PyTorch、Keras、OpenCV、NumPy和PIL这几个框架中&#xff0c;它们在表示图像时的维度存储顺序有所不同。下面我将逐一解释每个框架中图像维度的存储顺序&#xff1a; 1&#xff0c;PyTorch: PyTorch中图像的维度顺序通常遵循 [N, C, H, W] 的格式&#xff0c;也就是channe…

【LGR-176-Div.2】[yLCPC2024] 洛谷 3 月月赛 I(A~C and G<oeis>)

[yLCPC2024] A. dx 分计算 前缀和提前处理一下区间和&#xff0c;做到O&#xff08;1&#xff09;访问就可以过。 #include <bits/stdc.h> //#define int long long #define per(i,j,k) for(int (i)(j);(i)<(k);(i)) #define rep(i,j,k) for(int (i)(j);(i)>(k);…

Redis作为缓存的数据一致性问题

背景 使用Reids作为缓存的原因&#xff1a; 在高并发场景下&#xff0c;传统关系型数据库的并发能力相对比较薄弱&#xff08;QPS不能太大&#xff09;&#xff1b; 使用Redis做一个缓存。让用户请求先打到Redis上而不是直接打到数据库上。 但是如果出现数据更新操作&#xff…

【C/C++ 学习笔记】运算符

【C/C 学习笔记】运算符 视频地址: Bilibili 算术运算符 运算符含义备注 加号 − - −减号 ∗ * ∗乘号 / / /除号整数相除结果依然是整数&#xff08;直接舍去小数部分&#xff09;&#xff0c;小数相除还是小数 % 取模小数无法进行取模运算&#xff1b;对 0 取模会报错 …

Windows下同一电脑配置多个Git公钥访问不同的账号

前言 产生这个问题的原因是我在Gitee码云上有两个账号,为了方便每次不用使用http模式推拉代码,于是我就使用了ssh的模式,起初呢我用两台电脑分别连接两个账号,用起来也相安无事,近段时时间台式机在家里,我在外地出差了,就想着把ssh公钥同时添加到不同的账号里,结果却发现不能用…

超网、IP 聚合、IP 汇总分别是什么?三者有啥区别和联系?

一、超网 超网&#xff08;Supernet&#xff09;是一种网络地址聚合技术&#xff0c;它可以将多个连续的网络地址合并成一个更大的网络地址&#xff0c;从而减少路由表的数量和大小。超网技术可以将多个相邻的网络地址归并成一个更大的网络地址&#xff0c;这个更大的网络地址…

【Vue3 组合式 API: reactive 和 ref 函数】

文章目录 1. 什么是组合式 API&#xff1f;2. reactive 函数3. ref 函数4. reactive vs ref 1. 什么是组合式 API&#xff1f; 组合式 API 是 Vue 3 中的一种新特性&#xff0c;它允许我们通过函数来组织组件的逻辑&#xff0c;而不是依赖于选项式 API 中的选项对象。这使得代…