模型剪枝入门

一、定义

1.定义
2. 案例1
3. 全局剪枝案例
4. 全局剪枝案例
5. 自定义剪枝
6. 特定网络剪枝
7. 多参数模块剪枝
8. torch.nn.utils.prune 解读

二、实现

  1. 定义
    在这里插入图片描述
  2. 接口:
import torch.nn.utils.prune as prune
  1. 案例1
import torch.nn as nn
import torch.nn.utils.prune as pruneimport torchdef prune_first_layer(module, inputs):# 对权重矩阵进行L1剪枝,保留80%的元素prune.l1_unstructured(module,"weight", amount=0.8)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 10)def forward(self, x):x = x.view(-1, 784)x = self.fc1(x)x = nn.functional.relu(x)x = self.fc2(x)x = nn.functional.relu(x)x = self.fc3(x)return xnet = Net()# 将钩子函数与第一个全连接层关联起来
handle = net.fc1.register_forward_pre_hook(prune_first_layer)   #fc1 执行之前进行剪枝# 进行前向传递
output = net(torch.randn(1, 784))
# 移除钩子函数
handle.remove()
  1. 全局剪枝案例
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸self.conv1 = nn.Conv2d(1, 6, 3)self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = LeNet().to(device=device)# 首先打印初始化模型的状态字典
print(model.state_dict().keys())
print('*'*50)# 构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),(model.fc2, 'weight'),(model.fc3, 'weight'))# 调用prune中的全局剪枝函数global_unstructured执行剪枝操作, 此处针对整体模型中的20%参数量进行剪枝
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)# 最后打印剪枝后的模型的状态字典
print(model.state_dict().keys())print("Sparsity in conv1.weight: {:.2f}%".format(100. * float(torch.sum(model.conv1.weight == 0))/ float(model.conv1.weight.nelement())))print("Sparsity in conv2.weight: {:.2f}%".format(100. * float(torch.sum(model.conv2.weight == 0))/ float(model.conv2.weight.nelement())))print("Sparsity in fc1.weight: {:.2f}%".format(100. * float(torch.sum(model.fc1.weight == 0))/ float(model.fc1.weight.nelement())))print("Sparsity in fc2.weight: {:.2f}%".format(100. * float(torch.sum(model.fc2.weight == 0))/ float(model.fc2.weight.nelement())))print("Sparsity in fc3.weight: {:.2f}%".format(100. * float(torch.sum(model.fc3.weight == 0))/ float(model.fc3.weight.nelement())))print("Global sparsity: {:.2f}%".format(100. * float(torch.sum(model.conv1.weight == 0)+ torch.sum(model.conv2.weight == 0)+ torch.sum(model.fc1.weight == 0)+ torch.sum(model.fc2.weight == 0)+ torch.sum(model.fc3.weight == 0))/ float(model.conv1.weight.nelement()+ model.conv2.weight.nelement()+ model.fc1.weight.nelement()+ model.fc2.weight.nelement()+ model.fc3.weight.nelement())))# 当采用全局剪枝策略的时候(假定20%比例参数参与剪枝),
# 仅保证模型总体参数量的20%被剪枝掉,
# 具体到每一层的情况则由模型的具体参数分布情况来定.
import torchfrom torch.nn.utils import prune# 
定义模型model = torch.nn.Sequential(   torch.nn.Linear(100, 50), torch.nn.ReLU(),  torch.nn.Linear(50, 10))# 剪枝网络结构
prune.random_unstructured(model, amount=0.2)
# 剪枝权重
prune.l1_unstructured(model, amount=0.2)
  1. 自定义剪枝
#用户自定义剪枝(Custom pruning).
# 剪枝模型通过继承class BasePruningMethod()来执行剪枝,
# 内部有若干方法: call, apply_mask, apply, prune, remove等等.
# 一般来说, 用户只需要实现__init__, 和compute_mask两个函数即可完成自定义的剪枝规则设定.
import time
# 自定义剪枝方法的类, 一定要继承prune.BasePruningMethod
class myself_pruning_method(prune.BasePruningMethod):PRUNING_TYPE = "unstructured"# 内部实现compute_mask函数, 完成程序员自己定义的剪枝规则, 本质上就是如何去mask掉权重参数def compute_mask(self, t, default_mask):mask = default_mask.clone()# 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数量的50%被mask掉mask.view(-1)[::2] = 0return mask# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def myself_unstructured_pruning(module, name):myself_pruning_method.apply(module, name)return module# 实例化模型类
model = LeNet().to(device=device)start = time.time()
# 调用自定义剪枝方法的函数, 对model中的第三个全连接层fc3中的偏置bias执行自定义剪枝
myself_unstructured_pruning(model.fc3, name="bias")# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc3.bias_mask)# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')# 打印出来的bias_mask张量, 完全是按照预定义的方式每隔一位遮掩掉一位,
#  0和1交替出现, 后续执行remove操作的时候,
# 原始的bias_orig中的权重就会同样的被每隔一位剪枝掉一位.
  1. 特定网络剪枝
