datawhale11月组队学习 模型压缩技术2:PyTorch模型剪枝教程

文章目录

    • 一、 prune模块简介
      • 1.1 常用方法
      • 1.2 剪枝效果
      • 1.3 二、三、四章剪枝测试总结
    • 二、局部剪枝(Local Pruning)
      • 2.1 结构化剪枝
        • 2.1.1 对weight进行随机结构化剪枝(random_structured)
        • 2.1.2 对weight进行迭代剪枝(范数结构化剪枝,ln_structured)
      • 2.2 非结构化剪枝
        • 2.2.1 对bias进行随机非结构化剪枝
        • 2.2.2 对多层网络进行范数非结构化剪枝(l1_unstructured)
      • 2.3 永久化剪枝(remove)
    • 三、全局剪枝(GLobal pruning)
    • 四、自定义剪枝(Custom pruning)

  • 《datawhale2411组队学习之模型压缩技术1:模型剪枝(上)》:介绍模型压缩的几种技术;模型剪枝基本概念、分类方式、剪枝标准、剪枝频次、剪枝后微调等内容
  • 《datawhale11月组队学习 模型压缩技术2:PyTorch模型剪枝教程》:介绍PyTorch的prune模块具体用法
  • 《datawhale11月组队学习 模型压缩技术3:2:4结构稀疏化BERT模型》:介绍基于模式的剪枝——2:4结构稀疏化及其在BERT模型上的测试效果

项目地址awesome-compression、在线阅读

一、 prune模块简介

PyTorch教程《Pruning Tutorial》、torch.nn.utils.prune文档

1.1 常用方法

Pytorch在1.4.0版本开始,加入了剪枝操作,在torch.nn.utils.prune模块中,主要有以下剪枝方法:

剪枝类型子类型剪枝方法
局部剪枝结构化剪枝随机结构化剪枝 (random_structured)
范数结构化剪枝 (ln_structured)
非结构化剪枝随机非结构化剪枝 (random_unstructured)
范数非结构化剪枝 (ln_unstructured)
全局剪枝非结构化剪枝全局非结构化剪枝 (global_unstructured)
自定义剪枝自定义剪枝 (Custom Pruning)

除此之外,模块中还有一些其它方法:

方法描述
prune.remove(module, name)剪枝永久化
prune.apply使用指定的剪枝方法对模块进行剪枝。
prune.is_pruned(module)检查给定模块的某个参数是否已被剪枝。
prune.custom_from_mask(module, name, mask)基于自定义的掩码进行剪枝,用于定义更加细粒度的剪枝策略。

1.2 剪枝效果

  • 参数变化

    • 剪枝前,weight 是模型的一个参数,意味着它是模型训练时优化的对象,可以通过梯度更新(通过 optimizer.step() 来更新它的值)。
    • 剪枝过程中,原始权重被保存到新的变量 weight_orig中,便于后续访问原始权重。
    • 剪枝后,weight是剪枝后的权重值(通过原始权重和剪枝掩码计算得出),但此时不再是参数,而是模型的属性(一个普通的变量)
  • 掩码存储:生成一个名为 weight_mask的剪枝掩码,会被保存为模块的一个缓冲区(buffer)。

  • 前向传递:PyTorch 使用 forward_pre_hooks 来确保每次前向传递时都会应用剪枝处理。每个被剪枝的参数都会在模块中添加一个钩子来实现这一操作。

1.3 二、三、四章剪枝测试总结

  1. weight进行剪枝,效果见1.2 章节。
  2. weight进行迭代剪枝,相当于把多个剪枝核(mask)序列化成一个剪枝核, 最终只有一个weight_origweight_maskhook也被更新。
  3. weight剪枝后,再对bias进行剪枝,weight_origweight_mask不变,新增bias_origbias_mask,新增bias hook
  4. 可以对多个模块同时进行剪枝,最后使用remove进行剪枝永久化
    使用remove函数后, weight_origbias_orig 被移除,剪枝后的weightbias 成为标准的模型参数。经过 remove 操作后,剪枝永久化生效。此时,剪枝掩码weight_mask 和 hook不再需要,named_buffers_forward_pre_hooks 都被清空。
  5. 局部剪枝需要根据自己的经验来决定对某一层网络进行剪枝,需要对模型有深入了解,所以全局剪枝(跨不同参数)更通用,即从整体网络的角度进行剪枝。采用全局剪枝时,不同的层被剪掉的百分比可能不同。
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),(model.fc2, 'weight'))# 应用20%全局剪枝
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

