2024强化学习的结构化剪枝模型RL-Pruner原理及实践

[2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration

目录

  • [2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration
    • 一、论文说明
    • 二、原理
    • 三、实验与分析
      • 1、环境配置
        • 在Windows配置git bash链接conda环境
      • 2、项目代码运行
        • 1、训练预训练权重
        • 2、模型压缩
        • 3、模型验证
    • 四、总结

一、论文说明

论文标题:使用强化学习进行结构化剪枝用于卷积神经网路压缩和加速

机构:伊利诺伊大学厄巴纳-香槟分校

论文链接:https://arxiv.org/pdf/2411.06463

代码链接:https://github.com/Beryex/RLPruner-CNN

论文简介: 卷积神经网络(ConvolutionalNeural Networks, CNNs)近年来表现出卓越的性能。压缩这些模型不仅减少了存储需求,使其在边缘设备上的部署变得可行,还加速了推理,从而降低了延迟和计算成本。结构化剪枝,它在层级上去除过滤器,直接修改了模型架构。这种方法实现了更紧凑的架构,同时保持目标准确性,确保压缩模型具有较好的兼容性和硬件效率。所提方法基于一个关键观察:

  • 1、神经网络中不同层的过滤器对模型性能的重要性各不相同。

  • 2、当修剪的过滤器数量固定时,不同层之间的最佳修剪分配是不均匀的,以最小化性能损失

  • 3、对修剪敏感的层应该占据更小的修剪分配比例。

为了利用这一洞察,文中提出了RL-Pruner,它使用强化学习来学习最佳修剪分配。RL-Pruner可以自动提取输入模型中过滤器之间的依赖关系并执行修剪,无需特定于模型的修剪实现。在GoogleNet、ResNet和MobileNet 等模型上进行了实验,将所提方法与其他结构化剪枝方法进行了比较,以验证其有效性。

在这里插入图片描述

二、原理

RL-Pruner 首先在模型中的层之间构建依赖图,然后分几个步骤进行剪枝。在每个步骤中:1) 基于基础分布生成一个新的剪枝稀疏分布作为动作 ,这作为策略;2)根据相应的稀疏度,使用泰勒准则(Taylorcriterion)对每一层进行剪枝;3) 评估压缩后的模型以获得奖励,并将动作和奖励存储在回放(replay)经验池中。每个步骤后,基础分布根据经验池更新,如果计算资源足够,则对压缩模型应用后训练阶段,使用知识蒸馏(knowledge distillation),其中原始模型作为教师。具体框图如图2所示。

三、实验与分析

1、环境配置

实验平台及软件

  • Windows 10
  • git bash
  • conda环境

这里主要介绍如何在windows系统上让git bash链接conda环境。

在Windows配置git bash链接conda环境

由于工程代码中需要使用bash命令运行代码,因此需要保证git bash能调用conda环境运行对应的脚本文件。

C:\Users\username\.bashrc文件内设置conda.sh位置(文中示例为:D:\\Anaconda3\\etc\\profile.d\\conda.sh),并激活配置。在git bash界面输入具体命令如下:

echo "D:\\Anaconda3\\etc\\profile.d\\conda.sh" >> ~/.bashrc  
source ~/.bashrc

然后关闭git bash界面,再重新打开一个git bash界面,最后输入命令激活conda环境conda activate 虚拟环境名字。如果命令提示中出现如下图所示的字样,即为配置成功,否则根据提示的要求进行配置,比如输入conda init,重新打开一个新的git bash界面。

在这里插入图片描述

2、项目代码运行

克隆项目文件,具体命令如下:

git clone https://github.com/Beryex/RLPruner-CNN.git --depth 1
cd RLPruner-CNN

安装python第三方包,具体命令如下(如果之前有conda环境,可以不用进行下面这一步,等报错了再根据提示安装对应的包即可):

conda create -n RLPruner python=3.10 -y
conda activate RLPruner
pip install -r requirements.txt

官方代码提供了一步到位的运行脚本,从预训练模型、模型压缩到模型验证,仅需在命令行中输入如下代码:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

为了更好地了解每一步的设置,下面内容将分为预训练模型、模型压缩、模型验证三个步骤进行介绍。

1、训练预训练权重

训练模型得到对应的预训练权重,这里以resnet32googlenet为例,在git bash输入具体命令(默认使用cuda)如下:

./scripts/train.sh googlenet cifar100
./scripts/train.sh resnet32 cifar100

或者使用参考指定配置命令:

python -m train --model ${MODEL} --dataset ${DATASET} --device cuda \--output_dir ${PRETRAINED_MODEL_DIR} \--log_dir ${LOG}

