Torch-Pruning 库入门级使用介绍

在这里插入图片描述

项目地址:https://github.com/VainF/Torch-Pruning

Torch-Pruning 是一个专用于torch的模型剪枝库,其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法(如 Magnitude Pruning 或 Taylor Pruning)相结合可以达到良好的剪枝效果。

本博文结合项目官网案例,对信息进行结构话,抽离出剪枝技术说明、剪枝模型保存与加载、剪枝技术的基本使用,剪枝技术的具体使用案例。并结合外部信息,分析剪枝对模型性能精度的影响。

1、基本说明

1.1 项目安装

打开https://github.com/VainF/Torch-Pruning,下载项目
在这里插入图片描述
然后在终端中,进入项目目录,并执行pip install -r requirements.txt 安装项目依赖库
在这里插入图片描述
然后在执行 pip install -e . ,将项目安装在当前目录下,并设置为editing模式。
在这里插入图片描述
验证安装:执行命令python -c "import torch_pruning", 如果没有输出报错信息则表示安装成功。
在这里插入图片描述

1.2 DepGraph 技术说明

在结构修剪中,组被定义为深度网络中最小的可移除单元。每个组由多个相互依赖的层组成,需要同时修剪这些层以保持最终结构的完整性。然而,深度网络通常表现出层与层之间错综复杂的依赖关系,这对结构修剪提出了重大挑战。这项研究通过引入一种名为 DepGraph 的自动化机制来解决这一挑战,该机制可以轻松实现参数分组,并有助于修剪各种深度网络。
在这里插入图片描述

直接剪枝会会破坏layer间的依赖关系,会导致forward流程报错。具体如下面代码,移除model.conv1模块中的idxs为0与1的channel,导致后续的bn1层输入输入与参数格式对不上号,然后报错。

from torchvision.models import resnet18
import torch_pruning as tp
import torchmodel = resnet18().eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
output = model(torch.randn(1,3,224,224)) # test

在这里插入图片描述
基本在后续层添加剪枝,运行代码也会保存,因为batchnorm的下一层要求的输出channel是64。

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) 
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
output = model(torch.randn(1,3,224,224)) 

使用DepGraph剪枝代码如下,先使用tp.DependencyGraph().build_dependenc构建出依赖图,然后基于DG.get_pruning_group函数获取目标剪枝层的依赖关系组,最后在检验关系并进行剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True).eval()# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )# 3. prune all grouped layers that are coupled with model.conv1 (included).
print(group)
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.group.prune()# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

代码执行后的输出如下所示,可以看到捕捉到group对应的依赖layer

--------------------------------Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

1.3 剪枝模型的保存与加载

剪枝后的模型由于网络结构改变了,如果只保存模型参数,是无法支持原始网络结构,需要将模型结构连参数一并保存。加载时连同参数一起加载。

model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model

或者基于tp库中tp.state_dict函数提取目标参数进行保存,并基于tp.load_state_dict函数将剪枝后的参数赋值到原始模型中形成剪枝模型。

# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')# create a new model, e.g. resnet18
new_model = resnet18().eval()# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
print(new_model) # This will be a pruned model.

2、剪枝基本案例

2.1 具有目标结构的剪枝

以下代码使用TaylorImportance指标进行剪枝,设置忽略输出层的剪枝。并设置MagnitudePruner中对通道剪枝50%,一共分iterative_steps步完成剪枝,每一次剪枝都进行微调。
整体来说,具备目标结构的剪枝,效果是最差的。 基于https://blog.csdn.net/a486259/article/details/140407147 分析的数据得出的结论。

import torch
from torchvision.models import resnet18
import torch_pruning as tp#model = resnet18(pretrained=True)
model = resnet18()# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()ignored_layers = []
for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}#pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers,
)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):if isinstance(imp, tp.importance.TaylorImportance):# Taylor expansion requires gradients for importance estimationloss = model(example_inputs).sum() # a dummy loss for TaylorImportanceloss.backward() # before pruner.step()pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)# finetune your model here# finetune(model)# ...