最终各层剪枝比例为(随机的):

Sparsity in conv1.weight: 5.33%
Sparsity in conv2.weight: 17.25%
Sparsity in fc1.weight: 22.03%
Sparsity in fc2.weight: 14.67%
Global sparsity: 20.00%
  1. 自定义剪枝需要通过继承class BasePruningMethod()来定义,,其内部有若干方法: call, apply_mask, apply, prune, remove。其中,必须实现__init__compute_mask两个函数才能完成自定义的剪枝规则设定。此外,您必须指定要实现的修剪类型( global, structured, and unstructured)。

二、局部剪枝(Local Pruning)

  局部剪枝,指的是对网络的单个层或局部范围内进行剪枝。其中,非结构化剪枝会随机地将一些权重参数变为0,结构化剪枝则将某个维度某些通道的权重变成0。
总结一下2.1和2.2的效果:

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torchsummary import summary# 1.定义一个经典的LeNet网络
class LeNet(nn.Module):def __init__(self, num_classes=10):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)self.fc2 = nn.Linear(in_features=120, out_features=84)self.fc3 = nn.Linear(in_features=84, out_features=num_classes)def forward(self, x):x = self.maxpool(F.relu(self.conv1(x)))x = self.maxpool(F.relu(self.conv2(x)))x = x.view(x.size()[0], -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device=device)# 2.打印模型结构
summary(model, input_size=(1, 28, 28))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1            [-1, 6, 24, 24]             156MaxPool2d-2            [-1, 6, 12, 12]               0Conv2d-3             [-1, 16, 8, 8]           2,416MaxPool2d-4             [-1, 16, 4, 4]               0Linear-5                  [-1, 120]          30,840Linear-6                   [-1, 84]          10,164Linear-7                   [-1, 10]             850
================================================================
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 0.17
Estimated Total Size (MB): 0.22
----------------------------------------------------------------
# 3.打印模型的状态字典,状态字典里包含了所有的参数
print(model.state_dict().keys())
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
# 4.打印第一个卷积层的参数
module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],[ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],[ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],[-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],[-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],......[[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],[ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],[-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],[ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],[ 0.0418,  0.1033,  0.1615,  0.1822, -0.1586]]]], device='cuda:0',requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',requires_grad=True))]
# 5.打印module中属性张量named_buffers,此时为空列表
print(list(module.named_buffers()))
[]

2.1 结构化剪枝

2.1.1 对weight进行随机结构化剪枝(random_structured)

  对LeNet的conv1层的weight参数进行随机结构化剪枝,其中 amount是一个介于0.0-1.0的float数值,代表比例, 或者一个正整数,代表剪裁掉多少个参数.

