对 MODNet 网络结构直接剪枝的探索

文章目录

  • 1 写在前面
  • 2 遇到问题
  • 3 解决方案
  • 4 探索过程
    • 4.1 方案一
    • 4.2 方案二
    • 4.3 方案三
  • 5 疑惑与思考
    • 5.1 Q1
    • 5.2 Q2

1 写在前面

在前面的文章中,笔者与小伙伴们分享了对 MODNet 主干网络部分以及其余分支分别剪枝的探索历程,即先分解、再处理、后融合的手法。然而,

马克思曾深刻地思考社会的全面性,强调社会是一个复杂而相互关联的整体。在他的《资本论》中,他突显了社会结构的总体性,认为理解社会现象需要考虑各个层面的相互作用。正如马克思强调的那样,我们不能孤立地看待社会中的一部分,而是应该以全面、整体的视角去考虑其内在关系和矛盾。

在这个思想的启发下,笔者将 MODNet 作为一个整体,直接做剪枝处理。同样,我们还是借助 NNI 工具库实现模型剪枝与加速。

2 遇到问题

在这里插入图片描述

由于 MODNet 中计算维度较大,超出了PyCharm默认的内存限制。

3 解决方案

  1. 修改 PyCharm 最大内存限制;
  2. 降低输入图像的分辨率;
  3. 对 MODNet 中的少部分层进行剪枝;

4 探索过程

4.1 方案一

打开 PyCharm 中的 VM options:
在这里插入图片描述
修改 XMX,即运用程序运行时的可用内存大小,本机的运行内存为12G,这里先对半设置 6002M。然而,运行后依旧提示内存不足,当设置为7000时,在运行中,PyCharm 自动关闭。因此,该方法无效,而且属于高危行为!

💥注意:不能绝对设置大小,考虑到计算机除 PyCharm 以外也有其他进程占用内存,因此,设置上限时需要综合考虑计算机的状况。

4.2 方案二

将 MODNet 输入尺寸从 512 降低为 256,成功剪枝!

由于使用 DataParallel 加载以后,在 GPU 上剪枝会提示显存不足!因此,笔者先将 MODNet 加载到 CPU 上,取消 DataParallel 加载,可以完成剪枝与模型加速!!

在这里插入图片描述

实际上,尽管模型是通过多卡训练保存得到,在使用 DataParallel 加载后,也可以直接转换到 CPU 上:

from src.models import modnet
import torchmodel = modnet.MODNet(backbone_pretrained=False)
model = torch.nn.DataParallel(model).to('cpu')pretrained_ckpt = torch.load('modnet_photographic_portrait_matting.ckpt')
model.load_state_dict(pretrained_ckpt, strict=False)print(next(model.parameters()).device)  # CPU

然而,剪枝出现了问题,即 module 的参数必须是在 CUDA 上操作,在 CPU 上无效:

在这里插入图片描述

该问题表明:利用to.(‘cpu’)的方式从 CUDA 转移到 CPU 本身是没问题的,但这不被 module 接受。换句话说,在 MODNet 模型外面,包裹着 module 模块。因此,如果要在 CPU 上完成剪枝,module 是首要解决的问题!


打印结构发现:

在这里插入图片描述

和原先进行对比:

在这里插入图片描述

多了 module 模块,这是一个值得思考的地方!🙄

再次了解 torch.nn.DataParallel():

在使用多卡训练时,该函数能够将 input 数据划分,进而送进不同的卡上训练;而模型的 module 会复制到不同的卡上。换句话说,具有相同 module 的不同卡会处理划分到的数据,当然,这是 forward 部分。而在 backpropagation 部分,不同卡的梯度会累加到原始的 module 上,被 cuda:0 计算。当训练完成,保存时也会采用model.module.state_dict(),而非单卡训练时的model.state_dict();在将参数加载时,结构中也必然存在module。


因此,这也就对 NNI 中看似排除某些层,实际上没有排除解释通了:原先并未在 torch.nn.DataParallel() 加载后观察结构情况,同时也因为GPU显存不够直接将模型转到了 CPU 上,导致在剪枝的 config_list 中没有指定正确的参数名。

用 NNI 输出 flops 进行对比:

在这里插入图片描述

在这里插入图片描述