# 第一种: 对特定网络模块的剪枝(Pruning Model).import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸self.conv1 = nn.Conv2d(1, 6, 3)self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = LeNet().to(device=device)# 序列化一个剪枝模型(Serializing a pruned model):
# 对于一个模型来说, 不管是它原始的参数, 拥有的属性值, 还是剪枝的mask buffers参数
# 全部都存储在模型的状态字典中, 即state_dict()中.
# 将模型初始的状态字典打印出来
print(model.state_dict().keys())
print('*'*50)# 对模型进行剪枝操作, 分别在weight和bias上剪枝
module = model.conv1
prune.random_unstructured(module, name="weight", amount=0.3)
prune.l1_unstructured(module, name="bias", amount=3)# 再将剪枝后的模型的状态字典打印出来
print(model.state_dict().keys())# 对模型执行剪枝remove操作.
# 通过module中的参数weight_orig和weight_mask进行剪枝, 本质上属于置零遮掩, 让权重连接失效.
# 具体怎么计算取决于_forward_pre_hooks函数.
# 这个remove是无法undo的, 也就是说一旦执行就是对模型参数的永久改变.# 打印剪枝后的模型参数
print(list(module.named_parameters()))
print('*'*50)# 打印剪枝后的模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)# 打印剪枝后的模型weight属性值
print(module.weight)
print('*'*50)# 打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)
print('*'*50)# 执行剪枝永久化操作remove
prune.remove(module, 'weight')
print('*'*50)# remove后再次打印模型参数
print(list(module.named_parameters()))
print('*'*50)# remove后再次打印模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)# remove后再次打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)# 对模型的weight执行remove操作后, 模型参数集合中只剩下bias_orig了,
# weight_orig消失, 变成了weight, 说明针对weight的剪枝已经永久化生效.
# 对于named_buffers张量打印可以看出, 只剩下bias_mask了,
# 因为针对weight做掩码的weight_mask已经生效完毕, 不再需要保留了.
# 同理, 在_forward_pre_hooks中也只剩下针对bias做剪枝的函数了.
  1. 多参数模块剪枝
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸self.conv1 = nn.Conv2d(1, 6, 3)self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 第二种: 多参数模块的剪枝(Pruning multiple parameters).
model = LeNet().to(device=device)# 打印初始模型的所有状态字典
print(model.state_dict().keys())
print('*'*50)# 打印初始模型的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)# 对于模型进行分模块参数的剪枝
for name, module in model.named_modules():# 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝if isinstance(module, torch.nn.Conv2d):prune.l1_unstructured(module, name="weight", amount=0.2)# 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝elif isinstance(module, torch.nn.Linear):prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)# 打印多参数模块剪枝后的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)# 打印多参数模块剪枝后模型的所有状态字典名称
print(model.state_dict().keys())
  1. torch.nn.utils.prune 解读:用于修剪模块参数的实用类和函数。
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/48310.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全部由1组成的子矩形的数量

