【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上,非结构化剪枝)

文章目录

      • 0. 前言
      • 1. 剪枝`prune`主要功能分类
      • 2. `torch.nn.utils.prune`中的方法介绍
      • 3. PyTorch实例
        • 3.1 `BasePruningMethod`
        • 3.2`PruningContainer`
        • 3.3 `identity`
        • 3.4`random_unstructured`
        • 3.5`l1_unstructured`
      • 4. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

PyTorch中的torch.nn.utils.prune模块是一个专门用于神经网络模型剪枝的工具集。模型剪枝是一种减少神经网络参数数量的技术,其目标是在保持模型性能的同时减少计算成本内存占用。这对于部署模型到资源受限的设备(如移动设备或嵌入式系统)特别有用。

本文将通过实例介绍torch.nn.utils.prune模块中的各个方法,由于内容较多分为上、下两篇。本篇主要介绍非结构化剪枝。

1. 剪枝prune主要功能分类

torch.nn.utils.prune模块提供了一系列的剪枝方法,包括但不限于:

  1. 无结构剪枝:这种剪枝方法可以独立地移除网络中的权重,而不考虑权重之间的结构关系。例如,L1UnstructuredRandomUnstructured 就是两种无结构剪枝方法,它们分别根据权重的绝对值大小和随机选择的方式移除权重。

  2. 结构化剪枝:与无结构剪枝相反,结构化剪枝会移除整个的结构单位(如整个神经元或通道),而不是单独的权重。RandomStructuredLnStructured 就是这样的例子,它们可以移除整个的通道。

  3. 自定义剪枝CustomFromMask 方法允许用户自定义剪枝策略,通过提供一个掩码来指定哪些权重应该被保留或移除。

  4. 剪枝管理:除了剪枝方法本身,torch.nn.utils.prune还提供了工具来管理和应用剪枝,例如,prune.global_unstructuredprune.remove 方法。前者允许跨多个层执行全局剪枝,而后者则用于移除剪枝操作,恢复原始权重或应用剪枝掩码。

2. torch.nn.utils.prune中的方法介绍

下面是本文将介绍的torch.nn.utils.prune中的方法:

  • BasePruningMethod: 抽象基类,用于创建新的剪枝类。
  • PruningContainer:允许组合多种不同的剪枝策略,并按顺序应用这些策略。
  • identity: 实现了一个不剪枝任何单元仅生成一个全为一的掩码的实用剪枝方法。
  • random_unstructured: 随机剪枝张量中的单元。
  • l1_unstructured: 根据L1范数(绝对值)剪枝张量中的单元。

3. PyTorch实例

为了介绍这些剪枝方法,我们将首先定义一个简单的模型,并使用torch.nn.utils.prune模块中的各种剪枝方法来处理这个模型的权重。我们将以一个简单的卷积层为例,然后应用上述提到的每种剪枝方法。

首先,让我们导入必要的库并定义一个包含单个卷积层的模型:

import torch
import torch.nn as nn
from torch.nn.utils import prunetorch.manual_seed(888)
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()# 创建一个简单的卷积层self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)model = SimpleModel()

这里有一个值得注意的地方就是prune的导入:如果不写from torch.nn.utils import prune,而直接在代码中使用torch.nn.utils.prune.xxxx(),会报错↓
在这里插入图片描述
这个报错我不太能理解,不知道会不会在后续版本中更正。

接下来,我们将逐一介绍并应用每种剪枝方法:

3.1 BasePruningMethod

这是一个抽象类,可以理解为自定义剪枝的类。

class BasePruningMethod(ABC):r"""Abstract base class for creation of new pruning techniques.Provides a skeleton for customization requiring the overriding of methodssuch as :meth:`compute_mask` and :meth:`apply`."""
3.2PruningContainer

一开始我觉得这个方法和nn.Sequential差不多,但是实际并不是!

PruningContainer通常不会直接由用户实例化,而是作为torch.nn.utils.prune中其他剪枝方法的基础。当调用如l1_unstructuredrandom_unstructuredln_structured等剪枝方法时,内部会创建一个PruningContainer实例,并且将特定的剪枝方法添加到容器中。

3.3 identity

这个方法不会剪枝(改变)任何权重,它只会生成一个全为1的掩码。

print("Weight before Identity pruning:")
print(model.conv.weight)
prune.identity(model.conv, name="weight")
print("Weight after Identity pruning:")
print(model.conv.weight)
print("mask:")
print(model.conv.weight_mask)

输出为:

Weight before Identity pruning:
Parameter containing:
tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]]], requires_grad=True)
Weight after Identity pruning:
tensor([[[[-0.3017,  0.1290, -0.2468],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.1403]]]], grad_fn=<MulBackward0>)
mask:
tensor([[[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]]]])
3.4random_unstructured

