深度学习模型部署(番外4)模型量化方案及实战

量化方案

根据量化的时机可以分为训练时量化QAT和训练后量化PTQ。二者的工作流程图如下:
QAT量化流程
QAT量化流程图

PTQ量化流程
PTQ量化流程图

其中的CLE为层间放缩平衡Cross-Layer Equalization,在使用per_tensor粒度进行量化时,同一个tensor中可能数据不平衡情况很严重,尤其是在深度可分离卷积中,非常影响量化性能,所以为了解决这个问题提出的,大致可以理解为对于前后两层,前面一层的参数放大s倍,后面一层的参数缩小s倍,那么最后的输出结果不会有太大变化。关于CLE后面在量化性能调优的blog里面讲

PTQ

其中PTQ又分为PTDQ和PTSQ两种方案,二者的主要区别是对于activation的处理不同。具体区别如下:

PTDQ

Post Training Dynamic Quantization,训练后动态量化,也就是在模型训练收敛后进行量化,weight进行量化,而activation则是在推理过程中进行动态计算scale和offset,进行动态量化。
这种方案是最简单的量化方案,主要用于模型参数加载较为费时的模型,如LSTM和transformer。
具体的参数类型如下:

# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32/
linear_weight_fp32# dynamically quantized model
# 只有weight变成了int8,activation还是fp32
previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32/linear_weight_int8

pytorch demo代码如下:

import torch# define a floating point model
class M(torch.nn.Module):def __init__(self):super().__init__()self.fc = torch.nn.Linear(4, 4)def forward(self, x):x = self.fc(x)return xmodel_fp32 = M()
# create a quantized model instance
model_int8 = torch.ao.quantization.quantize_dynamic(model_fp32,  # the original model{torch.nn.Linear},  # 自己选择要进行量化的层dtype=torch.qint8)  # 量化后参数的存储类型# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)

PTSQ

Post Training Static Quantization,训练后静态量化。
与PTDQ的不同在于,PTSQ是把weight和activation的量化都提前做好,等到推理时不需要涉及到任何计算scale和offset的工作。对于activation的量化,PTSQ会尽量把它融合到前一层中,跟一层的layer一块儿量化。
这种方案适合读取参数和计算都比较耗时的模型,读取参数时间和计算时间都减少了。但是需要采用典型的数据进行校准,用于提前计算activation的scale和offset。
pytorch中PTSQ的方案需要自己定义量化器和解量化器,并且自己定义哪里开始量化,哪几层进行融合。

# original model
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32/linear_weight_fp32# statically quantized model
# 将linear和activation融合到了一块儿
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8/linear_weight_int8

pytorch代码如下:

import torch# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):def __init__(self):super().__init__()self.quant = torch.ao.quantization.QuantStub() # 量化器self.conv = torch.nn.Conv2d(1, 1, 1)self.relu = torch.nn.ReLU()self.dequant = torch.ao.quantization.DeQuantStub() # 解量化器def forward(self, x):x = self.quant(x) # 将x量化,由FP32变成int8x = self.conv(x)x = self.relu(x)x = self.dequant(x) # 解量化,由int8变回FP32return x# create a model instance
model_fp32 = M()model_fp32.eval()model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)input_fp32 = torch.randn(4, 1, 4, 4) //典型数据
model_fp32_prepared(input_fp32) //校准,计算scale和offsetmodel_int8 = torch.ao.quantization.convert(model_fp32_prepared) //模型转化res = model_int8(input_fp32) //推理

QAT

边量化边训练,训练过程还是浮点数运算,不过会进行量化模拟,模拟量化以查看效果,本来训练是使loss最小,QAT情况下的训练是:使量化模拟后的模型loss最小。

关于量化模拟在之前提到过,详见:tensor量化
加入量化器后出现的一个问题就是:反向传播的时候,量化模拟器怎么求导?量化模拟器求导的结果要么是0要么就是不存在,所以这里引入直通估计,也就是假设四舍五入函数 r o u n d ( x ) round(x) round(x)的导数为1,也就是不影响求导,最终的反向传播效果如下:在这里插入图片描述

量化前后数据流图:

# original model
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32/linear_weight_fp32# 量化模拟,在其中加入了很多量化器,为了模拟量化产生的精度损失,数据依旧使fp32类型
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32/linear_weight_fp32 -- fq# 真实量化过的模型,全部都是int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8/linear_weight_int8

demo代码:

import torch# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):def __init__(self):super().__init__()self.quant = torch.ao.quantization.QuantStub() //量化器self.conv = torch.nn.Conv2d(1, 1, 1)self.bn = torch.nn.BatchNorm2d(1)self.relu = torch.nn.ReLU()self.dequant = torch.ao.quantization.DeQuantStub() //解量化器def forward(self, x):x = self.quant(x)x = self.conv(x)x = self.bn(x)x = self.relu(x)x = self.dequant(x)return xdef training_loop(model):pass# create a model instance
model_fp32 = M()
model_fp32.eval()model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')# 算子融合,将多个算子融合为一个
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,[['conv', 'bn', 'relu']])# 准备量化
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())# 训练
training_loop(model_fp32_prepared)model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)res = model_int8(input_fp32)

pytorch对于量化的实现

量化算法支持

回归到最底层的tensor量化,有两种量化粒度:per_tensor和per_channel,两种量化方法:对称量化,非对称量化。一共2*2=4种选择,所以pytorch也提供了4种方法:

import torch
torch.per_tensor_affine # 非对称
torch.per_tensor_symmetric # 对称
torch.per_channel_affine
torch.per_channel_symmetric

量化不仅仅需要进行模型融合,加入量化算法支持,还需要对tensor进行修改,原本的tensor只需要存储数据,现在对于量化过的INT8 tensor,还需要存储scale和offset。所以pytorch针对量化后的tensor提供了新的类型:

import torch
torch.quint8
torch.qint8
torch.qint32
torch.float16

关于config,类型转化等,详见官方文档:官方文档

如果觉得有帮助,请点赞收藏+关注,thanks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/724864.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

upload-Labs靶场“11-15”关通关教程

君衍. 一、第十一关 %00截断GET上传1、源码分析2、%00截断GET上传 二、第十二关 %00截断POST上传1、源码分析2、%00截断POST上传 三、第十三关 文件头检测绕过1、源码分析2、文件头检测绕过 四、第十四关 图片检测绕过上传1、源码分析2、图片马绕过上传 五、第十五关 图片检测绕…

华中某科技大学校园网疑似dns劫持的解决方法

问题 在校园网ping xxx.ddns.net,域名解析失败 使用热点ping xxx.ddns.net,可以ping通 尝试设置windows dns首选dns为114.114.114.114,重新ping,仍然域名解析失败 猜测【校园网可能劫持dns请求】 解决方法 使用加密的dns请求…

[LeetBook]【学习日记】有效数字——状态机

题目 有效数字 有效数字(按顺序)可以分成以下几个部分: 若干空格一个小数或者整数(可选)一个’e’或’E’,后面跟着一个整数若干空格 小数(按顺序)可以分成以下几个部分&#xff1a…

【考研数学】基础660太难了?一个办法搞定660

觉得题目太难,大概率是题目超出了自己当前的水平 题型没见过,或者太复杂,属于跳级学习了,正确的思路就是回归到自己的水平线,题目略难即可。 这样做题的话,大部分题目涉及的点不会超出自己的能力范围&…

数据结构(七)——线性表的基本操作

🧑个人简介:大家好,我是尘觉,希望我的文章可以帮助到大家,您的满意是我的动力😉 在csdn获奖荣誉: 🏆csdn城市之星2名 ⁣⁣⁣⁣ ⁣⁣⁣⁣ ⁣⁣⁣⁣ ⁣⁣⁣⁣ ⁣⁣⁣⁣ ⁣⁣⁣⁣ ⁣⁣⁣⁣ …

vue3学习(续篇)

vue3学习(续篇) 默认有vue3基础并有一定python编程经验。 chrome浏览器安装vue.js devtools拓展。 文章目录 vue3学习(续篇)1. element-ui使用2. axios 网络请求1. 简介2. 操作 3. flask-cors解决跨域问题1. 简介2. 操作 4. 前端路由 vue-router1. 简单使用2. 配置路径别名和…

好利来做宠物蛋糕,为啥品牌争相入局宠物赛道?

2月20日,好利来宣布推出宠物烘焙品牌「Holiland Pet」,正式进军宠物烘焙市场。作为首个入局宠物烘焙领域的国内食品品牌,一经推出,就面临着各种争议,不管大众看法如何,好利来进军宠物市场,也让宠…

后台组件-语言包

<groupId>org.qlm</groupId><artifactId>qlm-language</artifactId><version>1.0-SNAPSHOT</version> 平台提供多语言支持&#xff0c;以上为语言包&#xff0c;提供后台多语言支持。首批实现&#xff1a; public class LanguageConstan…