题目描述: 给定一个二维数组matrix,其中的值不是0就是1,返回全部由1组成的子矩阵的数量。 way: 假设我们遍历矩形的每一行,以当前遍历到的行作为地基,去看这一行的直方图(直方图介绍 ->直方…

10.3.3 QGIS点类型注释(Annotation)的应用与二次开发实现

文章目录 前言注释(Annotation)图层QGis中的注释(Annotation)图层二次开发实现线段类型注释(Annotation)点类型Item 总结 前言 介绍注释(Annotation)图层在QGis中的使用以及二次开发的实现说明:文章中的示例代码均来自开源项目qgis_cpp_api_apps 注释(Annotation)…

【Unity实战100例】Unity声音可视化多种显示效果

目录 一、技术背景 二、界面搭建 三、 实现 UIAudioVisualizer 基类 四、实现 AudioSampler 类 五、实现 IAudioSample 接口 六、实现MusicAudioVisualizer 七、实现 MicrophoneAudioManager 类 八、实现 MicrophoneAudioVisualizer 类 九、源码下载 Unity声音可视化四…

代码随想录算法训练营第九天 |LeetCode151.翻转字符串里的单词 卡码网:55.右旋转字符串

代码随想录算法训练营 Day 9 代码随想录算法训练营第九天 |LeetCode151.翻转字符串里的单词 卡码网:55.右旋转字符串 目录 代码随想录算法训练营前言LeetCode151.翻转字符串里的单词卡码网:55.右旋转字符串 一、LeetCode151.翻转字符串里的单词1.题目链…

laravel为Model设置全局作用域

如果一个项目中存在这么一个sql条件在任何情况下或大多数情况都会被使用,同时很容易被开发者遗忘,那么就非常适用于今天要提到的这个功能,Eloquent\Model的全局作用域。 首先看一个示例,有个数据表,结构如下&#xff1…

一款国外开发的高质量WordPress下载站模板主题

5play下载站是由国外站长开发的一款WordPress主题,主题简约大方,为v1.8版本, 该主题模板中包含了上千个应用,登录后台以后只需要简单的三个步骤就可以轻松发布apk文章, 我们只需要在WordPress后台中导入该主题就可以…

大模型应用如何点燃?

▎****尽管在中国,关于大模型的商业模式的讨论尚显早期,但智能体,尤其是专业智能体,蕴藏着巨大的潜力。 ChatGPT 还没有颠覆世界。 身处“第三次信息革命”,很多人被浓烈的FOMO(Fear of Missing Out&…

昇思25天学习打卡营第12天 | ResNet50图像分类

ResNet50在CIFAR-10数据集上的图像分类实践 在深入学习和实践使用ResNet50进行CIFAR-10数据集上的图像分类后,我对深度学习模型的构建、训练和优化有了更深刻的理解。本次学习经历涵盖了从理论探索到实际应用的全过程,以下是我的主要收获和反思。 1. 理…

(南京观海微电子)——电感的电路原理及应用区别

电感 电感是导线内通过交流电流时,在导线的内部及其周围产生交变磁通,导线的磁通量与生产此磁通的电流之比。 当电感中通过直流电流时,其周围只呈现固定的磁力线,不随时间而变化;可是当在线圈中通过交流电流时&am…

Jump Point Search(JPS)算法与A*算法

A* A*算法本质上讲是结合了DFS和BFS,针对当前起点先做一次BFS,再针对搜索的八个点做一次DFS BFS--广度优先算法(Breadth First Search) DFS A* 算法思想 A*的核心思想就是先进行一次BFS搜索,然后从这次BFS中找到距离…

python Requests库7种主要方法及13个控制参数(实例实验)

文章目录 一、Requests库的7种主要方法二、kwargs:控制访问的13个参数 一、Requests库的7种主要方法 序号方法说明1requests.request():提交一个request请求,作为其他请求的基础2requests.get():获取HTML网页代码的方法3requests.head()&…

基于重要抽样的主动学习不平衡分类方法ALIS

这篇论文讨论了数据分布不平衡对分类器性能造成的影响,并提出了一种新的有效解决方案 - 主动学习框架ALIS。 1、数据分布不平衡会影响分类器的学习性能。现有的方法主要集中在过采样少数类或欠采样多数类,但往往只采用单一的采样技术,无法有效解决严重的类别不平衡问题。 2、论…

9种二极管及其特点总结

二极管种类和特点 名字特点恒流二极管近些年出现,电压大于某个值,电流恒定,一般用于led普通二极管低频整流和续流,便宜,反向恢复时间us级别,PN结肖特基二极管比普通二极管反向关断更快,10ns级别…

智能硬件——0-1开发流程

文章目录 流程图1. 市场分析具体分析 2. 团队组建2. 团队组建早期团队配置建议配置一:基础型团队 (4人)配置二:扩展型团队 (6人)配置三:全面型团队 (7人) 3. 产品需求分析4. ID设计(Industrial Design, 工业设计)5. 结…

阿里云公共DNS免费版自9月30日开始限速 企业或商业场景需使用付费版

本周阿里云发布公告对公共 DNS 免费版使用政策进行调整,免费版将从 2024 年 9 月 30 日开始按照请求源 IP 进行并发数限制,单个 IP 的请求数超过 20QPS、UDP/TCP 流量超过 2000bps 将触发限速策略。 阿里云称免费版的并发数限制并非采用固定的阈值&…

Unity游戏开发入门:从安装到创建你的第一个3D场景

目录 引言 一、Unity的安装 1. 访问Unity官网 2. 下载Unity Hub 3. 安装Unity Hub并安装Unity编辑器 二、创建你的第一个项目 1. 启动Unity Hub并创建新项目 2. 熟悉Unity编辑器界面 3. 添加基本对象 4. 调整对象属性 5. 添加光源 三、运行与预览 引言 Unity&…

netty 自定义客户端连接池和channelpool

目录标题 客户端池化运行分析问题修复 客户端池化 通信完成之后,一般要关闭channel,释放内存。但是与一个服务器频繁的打开关闭浪费资源。 通过连接池,客户端和服务端之间可以创建多个 TCP 连接,提升消息的收发能力,同…

【深度学习】VGG-16原理及代码实现

1.原理及介绍 2.代码实现 2.1model.py import torch from torch import nn from torchsummary import summary import torch.nn.functional as Fclass VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()self.block1 nn.Sequential( # 用一个序列&#xf…

51单片机嵌入式开发:13、STC89C52RC 之 RS232与电脑通讯

STC89C52RC 之 RS232与电脑通讯 第十三节课,RS232与电脑通讯1 概述2 Uart介绍2.1 概述2.2 STC89C52UART介绍2.3 STC89C52 UART寄存器介绍2.4 STC89C52 UART操作 3 C51 UART总结 第十三节课,RS232与电脑通讯 1 概述 RS232(Recommended Stand…

Github报错:Kex_exchange_identification: Connection closed by remote host

文章目录 1. 背景介绍2. 排查和解决方案 1. 背景介绍 Github提交或者拉取代码时,报错如下: Kex_exchange_identification: Connection closed by remote host fatal: Could not read from remote repository.Please make sure you have the correct ac…