结构中的 module 模块如何处理,笔者考虑了两种方案:

  1. 修改 NNI 的config_list 为 module. 进行剪枝;
  2. 去除 module;

由于第一种方案会涉及到上述提及的问题:从 CUDA 转到 CPU 不会被 module 接受。因此我们选择方案2。

加载的 ckpt 类型为 dict,因此,通过 items 获得 key 以及 value 后可以通过 replace 替换,如下:

model = modnet.MODNet(backbone_pretrained=False)
pretrained_ckpt = 'modnet_photographic_portrait_matting.ckpt'
# model = torch.nn.DataParallel(model)
model.load_state_dict({k.replace('module.', ''):v for k, v in torch.load(pretrained_ckpt).items()})print(model)
print(list(model.named_parameters()))

在这里插入图片描述

此时,module 模块成功去除,且惊奇发现,即使不通过 DataParallel 加载 model,每次打印得到的权重参数也保持一致了!至此,module 问题解决,在 CPU 上的剪枝也就顺其自然了~~

4.3 方案三

不论如何排除某些层,从 NNI 在控制台输出的 info 来看,每次都会从头到尾计算每一层的信息,并进行更新。此外,尽管排除了某些层,由于上一层的通道数变化会影响下一层的变化,因此还是会进行计算。

所以,通过排除某些层,或者是指定某些层进行剪枝的作法,就解决内存限制问题而言,并不合理!

5 疑惑与思考

5.1 Q1

在剪枝过后,模型精度相应也会降低,为了能恢复到原来的精度,或者是达到可接受的精度,应当进行微调(fine-tune),其中需要调整训练超参数,包括 epoch、learning rate、momentum 等;

🚩说明:剪枝模型重训练不光可以采用微调,也可从头训练。这一观点在《Rethinking the value of network pruning》一文中表明。另外,文章中通过实验说明了从头训练的效果优于微调,原因是模型剪枝后重要的是紧凑的结构,而不是原来的那些重要的权重。


笔者以 LeNet 为例做了一个剪枝后微调的实验,对比如下:

原模型剪枝后微调后
Accuracy91%85%94%

为什么剪枝模型在 fine-tune 后精度更高?

有这样一种解释,大模型提供了一个包含最优解的大的解空间,而对于小模型来说,这样的解空间较小,因此更容易找到 optimal solution,故精度会比大模型更高。


但是,有些时候或许 fine-tune 后精度还是很差,这是为什么?

注意,这里笔者思考的事为啥很差,而不是差,如果下降比例不大的话,其实是因为剪枝本身导致的信息损失,这和评价指标有关,是很难避免的。

  1. 过度剪枝: 剪枝了太多的参数,导致模型容量不足,无法捕获数据中的复杂模式,以至于模型欠拟合,表现较差。
  2. 超参数调整不当: 超参数的调整本身并不容易,不同的任务和数据集需要不同的超参数配置。笔者在调参时遇到不同的随机种子可以带来10%准确率差异!所以,模型可能无法收敛到良好的解决方案。
  3. 数据不足: 如果 fine-tuning 阶段的训练数据量不足,模型可能无法充分学习新任务的特征,导致性能下降。(这里又有一个问题,大模型小数据训练会如何?也不好训练,因为大模型可能会过于复杂,难以泛化到新的数据)
  4. 任务不适合: 原始模型可能不适合进行剪枝和 fine-tuning 的任务,某些任务可能需要更大的模型,或者可以考虑使用从头训练。(但一般fine-tune和从头训练不会差异巨大)

另外,笔者也抛出一个有意思的问题:剪枝时,我们通常会采用某一种评价准则去衡量权重的重要程度,如果重要程度低,我们认为是冗余权重,所以去除;而重要权重我们选择保留,那么,冗余权重真的冗余吗?是否也会对整个模型的评估起“支撑”作用?但笔者再想想,既然都是冗余了,那为啥还要保留?说明一定是对模型不再起作用的。所以问题就回到了这个评价准则,到底如何去设计评价准则,一直是模型剪枝的一个挑战。

5.2 Q2

在只保存剪枝后模型参数的情况下,需要修改相应的网络结构才能将参数填入结构中,那为何不可参考剪枝后的结构修改原先的结构,并进行训练?