代码的输出信息如下所示,可以看到macs与nparams在逐步降低。最终输出的模型结构,所有的chanel都减半了,只有输出层例外。

iter 0 | rate:0.8092  0.8111
iter 1 | rate:0.6469  0.6445
iter 2 | rate:0.4971  0.4979
iter 3 | rate:0.3718  0.3695
iter 4 | rate:0.2674  0.2614
ResNet((conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=256, out_features=1000, bias=True)
)
PS D:\开源项目\Torch-Pruning-master>(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=256, out_features=1000, bias=True)
)(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=256, out_features=1000, bias=True)
)

2.2 自动结构剪枝

这里的自动结构是有一个预设目标,即将总体channel剪枝到原模型的多少,但没有预定的目标结构。可能有的laye通道剪枝数多,有的剪枝数少。 与2.1中的代码相比,主要是增加了参数 global_pruning=True。但这个剪枝方式比具有目标结构的剪枝更加有效。就像裁员一样,要求各个部门内裁员比例相同与在公司内控制裁员比例(各个部门裁员比例按重要度排列,裁员比例不一样),必然是第二种方式更有效。第一种方式,使低效率部门的靠前但无用员工保留下来了。

import torch
from torchvision.models import resnet18
import torch_pruning as tp#model = resnet18(pretrained=True)
model = resnet18()# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()ignored_layers = []
for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!iterative_steps = 3 # progressive pruning
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,pruning_ratio=0.5, # remove 50%的channelignored_layers=ignored_layers,global_pruning=True
)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):if isinstance(imp, tp.importance.TaylorImportance):# Taylor expansion requires gradients for importance estimationloss = model(example_inputs).sum() # a dummy loss for TaylorImportanceloss.backward() # before pruner.step()pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)# finetune your model here# finetune(model)# ...

2.3 MagnitudePruner中的参数

指定特定层的剪枝比例 通过pruning_ratio_dict参数,指定model.layer2的剪枝比例为20%,这里适用于有先验经验的layer,控制对特定layer的剪枝比例。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)pruner = tp.pruner.MagnitudePruner(model,example_inputs,imp,pruning_ratio = 0.5,pruning_ratio_dict = {model.layer2: 0.2}
)
pruner.step()
print(model)

代码执行后的层为:ResNet{64, 128, 256, 512} => ResNet{32, 102, 128, 256}

设置最大剪枝比例 通过 max_pruning_ratio 参数设置最大剪枝比例,避免由于稀疏剪枝或者自动剪枝时某个层被严重剪枝或者移除。

剪枝次数与剪枝调度器 您打算分多轮修剪模型,iterative_steps 会很有用。默认情况下,修剪器会逐渐增加模型的稀疏度,直到达到所需的 pruning_ratio。如以下代码,分5次实现剪枝目标。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
)# prune the model, iteratively if necessary.
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)print("Round %d/%d, Params: %.2f M" % (i+1, iterative_steps, nparams/1e6))# finetune your model here# finetune(model)# ...
print(model)

对应输出如下
Round 1/5, Params: 9.44 M
Round 2/5, Params: 7.45 M
Round 3/5, Params: 5.71 M
Round 4/5, Params: 4.20 M
Round 5/5, Params: 2.93 M

设置忽略的层 这主要是避免对输出层进行剪枝,修改模型的输出结构。使用代码如下,通过ignored_layers参数传入忽略的layer对象。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,pruning_ratio=0.5, # remove 50% channelsignored_layers=[model.conv1, model.fc] # ignore the first & last layers
)
pruner.step()
print(model)

channel取整 在很多的时候都认为channel为16的倍数,gpu运行效率最高。使用代码如下,通过round_to参数,保持channel是特定数的倍数。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,pruning_ratio=0.3, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}round_to=10 # round to 10x. Note: 10x is not a good practice.
)pruner.step()
print(model)