Git快速上手二

对Git命令的深入理解快速上手Git&#xff08;包含提交至GitHub和Gitee&#xff09;-CSDN博客 1.5 分支操作 1.5.1 分支原理 系统上线后,又要修改bug,又要开发新的功能。 由于新功能没有开发完,所以需要建立分支,一边修改bug,一边开发新功能,最终合并. 1.5.2 分支实操 创建…

Java基于微信小程序的旅游出行必备小程序,附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

第三百八十六回

文章目录 概念介绍使用方法示例代码 我们在上一章回中介绍了Snackbar Widget相关的内容,本章回中将介绍TimePickerDialog Widget.闲话休提&#xff0c;让我们一起Talk Flutter吧。 概念介绍 我们在这里说的TimePickerDialog是一种弹出窗口&#xff0c;只不过窗口的内容固定显示…

18.网络游戏逆向分析与漏洞攻防-网络通信数据包分析工具-数据分析工具数据与消息配置的实现

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 上一个内容&#xff1a;17.数据分析工具配置功能的实现 码云地址&#xff08;master 分支&#xff09;&#xff1a;https://gitee.com/dye_your_fingers/titan…

于建筑外窗遮阳系数测试的太阳光模拟器模拟太阳光照射房屋视频

太阳光模拟器是一种用于测试建筑外窗遮阳系数的高科技设备。它能够模拟太阳光照射房屋的情景&#xff0c;帮助建筑师和设计师更好地了解建筑外窗的遮阳性能&#xff0c;从而提高建筑的能源效率和舒适度。 这种模拟器的工作原理非常简单&#xff0c;它通过使用高亮度的光源和精密…

Positional Encoding 位置编码

Positional Encoding 位置编码 flyfish Transformer模型没有使用循环神经网络&#xff0c;无法从序列中学习到位置信息&#xff0c;并且它是并行结构&#xff0c;不是按位置来处理序列的&#xff0c;所以为输入序列加入了位置编码&#xff0c;将每个词的位置加入到了词向量中…

Netty之WebSocket协议开发

一、WebSocket产生背景 在传统的Web通信中&#xff0c;浏览器是基于请求--响应模式。这种方式的缺点是&#xff0c;浏览器必须始终主动发起请求才能获取更新的数据&#xff0c;而且每次请求都需要经过HTTP的握手和头部信息的传输&#xff0c;造成了较大的网络开销。如果客户端…

爆肝!Claude3与ChatGPT-4到底谁厉害,看完你就知道了!

前言&#xff1a; 相信大家在pyq都被这张图片刷屏了把~ 昨天&#xff0c;为大家介绍了一下什么是Claude&#xff0c;今天咱终于弄到号了&#xff08;再被ban了3个号之后终于是成功的登上去了&#xff0c;如果各位看官觉得咱文章写的不错&#xff0c;麻烦点个小小的关注~你们的…

【详识C语言】自定义类型之三:联合

本章重点 联合 联合类型的定义 联合的特点 联合大小的计算 联合&#xff08;共用体&#xff09; 联合类型的定义 联合也是一种特殊的自定义类型 这种类型定义的变量也包含一系列的成员&#xff0c;特征是这些成员公用同一块空间&#xff08;所以联合也叫共用体&#xff09;…

mysql 数据库查询 查询字段用逗号隔开 关联另一个表并显示

文章目录 问题描述解决方案 问题描述 如下如所示&#xff1a; 表一&#xff1a;wechat_dynamically_config表&#xff0c;重点字段&#xff1a;wechat_object 表二&#xff1a;wechat_object表&#xff0c;重点字段&#xff1a;wxid 需求&#xff1a;根据wechat_dynamically_…

模仿Gitee实现站外链接跳转时进行确认

概述 如Gitee等网站&#xff0c;在有外部链接的时候如果不是同域则会出现一个确认页面。本文就带你看看这个功能应该如何实现。 效果 实现 1. 实现思路 将打开链接作为参数传递给一个中间页面&#xff0c;在页面加载的时候判断链接的域名和当前网站是否同域&#xff0c;同域…

Redis线程模型解析

引言 Redis是一个高性能的键值对&#xff08;key-value&#xff09;内存数据库&#xff0c;以其卓越的读写速度和灵活的数据类型而广受欢迎。在Redis 6.0之前的版本中&#xff0c;它采用的是一种独特的单线程模型来处理客户端的请求。尽管单线程在概念上似乎限制了其扩展性和并…