论文:DepGraph: Towards Any Structural Pruning
工程:https://github.com/VainF/Torch-Pruning
算法和库的使用介绍:CVPR 2023 | DepGraph 通用结构化剪枝
1 TP的简介
该算法介绍了DepGraph 如何建模结构化剪枝中的层依赖,实现任意结构的剪枝。对应实现的库为 torch-pruning。
本篇博客对作者的介绍做一个自己的梳理和记录。
- Torch-Pruning的简单介绍
Torch-Pruning(TP)是一个结构化剪枝库,与现有框架(例如torch.nn.utils.prune)最大的区别在于,TP会物理地移除参数,同时自动裁剪其他依赖层。TP是一个纯 PyTorch 项目,实现了内置的计算图的追踪(Tracing)、依赖图(DepenednecyGraph, 见论文)、剪枝器等功能,同时支持 PyTorch 1.x 和 2.0 版本。- 用 Torch-Pruning 剪枝的好处
假设正在对一个卷积结构化剪枝,需要减去哪些内容,具体第几个卷积核、对应的偏置、BN中对应的维度、与其直接或间接相连的层的核的channel。我们要实现剪枝,需要对不同模型定制不同的代码实现。Torch-Pruning可让实现者跳脱出对层剪枝时最具体的操作,而关注于整体剪枝的设置。
2 TP的初尝试
2.1 初步尝试
以 ResNet-18 结构化剪枝为例,对【conv1】进行剪枝,同时处理对应的bn、紧临的卷积。
from torchvision.models import resnet18 import torch_pruning as tp import torchmodel = resnet18(pretrained=True).eval() tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # 剪枝前两个通道 tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1]) # 尝试修复bn tp.prune_conv_in_channels(model.layer1[0].conv1, idxs=[0,1]) # 尝试修复紧邻的conv output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络
会报错如下:问题出在残差结构上。残差的相加操作要求传入的两个tensor具有相同的空间尺寸,也就意味着剪枝后的Tensor通道数62和另一个tensor的通道数64不再匹配。
2.2 使用TP对 conv1进行剪枝
手动设置DependencyGraph是Torch-Pruning框架的底层算法,设计目标就是"自动寻找耦合层",并自动化处理。
使用TP对ResNet-18的conv1进行剪枝,代码如下:import torch from torchvision.models import resnet18 import torch_pruning as tpmodel = resnet18(pretrained=True).eval()# 1. 构建依赖图 DG = tp.DependencyGraph() DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. 获取与model.conv1存在依赖的所有层,并指定需要剪枝的通道索引(此处我们剪枝第[2,6,9]个通道) group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] ) print(model, group)# 3. 执行剪枝操作 if DG.check_pruning_group(group): # 避免将通道剪枝到0group.prune() print(model, group) output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络
上述过程一共三步:
- 1 对网络进行依赖图构建;
- 2 选取需要剪枝的层,指定剪枝通道,获得分组group;这里的group,是所有与conv1相依赖的层。
- 3 执行剪枝操作,按组移除通道。那剪枝过程具体操作了哪些层呢?
左图为剪枝前的conv1 的group,右图为剪枝后的conv1 的group。
怎么去看这个group呢,在下图右侧进行了简单的标注,可以发现conv1的group都会进行剪枝,从而适应conv1的卷积核的维度发生的变化
左图为剪枝前的resnet结构部分,右图为剪枝后的resnet结构部分。
2.3 使用TP对网络中每个层进行剪枝
在实际实现时,我们希望是对整个网络结构进行剪枝,而非特定的某几层,这就涉及到如何不重复地遍历网络中所有分组的问题。
DepGraph提供了接口DG.get_all_groups
来实现这以目标。该接口仅实现层的分组,并不会分辨通道的重要性。该接口包含两个参数
ignored_layers
:指定忽略 某些希望被剪枝的层。通常包括最后的分类层、以及报错的层(也可以使用其它正确的层进行替换)root_module_types
:指定了每个组的起始层的类型。比如想剪枝所有的卷基层,而不想剪枝全连接层,只需要只传入对应的卷积类即可。
值得注意的是,不同层可能出现在同一分组中,Depgraph会自动去除重复分组。
下面先提前设定好需要剪枝的通道,来展示
DG.get_all_groups
的使用:import torch import torch.nn as nn import torch_pruning as tp from torchvision.models import resnet18model = resnet18(pretrained=True).eval()# 1. 构建依赖图 DG = tp.DependencyGraph() DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. 获取与model.conv1存在依赖的所有层,并指定需要剪枝的通道索引(此处我们剪枝第[2,6,9]个通道) Groups = DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear])# 3. 执行剪枝操作 for group in Groups:idxs = [2,4,6] # your pruning indicesgroup.prune(idxs=idxs)print(group)output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络
但该段代码剪枝,在TP实际剪枝也是较少使用,这里是展示一个剪枝底层的基本操作。
3 TP对完整网络的剪枝
3.1 常用的结构化剪枝原理
结构化和非结构化剪枝方向,已发表的有较多的论文。但在工业上较常用的为结构化剪枝。实际中最常用的结构化剪枝方法:
- 利用权值进行filter剪枝:Pruning Filters for Efficient ConvNets
在这张图中,我们可以找到两个卷积参数矩阵(Kernel Matrix):第一个卷积层以 x i x_i xi 作为输入,输出特征图 x i + 1 x_{i+1} xi+1;第二个卷积层以 x i + 1 x_{i+1} xi+1作为输入,生成特征图 x i + 2 x_{i+2} xi+2。
在结构化剪枝中,这两个卷基层之间存在非常直观的依赖关系,即当我们调整第一层的输出通道时,第二个卷积层的输入通道也需要相应的进行调整,这使得蓝色高亮的参数需要同时被剪枝。
此外,作者指出网络中可能存在更复杂的依赖,例如残差结构依赖:
- 利用bn进行剪枝:Learning Efficient Convolutional Networks through Network Slimming。
BN会按通道对输入特征进行归一化,使得不同的特征处于比较接近的范围内。我们将缩放因子(从批量归一化层重用)与卷积层中的每个通道相关联。稀疏正则化在训练期间被施加在这些缩放因子上,以自动识别不重要的通道。缩放因子值较小(橙色)的通道将被修剪(左侧)。修剪后,我们获得紧凑模型(右侧),然后对其进行微调,以实现与正常训练的全网络相当(甚至更高)的精度。
在任何一个网络中,BN的scale参数都具备一定的绝对值大小(也就是不会过小),这意味着各个通道都具有不可忽略的重要性。解决这类问题的一种有效方法是使用稀疏训练,通过对scale参数施加正则化项来稀疏化一部分通道。在slimming论文中,作者对scale参数施加了一个额外的L1正则化项,从而实现了这一过程。整个流程如下所示:稀疏训练–>剪枝–>微调
3.2 TP剪枝示例
网络中存在大量复杂依赖的情况下,如何进行剪枝呢?
【1】计算网络每个group中每层的重要性
- Torch-Puning 库内置了处理依赖的功能,并提供了可扩展的接口用于自定义剪枝器。
类tp.importance.Importance
要求我们实现一个非常简单的接口__call__
- 入参为一个 group,包含了多个相互耦合的层。
- 输出为一个一维的重要性的得分向量,其含义是每个通道的重要性,因此他的维度和通道数量是相同的。
由于输入的group通常会包含多个可剪枝层,因此我们首先对这些层进行独立的重要性计算,然后通过求平均值得到最终结果。- Torch-Puning也提供了常用重要性评估策略:
tp.importance.MagnitudeImportance(p=2)
:p=2
表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性 。
tp.importance.BNScaleImportance()
:利用BN计算每个group中的每个层的权值的重要性
tp.importance.GroupNormImportance()
:与继承于MagnitudeImportance,且没做任何的添加和修改。
【2】对网络进行剪枝
- Torch-Pruning库定义了一个元剪枝器
tp.pruner.MetaPruner
,能够完成除了重要性评估之外的所有工作。一般常在自定义的重要性评估后,执行剪枝时使用- Torch-Puning也提供了常用的剪枝策略
tp.pruner.MagnitudePruner()
tp.pruner.BNScalePruner()
tp.pruner.GroupNormPruner()
Depgraph 提出的基于全局重要性的剪枝
【3】例子
为了增加难度,这里我们对一个DenseNet模型进行剪枝。
这里只展示了稀疏训练和微调使用的位置,仅剪枝部分能够有效跑通。import torch import torch.nn as nn import torch_pruning as tp from torchvision.models import densenet121model = densenet121(pretrained=True) example_inputs = torch.randn(1, 3, 224, 224)# 1. 使用我们上述定义的重要性评估 # imp = tp.importance.MagnitudeImportance(p=2) # imp = tp.importance.BNScaleImportance() imp = tp.importance.GroupNormImportance()# 2. 忽略无需剪枝的层,例如最后的分类层 ignored_layers = [] for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!# 3. 初始化剪枝器 iterative_steps = 5 # 迭代式剪枝,重复5次Pruning-Finetuning的循环完成剪枝。 # pruner = tp.pruner.MagnitudePruner( # pruner = tp.pruner.BNScalePruner( pruner = tp.pruner.GroupNormPruner(model,example_inputs, # 用于分析依赖的伪输入importance=imp, # 重要性评估指标iterative_steps=iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝ch_sparsity=0.5, # 目标稀疏性,这里我们移除50%的通道 ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers, # 忽略掉最后的分类层 )# 4. 稀疏训练(为了节省时间我们假装在训练,实际应用时只需要在optimizer.step前插入regularize即可) for _ in range(100):pass# optimizer.zero_grad() # ...# loss.backward()# pruner.regularize(model, reg=1e-5) # <== 插入该行进行稀疏化# optimizer.step()# 4. Pruning-Finetuning的循环 base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) for i in range(iterative_steps):pruner.step() # 执行裁剪,本例子中我们每次会裁剪10%,共执行5次,最终稀疏度为50%macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)print(" Iter %d/%d, Params: %.2f M => %.2f M" % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6))print(" Iter %d/%d, MACs: %.2f G => %.2f G"% (i+1, iterative_steps, base_macs / 1e9, macs / 1e9))# finetune your model here# finetune(model)# ... print(model)