channel_groups 某些层(例如 nn.GroupNorm 和 nn.Conv2d)具有 group 参数,这会在层内引入额外的依赖项。修剪后,保持所有组的大小相同至关重要。为了满足这一要求,引入了参数 channel_groups 以启用对这些通道的手动分组。如以下代码,通过channel_groups参数,控制model.group_conv1中的参数为8个一组

pruner = tp.pruner.MagnitudePruner(model,example_inputs=example_inputs,importance=importance,iterative_steps=1,pruning_ratio=0.5,channel_groups = {model.group_conv1: 8} # For Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), groups=8))

额外参数剪枝 有些时候模型具备的可训练参数并非conv、fc等传统layer中,需要基于unwrapped_parameters参数将额外的可剪枝参数传入到剪枝器中。具体如下所示:

from torchvision.models.convnext import CNBlock, ConvNeXt
unwrapped_parameters = []
for m in model.modules():if isinstance(m, CNBlock):unwrapped_parameters.append( (m.layer_scale, 0) )pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,pruning_ratio=0.5, unwrapped_parameters=unwrapped_parameters 

限定剪枝范围 root_module_types 参数用于指定组的“根”或第一层。在许多情况下,我们专注于修剪线性层和卷积 (Conv) 层。要专门针对这些层启用修剪,我们可以使用以下参数:root_module_types=[nn.Conv2D, nn.Linear]。这可确保将修剪应用于所需的层。

pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,pruning_ratio=0.5, root_module_types=[nn.Conv2D, nn.Linear]

3、具体应用案例

3.1 timm模型剪枝

官方代码为:examples\timm_models\prune_timm_models.py
具体详情如下,这里有一个特殊用法,是通过num_heads参数实现对于transformer layer的支持

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
os.environ['TIMM_FUSED_ATTN'] = '0'
import torch
import torch.nn as nn 
import torch.nn.functional as F
from typing import Sequence
import timm
from timm.models.vision_transformer import Attention
import torch_pruning as tp
import argparseparser = argparse.ArgumentParser(description='Prune timm models')
parser.add_argument('--model', default=None, type=str, help='model name')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='channel pruning ratio')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
args = parser.parse_args()def main():timm_models = timm.list_models()if args.list_models:print(timm_models)if args.model is None: returnassert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)device = 'cuda' if torch.cuda.is_available() else 'cpu'model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)imp = tp.importance.GroupNormImportance()print("Pruning %s..."%args.model)input_size = model.default_cfg['input_size']example_inputs = torch.randn(1, *input_size).to(device)test_output = model(example_inputs)ignored_layers = []num_heads = {}for m in model.modules():if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:ignored_layers.append(model.head)print("Ignore classifier layer: ", m.head)# Attention layersif hasattr(m, 'num_heads'):if hasattr(m, 'qkv'):num_heads[m.qkv] = m.num_headsprint("Attention layer: ", m.qkv, m.num_heads)elif hasattr(m, 'qkv_proj'):num_heads[m.qkv_proj] = m.num_headsprint("========Before pruning========")print(model)base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)pruner = tp.pruner.MetaPruner(model, example_inputs, global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.importance=imp, # importance criterion for parameter selectioniterative_steps=1, # the number of iterations to achieve target pruning ratiopruning_ratio=args.pruning_ratio, # target pruning rationum_heads=num_heads,ignored_layers=ignored_layers,)for g in pruner.step(interactive=True):g.prune()for m in model.modules():# Attention layersif hasattr(m, 'num_heads'):if hasattr(m, 'qkv'):m.num_heads = num_heads[m.qkv]m.head_dim = m.qkv.out_features // (3 * m.num_heads)elif hasattr(m, 'qkv_proj'):m.num_heads = num_heads[m.qqkv_projkv]m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)print("========After pruning========")print(model)test_output = model(example_inputs)pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))if __name__=='__main__':main()