其中,
${MODEL}代表backbone的类型([“vgg11”, “vgg13”, “vgg16”, “vgg19”, “resnet18”, “resnet34”, “resnet50”, “resnet101”, “resnet152”, “resnet8”, “resnet14”, “resnet20”, “resnet32”, “resnet44”, “resnet56”, “resnet110”, “densenet121”, “densenet161”, “densenet169”, “densenet201”, “mobilenetv3_small”, “mobilenetv3_large”, “googlenet”]);
${DATASET}代表数据集名称,如cifar10或者cifar100。
${PRETRAINED_MODEL_DIR}代表输出权重文件路径,默认在pretrained_model文件夹下;
${LOG}代表输出日志路径,默认在log文件夹下。

在CIFAR100数据集上训练resnet32的结果(最佳准确率:0.706)如下图所示。
在这里插入图片描述
在CIFAR100数据集上训练googlenet的结果(最佳准确率:0.774)如下图所示。
在这里插入图片描述

2、模型压缩

模型结构化剪枝这里以0.2的稀疏度,taylor剪枝策略和Q_FLOP_coef=0,Q_Para_coef=0的参数进行测试。在git bash输入具体命令(默认使用cuda)如下:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

同理,也可以使用参考指定配置命令:

python -m compress --model ${MODEL} --dataset ${DATASET} --device cuda \--sparsity ${SPARSITY} --prune_strategy ${prune_strategy} --ppo \--Q_FLOP_coef ${Q_FLOP_coef} --Q_Para_coef ${Q_Para_coef} \--pretrained_pth ${PRETRAINED_MODEL_PTH} \--compressed_dir ${COMPRESSED_MODEL_DIR} \--checkpoint_dir ${CKPT_DIR} \--log_dir ${LOG} --save_model

测试结果如下图所示:
在这里插入图片描述

3、模型验证

在数据集上验证模型的识别性能,在git bash输入具体命令(默认使用cuda)如下:

./scripts/evaluate.sh googlenet cifar100

同理,也可以使用参考指定配置命令:

python -m evaluate --model ${MODEL} --dataset ${DATASET} --device cuda \--pretrained_pth ${PRETRAINED_MODEL_PTH} \--compressed_pth ${COMPRESSED_MODEL_PTH} \--log_dir ${LOG}

测试结果如下图所示:
在这里插入图片描述

四、总结

本文提出了RL-Pruner,一种结构化剪枝方法,能够学习各层之间的最优稀疏性分布,并支持无模型特定修改的一般剪枝。希望所提方法能够认识到每一层对模型(model)性能的重要性不同,这将影响未来在神经网络压缩领域的工作,包括无结构剪枝和量化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61379.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【SpringCloud详细教程】-02-微服务环境搭建

精品专题: 01.《C语言从不挂科到高绩点》课程详细笔记 https://blog.csdn.net/yueyehuguang/category_12753294.html?spm1001.2014.3001.5482 02. 《SpringBoot详细教程》课程详细笔记 https://blog.csdn.net/yueyehuguang/category_12789841.html?spm1001.20…

BOM的详细讲解

BOM概述 BOM简介 BOM(browser Object)即浏览器对象模型,它提供了独立于内容而与浏览器窗口进行交互的对象,其核心对象是window。 BOM由一系列的对象构成,并且每个对象都提供了很多方法与属性 BOM缺乏标准&#xff…

使用脚本判断网络连接状态,并且添加对应路由

这个脚本通过不断检测有线网络和4G网络的连通性来动态调整默认路由。如果两个网络都可用,则优先使用4G网络。如果只有一个网络可用,则使用该网络。如果两个网络都不可用,则每秒钟检测一次,连续30次检测失败后重启设备。 #!/bin/b…

Jenkins下载安装、构建部署到linux远程启动运行

Jenkins详细教程 Winodws下载安装Jenkins一、Jenkins配置Plugins插件管理1、汉化插件2、Maven插件3、重启Jenkins:Restart Safely插件4、文件传输:Publish Over SSH5、gitee插件6、清理插件:workspace cleanup system系统配置1、Gitee配置2、…

Pandas-1:初识Pandas

第1章:初识Pandas 本章将带领读者初步了解Pandas库,介绍其基本概念、功能特点和安装方法,同时学习Pandas的核心数据结构:Series和DataFrame。通过本章的学习,您将为后续章节的深入学习打下坚实的基础。 1.1 什么是Pan…

android9-sdk-28源码替换为-Lineageos-9源码-android-studio-4.2调试LineageOS-16.0的view绘制流程

整体想法: 替换sdk-28源码中每一个x.java文件为指向软连接LineageOS-16.0对应的x.java 调试前奏(准备) android-studio-4.2并不像老版本android-studio那样容易替换api源文件路径 android-studio-4.2: 在Project Structure不能设置api(如28)的源码路径在x.class的反编译窗口…

数据分析——Python绘制实时的动态折线图