这个方法会随机选择权重值进行剪枝:

prune.random_unstructured(model.conv, name="weight", amount=0.5) 
#amount参数指定的是要被剪枝(即置零)的权重比例。
print("Weight after RandomUnstructured pruning (50%):")
print(model.conv.weight)

输出为:

Weight after RandomUnstructured pruning (50%):
tensor([[[[-0.0000,  0.0000, -0.0000],[ 0.2107,  0.1799,  0.1923],[ 0.1887, -0.0527,  0.0000]]]], grad_fn=<MulBackward0>)

可以明显看出,对比3.3节的输出结果,有4个(50%)参数被剪枝(置零)了。

3.5l1_unstructured

这个方法会根据权重的L1范数选择要剪枝的权重。

prune.l1_unstructured(model.conv, name="weight", amount=0.5)
print("Weight after L1Unstructured pruning (50%):")
print(model.conv.weight)

输出为:

Weight after L1Unstructured pruning (50%):
tensor([[[[-0.3017,  0.0000, -0.2468],[ 0.2107,  0.0000,  0.1923],[ 0.1887, -0.0000,  0.0000]]]], grad_fn=<MulBackward0>)Process finished with exit code 0

对比3.3输出的结果,可以看出L1范数(绝对值)最小的4个(50%)参数被剪枝(置零)了。

4. 总结

本文介绍了PyTorch中的prune模型剪枝模块的中的非结构化剪枝,下一篇将介绍结构化剪枝。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/860106.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI办公自动化:免费批量将英语电子书转成有声书

Edge-TTS是由微软推出的文本转语音Python库&#xff0c;通过微软Azure Cognitive Services转化文本为自然语音。可以作为付费文本转语音TTS服务的替代品&#xff0c;Edge-TTS支持40多种语言和300种声音&#xff0c;提供优质的语音输出 。 edge-tts支持英语、汉语、日语、韩语、…

基于Netron库的PyTorch 2.0模型可视化

【图书推荐】《从零开始大模型开发与微调&#xff1a;基于PyTorch与ChatGLM》_《从零开始大模型开发与微调:基于pytorch与chatglm》-CSDN博客 前面章节带领读者完成了基于PyTorch 2.0的MNIST模型的设计&#xff0c;并基于此完成了MNIST手写体数字的识别。此时可能有读者对我们…

C语言结构体包含结构体

C语言结构体可以包含另一个结构体&#xff1b; 下面通过一个例子看一下&#xff1b; struct Date {int day;int month;int year; };struct Person {char *name;struct Date birthday; }; ...... void CTestView::OnDraw(CDC* pDC) {CTestDoc* pDoc GetDocument();ASSERT_VAL…

C语言 | Leetcode C语言题解之第189题轮转数组

题目&#xff1a; 题解&#xff1a; void swap(int* a, int* b) {int t *a;*a *b, *b t; }void reverse(int* nums, int start, int end) {while (start < end) {swap(&nums[start], &nums[end]);start 1;end - 1;} }void rotate(int* nums, int numsSize, int…

国内邮件推送如何避免拦截?内容优化技巧?

国内邮件推送的平台怎么选择&#xff1f;如何提高邮件推送效果&#xff1f; 邮件营销是企业与客户沟通的重要方式&#xff0c;但在国内邮件推送过程中&#xff0c;邮件被拦截的问题屡见不鲜。为了确保邮件能够顺利送达目标用户&#xff0c;AokSend将探讨一些有效的策略&#x…

【Android】实现图片和视频混合轮播(无限循环、视频自动播放)

目录 前言一、实现效果二、具体实现1. 导入依赖2. 布局3. Banner基础配置4. Banner无限循环机制5. 轮播适配器6. 视频播放处理7. 完整源码 总结 前言 我们日常的需求基本上都是图片的轮播&#xff0c;而在一些特殊需求&#xff0c;例如用于展览的的数据大屏&#xff0c;又想展…

跟着DW学习大语言模型-什么是知识库,如何构建知识库

建立一个高效的知识库对于个人和组织来说非常重要。无论是为了个人学习和成长&#xff0c;还是为了组织的持续创新和发展&#xff0c;一个完善的知识管理系统都是不可或缺的。那么&#xff0c;如何建立一个高效的知识库呢&#xff1f; 在建立知识库之前&#xff0c;首先需要确定…

第3章 小功能大用处-事务与Lua

为了保证多条命令组合的原子性&#xff0c;Redis提供了简单的事务功能以及集成Lua脚本来解决这个问题。 首先简单介绍Redis中事务的使用方法以及它的局限性&#xff0c;之后重点介绍Lua语言的基本使用方法&#xff0c;以及如何将Redis和Lua脚本进行集成&#xff0c;最后给出Red…