3.2 llm模型剪枝

在examples\LLMs\prune_llama.py中提供了一个对于llama模型的剪枝案例.
核心代码如下,可以看到也是基于num_heads记录transformer的结构信息,然后在剪枝后将num_heads数据赋值到对应模型参数上。与原始代码相比,这里删除了模型精度验证相关的代码。


# Code adapted from 
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
# https://github.com/locuslab/wandaimport os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))import argparse
import os 
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import time
import torch
import torch.nn as nn
from collections import defaultdict
import fnmatch
import numpy as np
import randomprint('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())def get_llm(model_name, cache_dir="./cache"):model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir, device_map="auto")model.seqlen = model.config.max_position_embeddings return modeldef main():parser = argparse.ArgumentParser()parser.add_argument('--model', type=str, help='LLaMA model')parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')parser.add_argument("--cache_dir", default="./cache", type=str )parser.add_argument('--save', type=str, default=None, help='Path to save results.')parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')parser.add_argument("--eval_zero_shot", action="store_true")args = parser.parse_args()# Setting seeds for reproducibilitynp.random.seed(args.seed)torch.random.manual_seed(args.seed)model_name = args.model.split("/")[-1]print(f"loading llm model {args.model}")model = get_llm(args.model, args.cache_dir)       model.eval()tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)device = torch.device("cuda:0")if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.device = model.hf_device_map["lm_head"]print("use device ", device)############### Pruning##############print("----------------- Before Pruning -----------------")print(model)text = "Hello world."inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)import torch_pruning as tp num_heads = {}for name, m in model.named_modules():if name.endswith("self_attn"):num_heads[m.q_proj] = model.config.num_attention_headsnum_heads[m.k_proj] = model.config.num_key_value_headsnum_heads[m.v_proj] = model.config.num_key_value_headshead_pruning_ratio = args.pruning_ratiohidden_size_pruning_ratio = args.pruning_ratiopruner = tp.pruner.MagnitudePruner(model, example_inputs=inputs,importance=tp.importance.GroupNormImportance(),global_pruning=False,pruning_ratio=hidden_size_pruning_ratio,ignored_layers=[model.lm_head],num_heads=num_heads,prune_num_heads=True,prune_head_dims=False,head_pruning_ratio=head_pruning_ratio,)pruner.step()# Update model attributesnum_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )model.config.num_attention_heads = num_headsmodel.config.num_key_value_heads = num_key_value_headsfor name, m in model.named_modules():if name.endswith("self_attn"):m.hidden_size = m.q_proj.out_featuresm.num_heads = num_headsm.num_key_value_heads = num_key_value_headselif name.endswith("mlp"):model.config.intermediate_size = m.gate_proj.out_featuresprint("----------------- After Pruning -----------------")print(model)#ppl_test = eval_ppl(args, model, tokenizer, device)#print(f"wikitext perplexity {ppl_test}")if args.save_model:model.save_pretrained(args.save_model)tokenizer.save_pretrained(args.save_model)if __name__ == '__main__':main()

3.3 目标检测模型剪枝

在Torch-Pruning 库中提供了针对yolov8、yolov7、yolov5的剪枝案例。关于yolov8还提供了剪枝后的训练策略,其主要技巧在与对不可剪枝层的可剪枝话处理(C2f模块的剪枝,其含split操作,不利于剪枝索引)。后续会补充博客,说明对yolov8的剪枝使用。

4、其他信息

4.1 剪枝器中的评价指标

在torch_pruning\pruner\importance.py中有很多个剪枝评价指标

__all__ = [# Base Class"Importance",# Basic Group Importance"GroupNormImportance","GroupTaylorImportance","GroupHessianImportance",# Aliases"MagnitudeImportance","TaylorImportance","HessianImportance",# Other Importance"BNScaleImportance","LAMPImportance","RandomImportance",
]

整体来看是TaylorImportance最好,一直使用该值即可。
来看