解释:因为大模型要比小模型好训练

  1. 更多的参数: 这意味着更大的模型可以学习更复杂的特征和模式,更多的参数允许模型更好地适应训练数据,捕获更多的细节和复杂性。
  2. 更好的表示能力: 大模型有更大的容量来表示数据的复杂关系,能够学习更抽象、更深层次的特征,使其在处理复杂任务时更为有效。
  3. 更好的泛化能力: 大模型在训练中可以学到更多的信息,从而提高了其在未见数据上的泛化能力,这意味着它们更有可能在面对新的、未知的数据时表现良好。

当然,由于训练大模型需要的训练时间长,计算资源和内存消耗也更大,所以需要根据实际的情况找到 trade-off

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/646498.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++:缺省参数函数重载

目录 C/C语言 函数调用的工作原理: 函数调用一般分为两个部分: 缺省参数: 缺省参数的分类: 全缺省参数 半缺省参数 注意事项: 缺省参数与C语言的调用参数对比: 函数重载: 函数重载…

pve8.1 安装、创建centos7虚拟机及配置

之前创建虚拟机centos7时,硬盘分配太大了,做成模板后无法进行修改了,安装完pve8.1后,强迫症犯了重新创建一下顺便记录一下配置过程。由于目前centos7还是生产用的比较多的版本所以本次还是安装centos7.9版本。 一、下载镜像 下载…

利用Redis List实现数据库分页快速查询的有效方法

目录 引言 传统数据库分页查询的挑战 Redis List的优势 利用Redis List实现分页查询 1. 数据准备 2. 分页查询 3. 分页缓存 4. 分页处理 结论 引言 随着Web应用程序的发展和用户数量的增加,数据库分页查询变得越来越常见。分页查询允许用户在大型数据集中…

JVM/GC复习

JVM/GC JVM(java虚拟机)MATjstack(将正在运行的JVM的线程进行快照并且打印出来)死锁VisualVM工具(监控线程内存使用情况)JMX GC垃圾回收算法1.引用计数法2.标记清除发3.标记压缩算法4.复制算法5.分代算法 收集器1.串行垃圾收集器2.并行垃圾收集器2.CMS垃圾收集器 3.G1垃圾收集器…

营销一体化平台如何助力企业增长?3个案例深度解析

无论大家怎么想,反对和批评的声音有多大,还是有很多企业从组织层面为CMO下了很多需要及时转化的KPI要求。 原因无外乎是增长乏力。再加上外部环境处在产业升级换代、科技革命在即的当口,企业比以往任何时候都意识到营销变革的重要性。 然而…

两相步进电机驱动原理

两相步进电机驱动 前言什么是步进电机驱动器细分控制电机内部结构图片步进电机驱动原理(重要)步进电机参数1、步距角:收到一个脉冲转动的角度2、细分数 :1/2&#xff0c…

清华大学对港澳台华侨生新增额外招生项目来啦

导读 众所周知的是,港澳台和华侨生录取清华大学和北京大学,除了港澳台联考,DSE申请等形式之外,那只有和普通内地高中生混在一起的录取方式。但是其实近些年来,清华大学也为尖子生开辟了新的录取方式,我们一…

Qt Quick程序的发布|Qt5中QML和Qt Quick 的更改

# Quick程序的发布旧版做法 # Qt5中QML和Qt Quick 的更改 1.QML语言的更改(Qt4->Qt5) 在QML语言中,只有少量更改会影响QML代码的迁移:无法直接导入单独的文件(例如:import"MyType.qml”),需要导人该文件所在的目录; JavaScript文件中的相对路径被解析…

线性代数:矩阵的定义

目录 一、定义 二、方阵 三、对角阵 四、单位阵 五、数量阵 六、行(列)矩阵 七、同型矩阵 八、矩阵相等 九、零矩阵 十、方阵的行列式 一、定义 二、方阵 三、对角阵 四、单位阵 五、数量阵 六、行(列)矩阵 七、同型矩…

手写一个图形验证码

文章目录 需求分析 需求 使用 JS 写一个验证码&#xff0c;并在前端进行校验 分析 新建文件 VueImageVerify.vue <template><div class"img-verify"><canvas ref"verify" :width"state.width" :height"state.height&qu…

河南嘉家购商贸有限公司获绿色积分信用认证