最近在做视觉应用开发,有个需求需要实时获取当前识别到的位姿点位是否有突变,从而确认是否是视觉算法的问题,发现Python的Matplotlib进行绘制比较方便。 目录 1.数据绘制2.绘制实时的动态折线图3.保存实时数据到CSV文件中 import matplotlib.…

Unity 使用 ExcelDataReader 读取Excel表

文章目录 1.下载NuGet包2.通过NuGet包获取dll3.将dll复制unity Plugins文件夹下4.代码获取Excel表内容 1.下载NuGet包 通过NuGet下载: ExcelDataReaderExcelDataReader.DataSet离线下载方法 2.通过NuGet包获取dll 根据编译时程序集找到dll位置,找到与…

【vmware+ubuntu16.04】ROS学习_博物馆仿真克隆ROS-Academy-for-Beginners软件包处理依赖报错问题

首先安装git 进入终端,输入sudo apt-get install git 安装后,创建一个工作空间名为tutorial_ws, 输入 mkdir tutorial_ws#创建工作空间 cd tutorial_ws#进入 mkdir src cd src git clone https://github.com/DroidAITech/ROS-Academy-for-Be…

九、FOC原理详解

1、FOC简介 FOC(field-oriented control)为磁场定向控制,又称为矢量控制(vectorcontrol),是目前无刷直流电机(BLDC)和永磁同步电机(PMSM)高效控制的最佳选择…

【MySQL】MySQL中的函数之JSON_KEYS

在 MySQL 中,JSON_KEYS() 函数用于获取 JSON 对象中的所有键名。这个函数非常有用,特别是在你需要知道 JSON 对象中包含哪些键时。下面是一些关于如何使用 JSON_KEYS() 的详细说明和示例。 基本语法 JSON_KEYS(json_doc [, path])json_doc: 要从中提取…

Linux的指令(三)

1.grep指令 功能: 在文件中搜索字符串,将找到的行打印出来 -i:忽略大小写的不同,所以大小写视为一样 -n:顺便输出行号 -v:反向选择,就是显示出没有你输入要搜索内容的内容 代码示例: roo…

2025蓝桥杯(单片机)备赛--扩展外设之DS1302的使用(九)

1.DS1302数据手册的使用 a. DS1302 features: 工作电压:2V-5.5V 通信协议:3线接口(CE、IO、SCLK) 计时:秒、分、小时、月日期、月、星期、年(闰年补偿器期至2100年) b.原理图接线说明&#xff…

Leetcode(滑动窗口习题思路总结,持续更新。。。)

讲解题目:长度最小的子数组 给定一个含有 n 个正整数的数组和一个正整数 target ,找出该数组中满足其和 ≥ target 的长度最小的连续子数组。如果不存在符合条件的连续子数组,返回 0。示例: 输入: target 7, nums [2,3,1,2,4,3] 输出: 2 解…

在CentOS中,通过nginx访问php

其实是nginx反向代理到php-fpm,就像nginx反向代理到tomcat。 1、安装PHP-FPM 1.1 安装 yum install php yum install php-fpm php-common 这里只安装了php-fpm,根据需要安装php模块,比如需要访问mysql则添加安装 php-mysqlnd。 1.2 启动…

Photino:通过.NET Core构建跨平台桌面应用程序,.net国产系统

一、Photino.NET简介: 最近发现了一个不错的框架 Photino.Net 一份代码运行,三个平台 windows max linux ,其中windows10,windows11,ubuntu 18.04,ubuntu 20.04 已测试均可以。mac 因为没有相关电脑没有测试。 github:https://github.com/t…

深度学习:神经网络的搭建

深度学习:神经网络的搭建 神经网络的搭建涉及多个步骤,从选择合适的网络架构到定义网络层、设置超参数以及最终的模型训练。下面我将详细介绍这些步骤,并提供一个具体的示例来展示如何使用PyTorch框架构建一个卷积神经网络(CNN&a…

编辑器vim 命令的学习

1.编辑器Vim 1.vim是一个专注的编辑器 2.是一个支持多模式的编辑器 1.1见一见: vim 的本质也是一条命令 退出来:-> Shift:q 先创建一个文件 再打开这个文件 进入后先按 I 然后就可以输入了 输入完后,保存退出 按Esc --> 来到最后一…

Ubuntu22.04LTS 部署前后端分离项目

一、安装mysql8.0 1. 安装mysql8.0 # 更新安装包管理工具 sudo apt-get update # 安装 mysql数据库,过程中的选项选择 y sudo apt-get install mysql-server # 启动mysql命令如下 (停止mysql的命令为:sudo service mysql stop&#xff0…

Python爬虫:如何从1688阿里巴巴获取公司信息

在当今的数字化时代,数据已成为企业决策和市场分析的重要资产。对于市场研究人员和企业分析师来说,能够快速获取和分析大量数据至关重要。阿里巴巴的1688.com作为中国最大的B2B电子商务平台之一,拥有海量的企业档案和产品信息。本文将介绍如何…