4.2 剪枝对性能精度的影响

在博客https://blog.csdn.net/a486259/article/details/140407147?spm=1001.2014.3001.5501 中基本确定了剪枝50%,对模型精度是没有任何影响的。这里对Torch-Pruning 库相关的论文数据进行二次核验,以致于分析出剪枝中速度提升对精度的影响。

以DepGraph: Towards Any Structural Pruning数据为例,可以发现最高支持6x速度剪枝后保持模型性能。
在这里插入图片描述
以LLM-Pruner: On the Structural Pruning of Large Language Models 论文数据为例,可以发现使用Vector评价方法的剪枝,移除10%的参数,zero-shot下对模型精度影响不大。而图4更表明,剪枝方法正确的话,移除50%的参数对模型性能影响也不大。
在这里插入图片描述
以论文 Structural Pruning for Diffusion Models 的数据为分析,同样可以发现剪枝50%左右的通道,对结果影响不对。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873165.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TCP重传机制详解

1.什么是TCP重传机制 在 TCP 中,当发送端的数据到达接收主机时,接收端主机会返回⼀个确认应答消息,表示已收到消息。 但是如果传输的过程中,数据包丢失了,就会使⽤重传机制来解决。TCP的重传机制是为了保证数据传输的…

React安装(学习版)

1. 安装Node.js和npm 首先,确保你的电脑上已经安装了Node.js和npm(Node Package Manager)。你可以从 Node.js官网 下载安装包并按照提示进行安装。安装完成后,可以在命令行终端中验证Node.js和npm是否正确安装: node …

【Node.js】初识 Node.js

Node.js 概念 Node.js 是一个开源与跨平台的 JavaScript运行时环境 ,在浏览器外运行 V8 JavaScript 引擎(Google Chrome的内核),利用事件驱动、非阻塞和异步输入输出 等技术提高性能。 可以理解为 Node.js就是一个服务器端的、非阻塞式 l/O 的、事件驱…

01 MySQL

学习资料:B站视频-黑马程序员JavaWeb基础教程 文章目录 JavaWeb整体介绍 MySQL1、数据库相关概念2、MySQL3、SQL概述4、DDL:数据库操作5、DDL:表操作6、DML7、DQL8、约束9、数据库设计10、多表查询11、事务 JavaWeb整体介绍 JavaWeb Web:全球广域网&…

PostgreSQL的逻辑架构

一、PostgreSql的逻辑架构: 一个server可以有多个database;一个database有多个schema,默认的schema是public;schema下才是对象,其中对象包含:表、视图、触发器、索引等;与user之间的关系&#x…

Mysql笔记-20240718

零、 help、\h、? 调出帮助 mysql> \hFor information about MySQL products and services, visit:http://www.mysql.com/ For developer information, including the MySQL Reference Manual, visit:http://dev.mysql.com/ To buy MySQL Enterprise support, training, …

免费恢复软件有哪些?电脑免费使用的 5 大数据恢复软件

您是否在发现需要的文件时不小心删除了回收站中的文件?您一定对误操作感到后悔。文件永远消失了吗?还有机会找回它们吗?当然有!您可以查看这篇文章,挑选 5 款功能强大的免费数据恢复软件,用于 Windows 和 M…

