2.6、微调算法

前言

在 PPQ 中我们目前提供两种不同的算法帮助你微调网络
这些算法将使用 calibration dataset 中的数据,对网络权重展开重训练

    1. 经过训练的网络不保证中间结果与原来能够对齐,在进行误差分析时你需要注意这一点
    1. 在训练中使用 with ENABLE_CUDA_KERNEL(): 子句将显著加速训练过程
    1. 训练过程的缓存数据将被贮存在 gpu 上,这可能导致你显存溢出,你可以修改参数将缓存设备改为 cpu

code

from typing import Iterableimport torch
import torchvisionfrom ppq import (QuantizationSettingFactory, TargetPlatform,graphwise_error_analyse)
from ppq.api import QuantizationSettingFactory, quantize_torch_model
from ppq.api.interface import ENABLE_CUDA_KERNEL
from ppq.executor.torch import TorchExecutor# ------------------------------------------------------------
# 在 PPQ 中我们目前提供两种不同的算法帮助你微调网络
# 这些算法将使用 calibration dataset 中的数据,对网络权重展开重训练
# 1. 经过训练的网络不保证中间结果与原来能够对齐,在进行误差分析时你需要注意这一点
# 2. 在训练中使用 with ENABLE_CUDA_KERNEL(): 子句将显著加速训练过程
# 3. 训练过程的缓存数据将被贮存在 gpu 上,这可能导致你显存溢出,你可以修改参数将缓存设备改为 cpu
# ------------------------------------------------------------BATCHSIZE   = 32
INPUT_SHAPE = [BATCHSIZE, 3, 224, 224]
DEVICE      = 'cuda'
PLATFORM    = TargetPlatform.PPL_CUDA_INT8def load_calibration_dataset() -> Iterable:# ------------------------------------------------------------# 让我们从创建 calibration 数据开始做起, PPQ 需要你送入 32 ~ 1024 个样本数据作为校准数据集# 它们应该尽可能服从真实样本的分布,量化过程如同训练过程一样存在可能的过拟合问题# 你应当保证校准数据是经过正确预处理的、有代表性的数据,否则量化将会失败;校准数据不需要标签;数据集不能乱序# ------------------------------------------------------------return [torch.rand(size=INPUT_SHAPE) for _ in range(32)]
CALIBRATION = load_calibration_dataset()def collate_fn(batch: torch.Tensor) -> torch.Tensor:return batch.to(DEVICE)# ------------------------------------------------------------
# 我们使用 mobilenet v2 作为一个样例模型
# PPQ 将会使用 torch.onnx.export 函数 把 pytorch 的模型转换为 onnx 模型
# 对于复杂的 pytorch 模型而言,你或许需要自己完成 pytorch 模型到 onnx 的转换过程
# ------------------------------------------------------------
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)# ------------------------------------------------------------
# PPQ 提供基于 LSQ 的网络微调过程,这是推荐的做法
# 你将使用 Quant Setting 来调用微调过程,并调整微调参数
# ------------------------------------------------------------
QSetting = QuantizationSettingFactory.default_setting()
QSetting.lsq_optimization                            = True
QSetting.lsq_optimization_setting.block_size         = 4
QSetting.lsq_optimization_setting.lr                 = 1e-5
QSetting.lsq_optimization_setting.gamma              = 0
QSetting.lsq_optimization_setting.is_scale_trainable = True
QSetting.lsq_optimization_setting.collecting_device  = 'cuda'# ------------------------------------------------------------
# 如果你使用 ENABLE_CUDA_KERNEL 方法
# PPQ 将会尝试编译自定义的高性能量化算子,这一过程需要编译环境的支持
# 如果你在编译过程中发生错误,你可以删除此处对于 ENABLE_CUDA_KERNEL 方法的调用
# 这将显著降低 PPQ 的运算速度;但即使你无法编译这些算子,你仍然可以使用 pytorch 的 gpu 算子完成量化
# ------------------------------------------------------------
with ENABLE_CUDA_KERNEL():quantized = quantize_torch_model(model=model, calib_dataloader=CALIBRATION,calib_steps=32, input_shape=INPUT_SHAPE,setting=QSetting, collate_fn=collate_fn, platform=PLATFORM,onnx_export_file='./model.onnx', device=DEVICE, verbose=0)# ------------------------------------------------------------# 当我们完成训练后,我们将调用 graphwise_error_analyse 方法分析网络误差# 经过训练的中间层误差可能很大,但这不是我们所关心的 —— 训练方法只优化最终输出的误差# 一个量化良好的网络,最后输出层的误差不应大于 10%# ------------------------------------------------------------graphwise_error_analyse(graph=quantized, running_device=DEVICE, dataloader=CALIBRATION,collate_fn=collate_fn)# ------------------------------------------------------------
# 下面我们向你展示另一种 PPQ 中提供的优化方法
# 在 PPQ 0.6.5 之后,我们将这部分扩展性的方法移出了 QuantizationSetting
# 现在,扩展性方法需要手动调用
# ------------------------------------------------------------
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)QSetting = QuantizationSettingFactory.default_setting()
# ------------------------------------------------------------
# baking_parameter 将会在网络量化之后,将网络中所有参数静态量化
# 参数静态量化将会显著提高 PPQ 的运行速度,但是一旦参数被静态量化,则其将无法被修改
# 也无法参与后续的训练过程
# ------------------------------------------------------------
QSetting.quantize_parameter_setting.baking_parameter = Falsewith ENABLE_CUDA_KERNEL():quantized = quantize_torch_model(model=model, calib_dataloader=CALIBRATION,calib_steps=32, input_shape=INPUT_SHAPE,setting=QSetting, collate_fn=collate_fn, platform=PLATFORM,onnx_export_file='./model.onnx', device=DEVICE, verbose=0)# ------------------------------------------------------------# 让我们手动调用 AdaroundPass 优化过程# 这一过程需要训练更多步数,同时你应当注意,训练过程应该放在网络量化过程之后# 并且不允许使用 QSetting.quantize_parameter_setting.baking_parameter = True# ------------------------------------------------------------from ppq.quantization.optim import AdaroundPass, ParameterBakingPassexecutor = TorchExecutor(graph=quantized, device=DEVICE)AdaroundPass(steps=5000).optimize(graph=quantized, dataloader=CALIBRATION, executor=executor, collate_fn=collate_fn)ParameterBakingPass().optimize(graph=quantized, dataloader=CALIBRATION, executor=executor, collate_fn=collate_fn)graphwise_error_analyse(graph=quantized, running_device=DEVICE, dataloader=CALIBRATION,collate_fn=collate_fn)
  • PPQ 提供基于 LSQ 的网络微调过程,这是推荐的做法
    将使用 Quant Setting 来调用微调过程,并调整微调参数
  • 另一种 PPQ 中提供的优化方法
    在 PPQ 0.6.5 之后,我们将这部分扩展性的方法移出了 QuantizationSetting
    现在,扩展性方法需要手动调用
    baking_parameter 将会在网络量化之后,将网络中所有参数静态量化
    参数静态量化将会显著提高 PPQ 的运行速度,但是一旦参数被静态量化,则其将无法被修改,也无法参与后续的训练过程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/123842.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu学习---跟着绍发学linux课程记录(第4部分)

