【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 前言
  • 卷积层剪枝
  • 总结


前言

深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经网络中的冗余连接(权重)或节点(神经元),从而实现模型的稀疏化。
深度学习剪枝(Pruning)具有以下几个好处:1. 模型压缩和存储节省;2. 计算资源节省;3. 加速推理速度;4. 防止过拟合。
“假剪枝”(Fake Pruning)是一种剪枝算法的称呼,它在剪枝过程中并不真正删除权重或节点,而是通过一些技巧将它们置零或禁用,以模拟剪枝的效果,不少优秀的论文就采用了"假剪枝"策略,尽管可以在一定程度上提高模型的推理速度,但假剪枝算法没有真正减少模型的大小,博主将通过讲解一个小案例,简洁易懂的说明一种对"假剪枝"卷积层进行真正的剪枝的的方法。


卷积层剪枝

可以先将最后的完整代码拷贝到自己的py文件中,然后按照博主的思路学习如何将置零卷积核进行真实剪枝:

  1. 初始化卷积层,并查看卷积层权重
    # 示例使用一个具有3个输入通道和5个输出通道的卷积层
    conv = nn.Conv2d(3, 5, 3)
    print("原始卷积层权重:")
    print(conv.weight.data)
    print(conv.weight.size())
    print("原始卷积层偏置:")
    print(conv.bias.data)
    print(conv.bias.size())
    
  2. 通过随机函数让部分卷积核权重置为0,模拟完成了假剪枝。
    # remove_zero_kernels方法内的代码
    weight = conv_layer.weight.data
    # 卷积核个数
    num_kernels = weight.size(0)
    # 随机对部分卷积置0
    pruned = torch.ones(num_kernels, 1, 1, 1)
    # 选择随着置0的卷积序号
    random_int = random.randint(1, num_kernels-1)
    for i in range(random_int):pruned[i, 0, 0, 0] = 0
    conv_layer.weight.data = weight * pruned
    weight = conv_layer.weight.data
    bias = conv_layer.bias.data
    
  3. 保存未被剪枝的卷积核的权重和偏置
    # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
    norms = torch.norm(weight.view(num_kernels, -1), dim=1)
    zero_kernel_indices = torch.nonzero(norms==0).squeeze()
    print(zero_kernel_indices)
    # 移除L2范数为零的卷积核
    new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
    new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
    
  4. 构建新的卷积层,用来替换此前的卷积层,完成置零卷积核的真实剪枝
    # 构建新的卷积层
    if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_bias
    else:new_conv_layer = conv_layer
    

完整代码

import torch
import torch.nn as nn
import randomdef remove_zero_kernels(conv_layer):# 卷积核权重weight = conv_layer.weight.data# 卷积核个数num_kernels = weight.size(0)# 随机对部分卷积置0pruned = torch.ones(num_kernels, 1, 1, 1)# 选择随着置0的卷积序号random_int = random.randint(1, num_kernels-1)for i in range(random_int):pruned[i, 0, 0, 0] = 0conv_layer.weight.data = weight * prunedweight = conv_layer.weight.databias = conv_layer.bias.data# 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了norms = torch.norm(weight.view(num_kernels, -1), dim=1)zero_kernel_indices = torch.nonzero(norms==0).squeeze()print(zero_kernel_indices)# 移除L2范数为零的卷积核new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])# 构建新的卷积层if zero_kernel_indices.numel() > 0:# 输入channelin_channels = weight.size(1)# 输出channelout_channels = new_weight.size(0)# 卷积核大小kernel_size = weight.size(2)# 步长stride = conv_layer.stridepadding = conv_layer.paddingdilation = conv_layer.dilationgroups = conv_layer.groupsnew_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)new_conv_layer.weight.data = new_weightnew_conv_layer.bias.data = new_biaselse:new_conv_layer = conv_layerreturn new_conv_layer# 示例使用一个具有3个输入通道和5个输出通道的卷积层
conv = nn.Conv2d(3, 5, 3)
# print("原始卷积层权重:")
# print(conv.weight.data)
# print(conv.weight.size())
# print("原始卷积层偏置:")
# print(conv.bias.data)
# print(conv.bias.size())# 将置零的卷积核移除
new_conv = remove_zero_kernels(conv)
# print("原始卷积层权重:")
# print(new_conv.weight.data)
# print(new_conv.weight.size())
# print("原始卷积层偏置:")
# print(new_conv.bias.data)
# print(new_conv.bias.size())

总结

博主的思路就是用卷积层中保留的(未被剪枝)权重初始化一个新的卷积层,这样就将假剪枝的置零卷积核真实的除去,有没有研究这方面的读者可以给博主分享其他的方法,共同进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/129744.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

thinkphp的路径参数(RESTFul风格),把参数写在路径里

thinkphp官方文档 https://www.kancloud.cn/manual/thinkphp5_1/353969 有一个Blog控制器,里面的read方法是固定的,不能该 route.php里添加如下代码,访问 blog对应的就是 android/blog Route::resource(blog,android/blog);然后访问路径

设计模式——模板方法模式(Template Pattern)+ Spring相关源码

文章目录 一、模板方法模式定义二、例子2.1 菜鸟教程例子2.1.1 抽象类Game 定义了play方法的执行步骤。2.1.2 继承Game类并实现initialize、startPlay、endPlay方法。2.1.3 使用 2.2 JDK源码 —— Map 2.3 Spring源码 —— JdbcTemplate2.4 Spring源码 —— RestTemplate三、其…

NLP之Bert实现文本分类

文章目录 1. 代码展示2. 整体流程介绍3. 代码解读4. 报错解决4.1 解决思路4.2 解决方法 5. Bert介绍5.1 什么是BertBERT简介:BERT的核心思想:BERT的预训练策略:BERT的应用:为什么BERT如此受欢迎?总结: 1. 代…