项目实训-vue(十三)

项目实训-vue&#xff08;十三&#xff09; 文章目录 项目实训-vue&#xff08;十三&#xff09;1.概述2.处理按钮 1.概述 本篇博客将记录我在图片上传页面中的工作。 2.处理按钮 实现了图片的上传之后&#xff0c;还需要设置具体的上传按钮。 这段代码使用 Element UI 的 …

Spring学习02-[Spring容器核心技术IOC学习]

Spring容器核心技术IOC学习 什么是bean?如何配置bean?Component方式bean配合配置类的方式import导入方式 什么是bean? 被Spring管理的对象就是bean,和普通对象的区别就是里面bean对象里面的属性也被注入了。 如何配置bean? Component方式、bean配合配置类的方式、import…

C语言 | Leetcode C语言题解之第190题颠倒二进制位

题目&#xff1a; 题解&#xff1a; const uint32_t M1 0x55555555; // 01010101010101010101010101010101 const uint32_t M2 0x33333333; // 00110011001100110011001100110011 const uint32_t M4 0x0f0f0f0f; // 00001111000011110000111100001111 const uint32_t M8…

【containerd】Containerd高阶命令行工具nerdctl

前言 对于习惯了使用docker cli的用户来说&#xff0c;containerd的命令行工具ctr使用起来不是很顺手&#xff0c;此时别慌&#xff0c;还有另外一个命令行工具项目nerdctl可供我们选择。 nerdctl是一个与docker cli风格兼容的containerd的cli工具。 nerdctl已经作为子项目加入…

秋招突击——6/24——复习{完全背包问题——买书,状态转换机——股票买卖V}——新作{两数相除,LRU缓存实现}

文章目录 引言复习完全背包问题——买书个人实现 状态转换机——股票买卖V个人实现参考实现 新作两数相除个人实现 新作LRU缓存实现个人实现unordered_map相关priority_queue相关 参考实现自己复现 总结 引言 今天知道拼多多挂掉了&#xff0c;难受&#xff0c;那实习就是颗粒无…

汪汪队短视频:成都柏煜文化传媒有限公司

汪汪队短视频&#xff1a;萌宠与冒险的交织乐章 在数字时代的浪潮中&#xff0c;短视频以其短小精悍、内容丰富的特点&#xff0c;迅速占领了人们的闲暇时光。而在这些琳琅满目的短视频中&#xff0c;有一类作品以其独特的魅力吸引了无数观众的目光&#xff0c;那就是以萌宠为…

单门户上集成多种数据库查询入口

&#xff08;作者&#xff1a;陈玓玏&#xff09; 开源项目&#xff0c;欢迎star哦&#xff0c;https://github.com/tencentmusic/cube-studio 在一家公司&#xff0c;我们通常会有多种数据库&#xff0c;每种数据库因为其特性承担不同的角色&#xff0c;比如mysql这种轻量…

AI-024人工智能指数报告(三):经济

概述 人工智能融入经济会引发许多很迷人的问题。有人预测人工智能会推动生产力得到改进&#xff0c;但其影响程度仍未确定。其中一个主要关切是大规模劳动替代的可能性——工作究竟会在多大程度上被自动化还是人工智能主要起到增强作用&#xff1f;各个行业的企业已经在用各种…

基于FPGA的温湿度检测

初始化部分就不过多赘述&#xff0c;我会给出对应的文件&#xff0c;我只说明这部分里面涉及到使用的代码部分 1、数据的读取和校验 数据的读取和检验代码如下 always (posedge clk_us)if (data_temp[7:0] data_temp[39:32] data_temp[31:24] data_temp[23:16] data_te…

SpringBoot+Vue集成富文本编辑器

1.引入 我们常常在各种网页软件中编写文档的时候&#xff0c;常常会有富文本编辑器&#xff0c;就比如csdn写博客的这个页面&#xff0c;包含了富文本编辑器&#xff0c;那么怎么实现呢&#xff1f;下面来详细的介绍&#xff01; 2.安装wangeditor插件 在Vue工程中&#xff0c;…

基于 SpringBoot + Vue 的图书购物商城项目

本项目是一个基于 SpringBoot 和 Vue 的图书购物商城系统。系统主要实现了用户注册、登录&#xff0c;图书浏览、查询、加购&#xff0c;购物车管理&#xff0c;订单结算&#xff0c;会员折扣&#xff0c;下单&#xff0c;个人订单管理&#xff0c;书籍及分类管理&#xff0c;用…

PCL 使用列文伯格-马夸尔特法计算变换矩阵

目录 一、算法原理1、计算过程2、主要函数3、参考文献二、代码实现三、结果展示本文由CSDN点云侠原创,原文链接。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫与GPT。 一、算法原理 1、计算过程 2、主要函数 void pcl