第3部份内容记录在:Ubuntu学习—跟着绍发学linux课程记录(第3部分) 文章目录 14 ubuntu服务器上的java14.1 Java的安装14.2 运行java程序14.3 Java启动脚本 15 ubuntu服务器上的tomcat15.1 Tomcat服务器15.2 Tomcat的配置15.3 Tomcat启动日志…

加速度中标云尖信息「电子元器件商城」开发项目——加速度jsudo

深圳市加速度软件开发有限公司在电子元器件和工业品行业有着多年得商城开发经验,服务过半导体、元器件、工业品行业的多家上市公司或实力工厂。选择加速度合作的60%的客户,或多或少都有踩坑的经历,这一次他们在选择商城开发商的时候格外谨慎&…

Spring@Lazy是如何解决构造函数循环依赖问题

Spring实例化源码解析之循环依赖CircularReference这章的最后我们提了一个构造函数形成的循环依赖问题,本章就是讲解利用Lazy注解如何解决构造函数循环依赖和其原理。 准备工作 首先创建两个构造函数循环依赖的类,TestA和TestB,代码如下&am…

qt-C++笔记之在两个标签页中按行读取两个不同的文件并且滚动条自适应滚动范围高度

qt-C笔记之在两个标签页中按行读取两个不同的文件并且滚动条自适应滚动范围高度 code review! 文章目录 qt-C笔记之在两个标签页中按行读取两个不同的文件并且滚动条自适应滚动范围高度1.运行2.文件结构3.main.cc4.main.pro5.a.txt6.b.txt7.上述代码中QVBoxLayout&#xff0c…

世界前沿技术发展报告2023《世界航空技术发展报告》(四)无人机技术

(四)无人机技术 1.无人作战飞机1.1 美国空军披露可与下一代战斗机编组作战的协同式无人作战飞机项目1.2 俄罗斯无人作战飞机取得重要进展 2.支援保障无人机2.1 欧洲无人机项目通过首个里程碑2.2 美国海军继续开展MQ-25无人加油机测试工作 3.微小型无人机…

TextureView和SurfaceView

1、Surface Surface对应了一块屏幕的缓冲区,每一个window对应一个Surface,任何View都是画在Surface上的,传统的View共享一块屏幕缓冲区,所有的绘制都必须在UI线程上进行。 2、SurfaceView 顾名思义就是Surface的View,…

NSS刷题 js前端修改 os.path.join漏洞