“实现绿色产业、打造完善的绿色产业链、走可持续发展共创共赢”。近日&#xff0c;河南嘉家购商贸有限公司获得绿色积分认证&#xff0c;确认了该企业在绿色消费积分领域的领先地位。 据了解&#xff0c;河南嘉家购商贸有限公司始终将绿色积分视为企业发展的核心要素。全面优化…

如何实现无公网ip远程访问本地websocket服务端【内网穿透】

文章目录 1. Java 服务端demo环境2. 在pom文件引入第三包封装的netty框架maven坐标3. 创建服务端,以接口模式调用,方便外部调用4. 启动服务,出现以下信息表示启动成功,暴露端口默认99995. 创建隧道映射内网端口6. 查看状态->在线隧道,复制所创建隧道的公网地址加端口号7. 以…

G1与ZGC

G1垃圾收集器(-XX:UseG1GC)详解 G1(Garbage-First)是一款面向服务器的垃圾收集器&#xff0c;主要针对配备多颗处理器及大容量内存的机器。以极高概率满足GC停顿时间要求的同时&#xff0c;还具备高吞吐量性能特性。 G1把内存区域划分为小格子(Region)&#xff0c;最多可以有2…

java常见的面试问题

目录 一、异常 1、 throw 和 throws 的区别&#xff1f; 2、 final、finally、finalize 有什么区别&#xff1f; 3、try-catch-finally 中哪个部分可以省略&#xff1f; 4、try-catch-finally 中&#xff0c;如果 catch 中 return 了&#xff0c;finally 还会执行吗&#…

大创项目推荐 题目:垃圾邮件(短信)分类 算法实现 机器学习 深度学习 开题

文章目录 1 前言2 垃圾短信/邮件 分类算法 原理2.1 常用的分类器 - 贝叶斯分类器 3 数据集介绍4 数据预处理5 特征提取6 训练分类器7 综合测试结果8 其他模型方法9 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于机器学习的垃圾邮件分类 该项目…

DP读书:在常工院的2023年度总结

DarrenPig的年度总结 这是最好的时代&#xff0c;这是最坏的时代。——狄更斯 这是最好的时代&#xff0c;这是最坏的时代。——狄更斯 这是最好的时代&#xff0c;这是最坏的时代。——狄更斯 一、2023我的感受 不就是2023吗&#xff0c;不就是一年的经历吗&#xff0c;大家…

Spring Boot 集成 API 文档 - Swagger、Knife4J、Smart-Doc

文章目录 1.OpenAPI 规范2.Swagger: 接口管理的利器3.Swagger 与 SpringFox&#xff1a;理念与实现4.Swagger 与 Knife4J&#xff1a;增强与创新5.案例&#xff1a;Spring Boot 整合 Swagger35.1 引入 Swagger3 依赖包5.2 优化路径匹配策略兼容 SpringFox5.3 配置 Swagger5.4 S…

硅像素传感器文献调研(九)3

欧洲X射线自由电子激光器抗辐射像素传感器的设计和初步试验 摘要 目前正在汉堡建造的欧洲X射线自由电子激光器的高强度和高重复率需要硅传感器&#xff0c;该传感器可以在高偏置电压下工作3年&#xff0c;承受高达1 GGy的X射线剂量。在AGIPD合作范围内&#xff0c;研究了由四家…

# [NOI2019] 斗主地 洛谷黑题题解

[NOI2019] 斗主地 题目背景 时限 4 秒 内存 512MB 题目描述 小 S 在和小 F 玩一个叫“斗地主”的游戏。 可怜的小 S 发现自己打牌并打不过小 F&#xff0c;所以他想要在洗牌环节动动手脚。 一副牌一共有 n n n 张牌&#xff0c;从上到下依次标号为 1 ∼ n 1 \sim n 1∼…

如何在Shopee平台上进行选品 :充分利用渠道获取灵感和数据支持

在Shopee平台上进行选品是一个关键的决策过程&#xff0c;它直接影响到卖家的销售业绩和店铺的发展。为了帮助卖家更好地进行选品&#xff0c;Shopee提供了多种渠道来获取灵感和数据支持。下面将介绍一些主要的选品渠道以及如何利用它们来进行选品。 先给大家推荐一款shopee知…