windows使用YOLOv8训练自己的模型(0基础保姆级教学)

目录 前言 一、使用labelimg制作数据集 1.1、下载labelimg 1.2、安装库并启动labelimg 1.4、制作YOLO数据集 二、使用YOLOv8训练模型 2.1、下载库——ultralytics (记得换源) 2.2、数据模板下载 2.3、开始训练 1、启动train.py,进行…

QT+SQLite数据库配置和使用

一、简介 1.1 SQLite(sql)是一款开源轻量级的数据库软件,不需要server,可以集成在其他软件中,非常适合嵌入式系统。Qt5以上版本可以直接使用SQLite(Qt自带驱动)。 二、下载和配置 2.1 SQLite下载…

GitLab(2)——Docker方式安装Gitlab

目录 一、前言 二、安装Gitlab 1. 搜索gitlab-ce镜像 2. 下载镜像 3. 查看镜像 4. 提前创建挂载数据卷 5. 运行镜像 三、配置Gitlab文件 1. 配置容器中的/etc/gitlab/gitlab.rb文件 2. 重启容器 3. 登录Gitalb ① 查看初始root用户的密码 ② 访问gitlab地址&#…

微信小程序-form表单-获取用户输入文本框的值

微信小程序-form表单-获取用户输入文本框的值 data: {userName: ,userPwd:""},//获取用户输入的用户名 userNameInput:function(e) {this.setData({userName: e.detail.value}) }, passWdInput:function(e) {this.setData({userPwd: e.detail.value}) }, //获取用户输…

Pycharm出现的一些问题和解决办法

1.每次启动打开多个项目,速度很慢。改为每次启动询问打开哪个单一项目 Setting -> Appearance & Behavior -> System Settings -> Project -> 关闭Reopen projects on startop 2.一直显示《正在关闭项目closing project》,关不上 pycha…

Java后端开发——JDBC组件

JDBC(Java Database Connectivity)是Java SE平台的一种标准API,它提供了一种标准的方法来访问关系型数据库,使得Java程序能够与各种不同的数据库进行交互,这篇文章我们来进行实验体验一下。 自定义JDBC连接工具类 1.编…

【IDEA使用maven package时,出现依赖不存在以及无法从仓库获取本地依赖的问题】

Install Parent project C:\Users\lxh\.jdks\corretto-1.8.0_362\bin\java.exe -Dmaven.multiModuleProjectDirectoryD:\学习\projectFile\study\study_example_service "-Dmaven.homeD:\Program Files\JetBrains\IntelliJ IDEA2021\plugins\maven\lib\maven3" "…

devops

git/jenkins 版本控制系统 gitlab 代码开发完成> 运维部署上线> 监控性能> cicd流水线部署, git和版本控制系统 git rootserver02:~/1103# git log --oneline 07f230c (HEAD -> main) first commit rootserver02:~/1103# git log -p 07f230c co…

LLC讲解

【精选】开关电源-LLC基本原理_llc 开关电源-CSDN博客

操作系统的线程模型

操作系统的线程调度有几个重要的概念: 调度器(Thread Scheduler):内核通过操纵调度器对内核线程进行调度,并负责将线程的任务映射到各个处理器上内核线程(Kernel Level Thread):简称…

【GitLab CI/CD、SpringBoot、Docker】GitLab CI/CD 部署SpringBoot应用,部署方式Docker

介绍 本文件主要介绍如何将SpringBoot应用使用Docker方式部署,并用Gitlab CI/CD进行构建和部署。 环境准备 已安装Gitlab仓库已安装Gitlab Runner,并已注册到Gitlab和已实现基础的CI/CD使用创建Docker Hub仓库,教程中使用的是阿里云的Docker…

100量子比特启动实用化算力标准!玻色量子重磅发布相干光量子计算机

2023年5月16日,北京玻色量子科技有限公司(以下简称“玻色量子”)在北京正大中心成功召开了2023年首场新品发布会,重磅发布了自研100量子比特相干光量子计算机——“天工量子大脑”。 就在3个月前,因“天工量子大脑”在…

java pdf,word,ppt转图片

pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0…

【Orangepi Zero2 全志H616】资料及环境搭建

一、资料文档 二、MobaXterm远程连接工具 三、修改登录密码 四、修改内核日志等级 五、配置网络 六、SSH 访问 OrangePi ZERO 2 七、配置 vim 八、基于官方外设开发SDK 一、资料文档 官网资料下载 GitHub&#xff1a;新版本的 orangepi-build 源码 环境搭建&#xff1a;新手配…

vue+asp.net Web api前后端分离项目发布部署

一、前后端项目介绍 1.前端项目是使用vue脚手架进行创建的。 脚手架版本&#xff1a;vue/cli 5.0.8 编译器版本&#xff1a;vs code 1.82.2 2.后端是一个asp.net Core Web API 项目 后端框架版本&#xff1a;.NET 6.0 编译器版本&#xff1a;vs 2022 二、发布部署步骤 第…

安卓抓包之小黄鸟

下载安装 下载地址: https://download.csdn.net/download/yijianxiangde100/88496463 安装apk 即可。 证书配置:

【嵌入式】HC32F07X CAN通讯配置和使用配置不同缓冲器以连续发送

一 背景说明 使用小华&#xff08;华大&#xff09;的MCU HC32F07X实现 CAN 通讯配置和使用 二 原理分析 【1】CAN原理说明&#xff08;参考文章《CAN通信详解》&#xff09;&#xff1a; CAN是控制器局域网络(Controller Area Network, CAN)的简称&#xff0c;是一种能够实现…