打算刷一遍nssweb题(任重道远) 前面很简单 都是签到题 这里主要记录一下没想到的题目 [GDOUCTF 2023]hate eat snake 这里 是对js的处理 有弹窗 说明可能存在 alert 我们去看看js 这里进行了判断 如果 getScore>-0x1e9* 我们结合上面 我觉得是6…

数据结构 -- ArrayList与LinkedList的区别

一、二者的相同点 1,它们都是继承自List接口。 二、二者的区别 1,数据结构:ArrayList是(Array动态数组)的数据结构;而LinkedList是(Link双向链表)的数据结构。ArrayList 自由性较…

RabbitMQ基础

目录 RabbitMQ的可靠性投递 确保消息正确地发送至 RabbitMQ 确保消息接收方消费了消息 流程分析 1.生产者发送消息给Broker 2.交换机路由消息到队列 3.消息存储在队列 4.消费者订阅并消费消息 三个重要概念 RabbitMQ集群模式 RabbitMQ的可靠性投递 在 RabbitMQ 中&a…

BUUCTF qr 1

BUUCTF:https://buuoj.cn/challenges 题目描述: 这是一个二维码,谁用谁知道! 密文: 下载附件,得到一张二维码图片。 解题思路: 1、这是一道签到题,扫描二维码得到flag。 flag:…

一文了解Elasticsearch

数据分类 数据按数据结构分类主要有三种:结构化数据、半结构化数据和非结构化数据。 结构化数据 结构化数据具有明确定义数据模型和格式的数据类型。 特点: 数据具有固定的结构和模式。 数据项明确定义数据类型和长度。 适合用于数据查询、过滤和分…

377. 组合总和 Ⅳ 70.魔改爬楼梯

377. 组合总和 Ⅳ 题目: 给一个正整数数组和一个正整数目标值,数组的每个元素可取无限次,求总额达到目标值的最大排列数。 dp[j]含义: dp[j]:达到目标值j的整数组合数为dp[j] 递推公式: 求装满背包有几…

在CARLA中手动开车,添加双目相机stereo camera,激光雷达Lidar

CARLA的使用逻辑: 首先创建客户端 设置如果2秒没有从服务器返回任何内容,则终止 client carla.Client("127.0.0.1", 2000) client.set_timeout(2.0) 从客户端中get world world client.get_world() 设置setting并应用 这里使用固定时…

【C++的OpenCV】第十四课-OpenCV基础强化(三):Mat元素的访问之data和step属性

🎉🎉🎉 欢迎来到小白 p i a o 的学习空间! \color{red}{欢迎来到小白piao的学习空间!} 欢迎来到小白piao的学习空间!🎉🎉🎉 💖 C\Python所有的入门技术皆在 我…

【年终特惠】基于最新导则下生态环评报告编制技术暨报告篇、制图篇、指数篇、综合应用篇系统性实践技能提升

根据生态环评内容庞杂、综合性强的特点,依据生态环评最新导则,将内容分为4大篇章(报告篇、制图篇、指数篇、综合篇)、10大专题(生态环评报告编制、土地利用图的制作、植被类型及植被覆盖度图的制作、物种适宜生境分布图的制作、生物多样性测定、生物量及…

前端Vue页面中如何展示本地图片

<el-table :data"tableData" stripe style"width: 100%"><el-table-column prop"imgUrl" label"图片"><template v-slot"scope"><img :src "http://localhost:8888/image/ scope.row.imgUrl&qu…

R-FCN: Object Detection via Region-based Fully Convolutional Networks(2016.6)

文章目录 AbstractIntroduction当前最先进目标检测存在的问题针对上述问题&#xff0c;我们提出... Our approachOverviewBackbone architecturePosition-sensitive score maps & Position-sensitive RoI pooling Related WorkExperimentsConclusion 原文链接 源代码 Abstr…

飞天使-mysql8.0远程连接允许

mysql -u root -p 查看身份验证类型 mysql> use mysql; Database changed mysql> SELECT Host, User, plugin from user; ------------------------------------------------- | Host | User | plugin | ------------------------------------------------- | % | root …

Sass、Less和Stylus之间有什么主要的区别?

Sass、Less和Stylus是三种常见的CSS预处理器&#xff0c;它们在功能和语法上有一些区别。以下是它们之间的主要区别&#xff1a; 1&#xff1a;语法差异&#xff1a; Sass使用缩进的语法&#xff0c;使用类似于Python的缩进来表示嵌套规则和块级作用域。Less和Stylus使用类似…

大数据之LibrA数据库系统告警处理(ALM-12002 HA资源异常)

告警解释 HA软件周期性检测Manager的WebService浮动IP地址和数据库。当HA软件检测到浮动IP地址或数据库异常时&#xff0c;产生该告警。 当HA检测到浮动IP地址或数据库正常后&#xff0c;告警恢复。 告警属性 告警参数 对系统的影响 如果Manager的WebService浮动IP地址异常…