<数据集>混凝土缺陷检测数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:7353张 标注数量(xml文件个数):7353 标注数量(txt文件个数):7353 标注类别数:6 标注类别名称:[exposed reinforcement, rust stain, Crack, Spalling, Efflorescence…

又缩水Unity7月闪促限时4折活动模块化角色模板编辑器场景美术插件拖尾怪物3D模型UI载具AI对话TPS飞机RPG和FPS202407

Flash Deals are Coming Back! 限时抢购又回来了! July 17, 2024 8:00:00 PT to July 24, 2024 7:59:00 PT 太平洋时间 2024 年 7 月 17 日 8:00:00 至 2024 年 7 月 24 日 7:59:00(太平洋时间)…

云计算实训室的核心功能有哪些?

在当今数字化转型浪潮中,云计算技术作为推动行业变革的关键力量,其重要性不言而喻。唯众,作为教育实训解决方案的领先者,深刻洞察到市场对云计算技能人才的迫切需求,精心打造了云计算实训室。这一实训平台不仅集成了先…

软件著作权申请教程(超详细)(2024新版)软著申请

目录 一、注册账号与实名登记 二、材料准备 三、申请步骤 1.办理身份 2.软件申请信息 3.软件开发信息 4.软件功能与特点 5.填报完成 一、注册账号与实名登记 首先我们需要在官网里面注册一个账号,并且完成实名认证,一般是注册【个人】的身份。中…

网安小贴士(17)认证技术原理应用

前言 认证技术原理及其应用是信息安全领域的重要组成部分,涉及多个方面,包括认证概念、认证依据、认证机制、认证类型以及具体的认证技术方法等。以下是对认证技术原理及应用的详细阐述: 一、认证概述 1. 认证概念 认证是一个实体向另一个实…

llama 2 改进之 RMSNorm

RMSNorm 论文:https://openreview.net/pdf?idSygkZ3MTJE Github:https://github.com/bzhangGo/rmsnorm?tabreadme-ov-file 论文假设LayerNorm中的重新居中不变性是可有可无的,并提出了均方根层归一化(RMSNorm)。RMSNorm根据均方根(RMS)将…

解决npm install(‘proxy‘ config is set properly. See: ‘npm help config‘)失败问题

摘要 重装电脑系统后,使用npm install初始化项目依赖失败了,错误提示:‘proxy’ config is set properly…,具体的错误提示如下图所示: 解决方案 经过报错信息查询解决办法,最终找到了两个比较好的方案&a…

HTTP协议、Wireshark抓包工具、json解析、天气爬虫

HTTP超文本传输协议 HTTP(Hyper Text Transfer Protocol): 全称超文本传输协议,是用于从万维网(WWW:World Wide Web )服务器传输超文本到本地浏览器的传送协议。 HTTP 协议的重要特点: 一发一收…

JVM:MAT内存泄漏检测原理

文章目录 一、介绍 一、介绍 MAT提供了称为支配树(Dominator Tree)的对象图。支配树展示的是对象实例间的支配关系。在对象引用图中,所有指向对象B的路径都经过对象A,则认为对象A支配对象B。 支配树中对象本身占用的空间称之为…

Spire.PDF for .NET【文档操作】演示:如何在 C# 中切换 PDF 层的可见性

我们已经演示了如何使用 Spire.PDF在 C# 中向 PDF 文件添加多个图层以及在 PDF 中删除图层。我们还可以在 Spire.PDF 的帮助下在创建新页面图层时切换 PDF 图层的可见性。在本节中,我们将演示如何在 C# 中切换新 PDF 文档中图层的可见性。 Spire.PDF for .NET 是一…

【LabVIEW作业篇 - 1】:中途停止for和while循环

文章目录 for循环while循环如何使用帮助 for循环 在程序框图中,创建for循环结构,选择for循环,鼠标右键-条件接线端,即出现像while循环中的小红圆心,其作用与while循环相同。 运行结果如下。(若随机数>…

分布式搜索引擎ES-Elasticsearch进阶

1.head与postman基于索引的操作 引入概念: 集群健康: green 所有的主分片和副本分片都正常运行。你的集群是100%可用 yellow 所有的主分片都正常运行,但不是所有的副本分片都正常运行。 red 有主分片没能正常运行。 查询es集群健康状态&…

ExoPlayer架构详解与源码分析(15)——Renderer

系列文章目录 ExoPlayer架构详解与源码分析(1)——前言 ExoPlayer架构详解与源码分析(2)——Player ExoPlayer架构详解与源码分析(3)——Timeline ExoPlayer架构详解与源码分析(4)—…