prune.random_structured(module, name="weight", amount=2, dim=0)
# 1.再次打印模型的状态字典,发现conv1层多了weight_orig和weight_mask
print(model.state_dict().keys())
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
# 2. 剪枝后,原始的weight变成了weight_orig,并存放在named_parameters中
print(list(module.named_parameters()))
[('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],[ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],[ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],[-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],[-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],......[[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],[ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],[-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],[ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],[ 0.0418,  0.1033,  0.1615,  0.1822, -0.1586]]]], device='cuda:0',requires_grad=True))]
# 3. 剪枝掩码矩阵weight_mask存放在模块的buffer中
print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]]]))]
# 4. 剪枝操作后的weight已经不再是module的参数, 而只是module的一个属性.
print(module.weight)
tensor([[[[ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],[ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],[ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],[-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],[ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],[[[-0.0540, -0.1928, -0.0355, -0.0075, -0.1481],[ 0.0135,  0.0192,  0.0082, -0.0120, -0.0164],[-0.0435, -0.1488,  0.1092, -0.0041,  0.1960],[-0.1045, -0.0136,  0.0398, -0.1286,  0.0617],[-0.0091,  0.0466,  0.1827,  0.1655,  0.0727]]],[[[ 0.1216, -0.0833, -0.1491, -0.1143,  0.0113],[ 0.0452,  0.1662, -0.0425, -0.0904, -0.1235],[ 0.0565,  0.0933, -0.0721,  0.0909,  0.1837],[-0.1739,  0.0263,  0.1339,  0.0648, -0.0382],[-0.1667,  0.1478,  0.0448, -0.0892,  0.0815]]],[[[ 0.0000,  0.0000,  0.0000, -0.0000,  0.0000],[-0.0000,  0.0000,  0.0000,  0.0000, -0.0000],[-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],[-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],[ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000]]],[[[ 0.1278,  0.1037, -0.0323, -0.1504,  0.1080],[ 0.0266, -0.0996,  0.1499, -0.0845,  0.0609],[-0.0662, -0.1405, -0.0586, -0.0615, -0.0462],[-0.1118, -0.0961, -0.1325, -0.0417, -0.0741],[ 0.1842, -0.1040, -0.1786, -0.0593,  0.0186]]],[[[-0.0889, -0.0737, -0.1655, -0.1708, -0.0988],[-0.1787,  0.1127,  0.0706, -0.0352,  0.1238],[-0.0985, -0.1929, -0.0062,  0.0488, -0.1152],[-0.1659, -0.0448,  0.0821, -0.0956, -0.0262],[ 0.1928,  0.1767, -0.1792, -0.1364,  0.0507]]]],grad_fn=<MulBackward0>)

  对于每一次剪枝操作,PyTorch 会为剪枝的参数(如 weight)添加一个 forward_pre_hook。这个钩子会在每次进行前向传递计算之前,自动应用剪枝掩码(即将某些权重置为零),这保证了剪枝后的权重在模型计算时被正确地使用。

# 5.打印_forward_pre_hooks
print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomStructured object at 0x7f04012f8ca0>)])

简单总结就是:

  • weight 不再是参数,它变成了一个属性,表示剪枝后的权重。
  • weight_orig 保存原始未剪枝的权重。
  • weight_mask 是一个掩码,表示哪些权重被剪去了(即哪些位置变为零)。
  • 钩子会保证每次前向传递时,weight 会根据 weight_mask 来计算出剪枝后的版本。
2.1.2 对weight进行迭代剪枝(范数结构化剪枝,ln_structured)

  一个模型的参数可以执行多次剪枝操作,这种操作被称为迭代剪枝(Iterative Pruning)。上述步骤已经对conv1进行了随机结构化剪枝,接下来对其再进行范数结构化剪枝,看看会发生什么?

# n代表范数,这里n=2表示l2范数
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)# 再次打印模型参数
print(" model state_dict keys:")
print(model.state_dict().keys())
print('*'*50)print(" module named_parameters:")
print(list(module.named_parameters()))
print('*'*50)print(" module named_buffers:")
print(list(module.named_buffers()))
print('*'*50)print(" module weight:")
print(module.weight)
print('*'*50)print(" module _forward_pre_hooks:")
print(module._forward_pre_hooks)
model state_dict keys:
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************
module named_parameters:	# 原始参数weight_orig不变
...
...
module named_buffers:
[('weight_mask', tensor([[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]]]))]
**************************************************module weight:......
module _forward_pre_hooks:
OrderedDict([(1, <torch.nn.utils.prune.PruningContainer object at 0x7f04c86756d0>)])

  可见迭代剪枝相当于把多个剪枝核序列化成一个剪枝核, 新的 mask 矩阵与旧的 mask 矩阵的结合由PruningContainer的compute_mask方法处理,最后只有一个weight_orig和weight_mask。

  module._forward_pre_hooks是一个用于在模型的前向传播之前执行自定义操作的机制,这里记录了执行过的剪枝方法:

# 打印剪枝历史
for hook in module._forward_pre_hooks.values():if hook._tensor_name == "weight":  breakprint(list(hook))
[<torch.nn.utils.prune.RandomStructured object at 0x7f04012f8ca0>, <torch.nn.utils.prune.LnStructured object at 0x7f04c8675b80>]

2.2 非结构化剪枝

2.2.1 对bias进行随机非结构化剪枝

此时,我们也可以继续对偏置bias进行剪枝,看看module的参数、缓冲区、钩子和属性是如何变化的。

prune.random_unstructured(module, name="bias", amount=1)
# 再次打印模型参数
print(" model state_dict keys:")
print(model.state_dict().keys())
print('*'*50)print(" module named_parameters:")
print(list(module.named_parameters()))
print('*'*50)print(" module named_buffers:")
print(list(module.named_buffers()))
print('*'*50)print(" module bias:")
print(module.bias)
print('*'*50)print(" module _forward_pre_hooks:")
print(module._forward_pre_hooks)
model state_dict keys:
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************
# weight_orig不变,添加了bias_origmodule named_parameters:  
[('weight_orig', Parameter containing:...
, requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0893, -0.1464, -0.1101, -0.0076,  0.1493, -0.0418],requires_grad=True))]
**************************************************
# weight_mask不变,添加了bias_maskmodule named_buffers:
[('weight_mask', 
...('bias_mask', tensor([1., 1., 0., 1., 1., 1.]))]
**************************************************module bias:
tensor([-0.0893, -0.1464, -0.0000, -0.0076,  0.1493, -0.0418],grad_fn=<MulBackward0>)
**************************************************module _forward_pre_hooks:
OrderedDict([(1, <torch.nn.utils.prune.PruningContainer object at 0x7f04c86756d0>), (2, <torch.nn.utils.prune.RandomUnstructured object at 0x7f04013a7d30>)])

  对bias进行剪枝后,会发现state_dictnamed_parameters中不仅仅有了weight_orig,也有了bias_orig。在named_buffers中, 也同时出现了weight_maskbias_mask。最后,因为我们在两种参数上进行剪枝,因此会生成两个钩子。

2.2.2 对多层网络进行范数非结构化剪枝(l1_unstructured)

  前面介绍了对指定的conv1层的weightbias进行了不同方法的剪枝,那么能不能支持同时对多层网络的特定参数进行剪枝呢?

# 对于模型多个模块进行bias剪枝
for n, m in model.named_modules():# 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝if isinstance(m, torch.nn.Conv2d):prune.l1_unstructured(m, name="bias", amount=0.2)# 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝# elif isinstance(module, torch.nn.Linear):#     prune.random_structured(module, name="weight", amount=0.4,dim=0)# 再次打印模型参数
print(" model state_dict keys:")
print(model.state_dict().keys())
print('*'*50)print(" module named_parameters:")
print(list(module.named_parameters()))
print('*'*50)print(" module named_buffers:")
print(list(module.named_buffers()))
print('*'*50)print(" module weight:")
print(module.weight)
print('*'*50)print(" module bias:")
print(module.bias)
print('*'*50)print(" module _forward_pre_hooks:")
print(module._forward_pre_hooks)
model state_dict keys:
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias_orig', 'conv2.bias_mask', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************module named_parameters:[('weight_orig', Parameter containing:...('bias_orig', Parameter containing:...
**************************************************
# # weight_mask不变,bias_mask更新
module named_buffers:
[('weight_mask', ...
('bias_mask', tensor([1., 1., 0., 0., 1., 1.]))]
**************************************************
# module weight不变
module weight:...
**************************************************
module bias:
tensor([-0.0893, -0.1464, -0.0000, -0.0000,  0.1493, -0.0418],grad_fn=<MulBackward0>)
**************************************************
module _forward_pre_hooks:
OrderedDict([(1, <torch.nn.utils.prune.PruningContainer object at 0x7f04c86756d0>), (3, <torch.nn.utils.prune.PruningContainer object at 0x7f04010c1100>)])

2.3 永久化剪枝(remove)

接下来对模型的weight和bias参数进行永久化剪枝操作prune.remove

# 对module的weight执行剪枝永久化操作remove
for n, m in model.named_modules():if isinstance(m, torch.nn.Conv2d):prune.remove(m, 'bias')# 对conv1的weight执行剪枝永久化操作remove
prune.remove(module, 'weight')
print('*'*50)# 将剪枝后的模型的状态字典打印出来
print(" model state_dict keys:")
print(model.state_dict().keys())
print('*'*50)# 再次打印模型参数
print(" model named_parameters:")
print(list(module.named_parameters()))
print('*'*50)# 再次打印模型mask buffers参数
print(" model named_buffers:")
print(list(module.named_buffers()))
print('*'*50)# 再次打印模型的_forward_pre_hooks
print(" model forward_pre_hooks:")
print(module._forward_pre_hooks)
**************************************************model state_dict keys:
odict_keys(['conv1.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************model named_parameters:
[('bias', Parameter containing:
tensor([-0.0893, -0.1464, -0.0000, -0.0000,  0.1493, -0.0418],requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],[ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],[ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],[-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],[ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000]]],[[[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],[ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],[-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],[-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],[-0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],[[[ 0.1216, -0.0833, -0.1491, -0.1143,  0.0113],[ 0.0452,  0.1662, -0.0425, -0.0904, -0.1235],[ 0.0565,  0.0933, -0.0721,  0.0909,  0.1837],[-0.1739,  0.0263,  0.1339,  0.0648, -0.0382],[-0.1667,  0.1478,  0.0448, -0.0892,  0.0815]]],[[[ 0.0000,  0.0000,  0.0000, -0.0000,  0.0000],[-0.0000,  0.0000,  0.0000,  0.0000, -0.0000],[-0.0000,  0.0000, -0.0000, -0.0000,  0.0000],[-0.0000, -0.0000,  0.0000, -0.0000,  0.0000],[ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000]]],[[[ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],[ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],[ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000]]],[[[-0.0889, -0.0737, -0.1655, -0.1708, -0.0988],[-0.1787,  0.1127,  0.0706, -0.0352,  0.1238],[-0.0985, -0.1929, -0.0062,  0.0488, -0.1152],[-0.1659, -0.0448,  0.0821, -0.0956, -0.0262],[ 0.1928,  0.1767, -0.1792, -0.1364,  0.0507]]]], requires_grad=True))]
**************************************************model named_buffers:
[]
**************************************************model forward_pre_hooks:
OrderedDict()

可见,执行remove操作后:

  • weight_origbias_orig 被移除,剪枝后的weightbias 成为标准的模型参数。经过 remove 操作后,剪枝永久化生效。
  • 剪枝掩码weight_maskbias_mask不再需要,named_buffers被清空
  • _forward_pre_hooks 也被清空(由于剪枝后的权重和偏置将直接反映在最终模型中,所以无须再借助外部的掩码或钩子函数来维护剪枝过程)。

三、全局剪枝(GLobal pruning)

  前面已经介绍了局部剪枝的四种方法,但这很大程度上需要根据自己的经验来决定对某一层网络进行剪枝。 更通用的剪枝策略是采用全局剪枝,即从整体网络的角度进行剪枝。采用全局剪枝时,不同的层被剪掉的百分比可能不同。

model = LeNet().to(device=device)# 1.打印初始化模型的状态字典
print(model.state_dict().keys())
print('*'*50)# 2.构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),(model.fc2, 'weight'))# 3. 全局剪枝
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)# 4. 打印剪枝后模型的状态字典
print(model.state_dict().keys())
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.bias', 'fc2.weight_orig', 'fc2.weight_mask', 'fc3.weight', 'fc3.bias'])

打印一下各层被剪枝的比例:

print("Sparsity in conv1.weight: {:.2f}%".format(100. * float(torch.sum(model.conv1.weight == 0))/ float(model.conv1.weight.nelement())))print("Sparsity in conv2.weight: {:.2f}%".format(100. * float(torch.sum(model.conv2.weight == 0))/ float(model.conv2.weight.nelement())))print("Sparsity in fc1.weight: {:.2f}%".format(100. * float(torch.sum(model.fc1.weight == 0))/ float(model.fc1.weight.nelement())))print("Sparsity in fc2.weight: {:.2f}%".format(100. * float(torch.sum(model.fc2.weight == 0))/ float(model.fc2.weight.nelement())))print("Global sparsity: {:.2f}%".format(100. * float(torch.sum(model.conv1.weight == 0)+ torch.sum(model.conv2.weight == 0)+ torch.sum(model.fc1.weight == 0)+ torch.sum(model.fc2.weight == 0))/ float(model.conv1.weight.nelement()+ model.conv2.weight.nelement()+ model.fc1.weight.nelement()+ model.fc2.weight.nelement())))
Sparsity in conv1.weight: 5.33%
Sparsity in conv2.weight: 17.25%
Sparsity in fc1.weight: 22.03%
Sparsity in fc2.weight: 14.67%
Global sparsity: 20.00%

四、自定义剪枝(Custom pruning)

  剪枝模型通过继承class BasePruningMethod()来执行剪枝, 内部有若干方法: call, apply_mask, apply, prune, remove等等。其中,必须实现__init__构造函数和compute_mask两个函数才能完成自定义的剪枝规则设定。 此外,您必须指定要实现的修剪类型( global, structured, and unstructured)。

# 自定义剪枝方法的类, 一定要继承prune.BasePruningMethod
class custom_prune(prune.BasePruningMethod):# 指定此技术实现的修剪类型(支持的选项为global、 structured和unstructured)PRUNING_TYPE = "unstructured"# 内部实现compute_mask函数, 定义剪枝规则, 本质上就是如何去mask掉权重参数def compute_mask(self, t, default_mask):mask = default_mask.clone()# 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数的50%被mask掉mask.view(-1)[::2] = 0return mask# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def custome_unstructured_pruning(module, name):custom_prune.apply(module, name)return module
import time
# 实例化模型类
model = LeNet().to(device=device)start = time.time()
# 调用自定义剪枝方法的函数, 对model中的第1个全连接层fc1中的偏置bias执行自定义剪枝
custome_unstructured_pruning(model.fc1, name="bias")# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc1.bias_mask)# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
5.576610565185547 ms

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPT1.0 和 GPT2.0 的联系与区别

随着自然语言处理技术的飞速发展&#xff0c;OpenAI 提出的 GPT 系列模型成为了生成式预训练模型的代表。作为 GPT 系列的两代代表&#xff0c;GPT-1 和 GPT-2 虽然在架构上有着继承关系&#xff0c;但在设计理念和性能上有显著的改进。本文将从模型架构、参数规模、训练数据和…

Java-06 深入浅出 MyBatis - 一对一模型 SqlMapConfig 与 Mapper 详细讲解测试

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 大数据篇正在更新&#xff01;https://blog.csdn.net/w776341482/category_12713819.html 目前已经更新到了&#xff1a; MyBatis&#xff…

css使用弹性盒,让每个子元素平均等分父元素的4/1大小

css使用弹性盒&#xff0c;让每个子元素平均等分父元素的4/1大小 原本&#xff1a; ul {padding: 0;width: 100%;background-color: rgb(74, 80, 62);display: flex;justify-content: space-between;flex-wrap: wrap;li {/* 每个占4/1 */overflow: hidden;background-color: r…

佛山三水戴尔R740服务器黄灯故障处理

1&#xff1a;佛山三水某某大型商场用户反馈一台DELL PowerEdge R740服务器近期出现了黄灯警告故障&#xff0c;需要冠峰工程师协助检查故障灯原因。 2&#xff1a;工程师协助该用户通过笔记本网线直连到服务器尾部的IDRAC管理端口&#xff0c;默认ip 192.168.0.120 密码一般在…

[ 应急响应进阶篇-1 ] Windows 创建后门并进行应急处置(后门账户\计划任务后门\服务后门\启动项后门\粘贴键后门)

&#x1f36c; 博主介绍 &#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 _PowerShell &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【数据通信】 【通讯安全】 【web安全】【面试分析】 &#x1f389;点赞➕评论➕收藏 养成习…

力扣 LeetCode 513. 找树左下角的值(Day8:二叉树)

解题思路&#xff1a; 方法一&#xff1a;递归法&#xff08;方法二更好理解&#xff0c;个人更习惯方法二&#xff09; 前中后序均可&#xff0c;实际上没有中的处理 中左右&#xff0c;左中右&#xff0c;左右中&#xff0c;实际上都是左在前&#xff0c;所以遇到的第一个…

基于web的教务系统的实现(springboot框架 mysql jpa freemarker)

&#x1f497;博主介绍&#x1f497;&#xff1a;✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示&#xff1a;文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

vue学习11.21

vue特点&#xff1a; 采用组件化开发&#xff0c;提高代码复用率和维护 声明式编码&#xff0c;不需要直接操作DOM元素 使用diff算法&#xff0c;把虚拟DOM变成真实DOM&#xff0c; 如果两个容器都用vue的实例&#xff0c;只选最上面的容器。 一个容器使用两个vue实例也不行…

【数据分享】中国汽车工业年鉴(1986-2023)

本年鉴是由工业和信息化部指导&#xff0c;中国汽车技术研究中心有限公司与中国汽车工业协会联合主办。《年鉴》是全面、客观记载中国汽车工业发展与改革历程的重要文献&#xff0c;内容涵盖汽车产业政策、标准、企业、市场以及全国各省市汽车工业发展情况&#xff0c;并调查汇…

Java项目实战II基于微信小程序的南宁周边乡村游平台(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 随着城市化…

深入理解Redis(七)----Redis实现分布式锁

基于Redis的实现方式 1、选用Redis实现分布式锁原因&#xff1a; &#xff08;1&#xff09;Redis有很高的性能&#xff1b; &#xff08;2&#xff09;Redis命令对此支持较好&#xff0c;实现起来比较方便 2、使用命令介绍&#xff1a; &#xff08;1&#xff09;SETNX SETNX …

如何创建一个项目用于研究element-plus的原理

需求&#xff1a;直接使用element-plus未封装成组件的源码&#xff0c;创建一个项目&#xff0c;可以使用任意的element-plus组件&#xff0c;可以深度研究组件的运行。例如研究某一个效果&#xff0c;如果直接在node_modules修改elment-plus打包之后的那些js、mjs代码&#xf…

SQL进阶技巧:如何进行数字范围统计?| 货场剩余货位的统计查询方法

目录 0 场景描述 1 剩余空位区间和剩余空位编号统计分析 2 查找已用货位区间 3 小结 0 场景描述 这是在做一个大型货场租赁系统时遇到的问题&#xff0c;在计算货场剩余存储空间时&#xff0c;不仅仅需要知道哪些货位是空闲的&#xff0c;还要能够判断出哪些货位之间是连…

菜鸟驿站二维码/一维码 取件识别功能

特别注意需要引入 库文 ZXing 可跳转&#xff1a; 记录【WinForm】C#学习使用ZXing.Net生成条码过程_c# zxing-CSDN博客 using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using static System.Net.…

2024年亚太地区数学建模大赛D题-探索量子加速人工智能的前沿领域

量子计算在解决复杂问题和处理大规模数据集方面具有巨大的潜力&#xff0c;远远超过了经典计算机的能力。当与人工智能&#xff08;AI&#xff09;集成时&#xff0c;量子计算可以带来革命性的突破。它的并行处理能力能够在更短的时间内解决更复杂的问题&#xff0c;这对优化和…

教程 - 在 Creo Elements/Pro 中使用 Mechanica 分析杆的 von-mises 应力?

这是教程。 步骤1&#xff1a; 在“零件”模式下启动 Creo Elements/Pro。 步骤2&#xff1a; 草绘>>顶平面并绘制一个直径为 20mm 的圆。 步骤3&#xff1a; 将其挤出 200 毫米。 步骤4&#xff1a; 应用>>机械. 步骤5&#xff1a; 单击“确定”&…

ssm框架-spring-spring声明式事务

声明式事务概念 声明式事务是指使用注解或 XML 配置的方式来控制事务的提交和回滚。 开发者只需要添加配置即可&#xff0c; 具体事务的实现由第三方框架实现&#xff0c;避免我们直接进行事务操作&#xff01; 使用声明式事务可以将事务的控制和业务逻辑分离开来&#xff0c;提…

基于单片机的多功能跑步机控制系统

本设计基于单片机的一种多功能跑步机控制系统。该系统以STM32单片机为主控制器&#xff0c;由七个电路模块组成&#xff0c;分别是&#xff1a;单片机模块、电机控制模块、心率检测模块、音乐播放模块、液晶显示模块、语音控制模块、电源模块。其中&#xff0c;单片机模块是整个…

写给Vue2使用者的Vue3学习笔记

&#x1f64b;‍请注意&#xff0c;由于本人项目中引入了unplugin-auto-import的依赖&#xff0c;所以所有的代码示例中均未手动引入各种依赖库&#xff08;ref、reactive、useRouter等等&#xff09; 初始环境搭建 npm init vuelatest模板语法 插值 同 Vue2 <span>…

C# 数据结构之【图】C#图

1. 图的概念 图是一种重要的数据结构&#xff0c;用于表示节点&#xff08;顶点&#xff09;之间的关系。图由一组顶点和连接这些顶点的边组成。图可以是有向的&#xff08;边有方向&#xff09;或无向的&#xff08;边没有方向&#xff09;&#xff0c;可以是加权的&#xff…