Pytorch量化之Post Train Static Quantization(训练后静态量化)

使用Pytorch训练出的模型权重为fp32,部署时,为了加快速度,一般会将模型量化至int8。与fp32相比,int8模型的大小为原来的1/4, 速度为2~4倍。
Pytorch支持三种量化方式:

  • 动态量化(Dynamic Quantization): 只量化权重,激活在推理过程中进行量化
  • 静态量化(Static Quantization): 量化权重和激活
  • 量化感知训练(Quantization Aware Training,QAT): 插入量化算子后进行训练,主要在静态量化精度不满足需求时进行。
    大多数情况下,我们只需要进行静态量化,少数情况下在量化感知训练不满足时使用QAT进行微调。所以本篇只重点讲静态量化,并且理论部分先略过(后面再专门总结),只关注实操。
    注:下面的代码是在pytorch1.10下,后面Pytorch对量化的接口有调整
    官方文档:Quantization — PyTorch 1.10 documentation

动态模式(Eager Mode)与静态模式(fx graph)

Pytorch支持用2种方式量化,一种是动态图模式,也是我们日常使用Pytorch训练所使用的方式,使用这种方式量化需要自己手动修改网络结构,在支持量化的算子前、后插入量化节点,优点是方便调试。静态模式则是由pytorch自动在计算图中插入量化节点,不需要手动修改网络。
网络上大部分的教程都是基于静态模式,这种方式比较大的问题就是需要手动修改网络结构,官方教程里的网络是属于demo型, 其中的QuantStub和DeQuantStub就分别是量化和反量化的节点:

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):def __init__(self):super(M, self).__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.relu = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.quantization.DeQuantStub()def forward(self, x):# manually specify where tensors will be converted from floating# point to quantized in the quantized modelx = self.quant(x)x = self.conv(x)x = self.relu(x)# manually specify where tensors will be converted from quantized# to floating point in the quantized modelx = self.dequant(x)return x

Pytorch对于很多网络层是不支持量化的(比如很常用的Prelu),如果我们用这种方式,我们就必须在这些不支持的层前面插入DeQuantStub,然后在支持的层前面插入QuantStub。笔者体验下来,体验很差,个人觉得不太实用,会破坏原来的网络结构。
而静态图模式,我们只需要调用Pytorch提供的接口将原模型转换一下即可,不需要修改原来的网络结构文件,个人认为实用性更强。
image.png

静态模式量化

1. 载入fp32模型,并转成fx graph

其中量化参数有‘fbgemm’和‘qnnpack’两种,前者在x86运行,后者在arm运行。

model_fp32 = torch.load(xxx)
model_fp32_quantize = copy.deepcopy(model_fp32)
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
model_fp32_quantize.eval()
# preparemodel_prepared = quantize_fx.prepare_fx(model_fp32_quantize, qconfig_dict)
model_prepared.eval()

2.读取量化数据,标定(Calibration)量化参数

标定的过程就是使用模型推理量化图片,然后统计权重和激活分布,从而得到量化参数。量化图片一般来源于训练集(几百张左右,根据测试情况调整)。量化图片可以通过Pytorch的Dataloader读取,也可以直接自行实现读图片然后送入网络。

### 使用dataloader读取
for i, (data, label) in enumerate(train_loader):data = data.to(torch.device("cpu:0"))outputs = model_prepared(data)print("calibrating {}".format(i))if i > 1000:break

3. 转换为量化模型并保存

quantized_model = quantize_fx.convert_fx(model_prepared)
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")

速度测试

量化后的模型使用方法与fp32模型一样:

import torch
import cv2
import numpy as np
torch.set_num_threads(1)fused_model = torch.jit.load("jit_model.pt")
fused_model.eval()
fused_model.to(torch.device("cpu:0"))img = cv2.imread("./1.png")
img_fp32 = img.astype(np.float32)
img_fp32 = (img_fp32-127.5) / 127.5
input = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()def speed_test(model, input):# warm upfor i in range(10):model(input)import timestart = time.time()for i in range(100):model(input)end = time.time()print("model time: ", (end-start)/100)time.sleep(10)# quantized model
quantized_model= torch.jit.load("quantized_model.pt")
quantized_model.eval()
quantized_model.to(torch.device("cpu:0"))speed_test(fused_model, input)
speed_test(quantized_model, input)

实测fp32模型单核运行120ms, 量化后47ms

结语

本文介绍了fx graph模式下的Pytorch的PTSQ方法,并实测了一个模型,效果还比较不错。
1_995567224_161_79_3_732056265_62005da0d7c1b531a6cf91ea587d312e.jpg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/30107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM 垃圾回收

垃圾回收算法 标记-清除算法(Mark and Sweep) 标记-清除算法分为两个阶段。在标记阶段,垃圾收集器会标记所有活动对象;在清除阶段,垃圾收集器会清除所有未标记的对象。标记-清除算法存在的问题是会产生内存碎片&#…

vue 实现粒子特效(vue-particles)

vue 实现粒子特效(vue-particles) npm install vue-particles --savemain.js引入 import VueParticles from vue-particles Vue.use(VueParticles)登录页面使用 login.vue <template><div class"box"><vue-particlescolor"#409EFF":part…

nvm安装以及使用

注意事项&#xff1a; 安装前需要卸载原有的node&#xff0c;卸载干净后cmd输入node -v查看&#xff1b; 一&#xff0c;下载nvm 下载&#xff1a;https://github.com/coreybutler/nvm-windows/releases 选择第四个 “nvm-setup.zip”&#xff1b; 二&#xff0c;安装 1&…

IAR目标代码4字节对齐

向工程添加文件 eof.c : // 文件头 #if defined(__CC_ARM) // MDK // uint32_t g_update_flag[2] __attribute__((zero_init, at(0x1000FFF0)));const unsigned long gc_eof __attribute__((used)) 0xFFFFFFFFul; #elif defined(__ICCARM__) // IAR__root const unsigned…

CSS transform:rotate;无效问题

CSS设置旋转 transform:rotate无效。 今天遇到一个奇怪的问题&#xff0c;CSS给 icon图标设置一个hover 旋转180deg的效果&#xff0c;不生效。 一度任务样式被覆盖了&#xff0c;样式不生效没选中元素的class。但是设置hover改变大小是生效的。奇怪了&#xff1f; 为什么会无…

分布式 - 消息队列Kafka:Kafka生产者发送消息的分区策略

文章目录 1. PartitionInfo 分区源码2. Partitioner 分区器接口源码3. 自定义分区策略4. 轮询策略 RoundRobinPartitioner5. 黏性分区策略 UniformStickyPartitioner6. hash分区策略7. 默认分区策略 DefaultPartitioner 分区的作用就是提供负载均衡的能力&#xff0c;或者说对数…

ArcGIS Pro实践技术应用暨基础入门、制图、空间分析、影像分析、三维建模、空间统计分析与建模、python融合、案例应用

GIS是利用电子计算机及其外部设备&#xff0c;采集、存储、分析和描述整个或部分地球表面与空间信息系统。简单地讲&#xff0c;它是在一定的地域内&#xff0c;将地理空间信息和 一些与该地域地理信息相关的属性信息结合起来&#xff0c;达到对地理和属性信息的综合管理。GIS的…

【香瓜说职场】谈谈如何提高赚到钱的概率(2020.01.05)

一、钱从哪里来&#xff1f; 要想学会赚钱&#xff0c;首先要想明白钱从哪里来&#xff1f; 我们常常听到一句话——客户是上帝。 作为工程师的我们&#xff0c;觉得这句话听起来没毛病&#xff0c;但总隐隐约约觉得客户高不高兴跟我们一毛钱关系都没有&#xff0c;我们照样拿着…

数字化时代,如何做好用户体验与应用性能管理

引言 随着数字化时代的到来&#xff0c;各个行业的应用系统从传统私有化部署逐渐转向公有云、行业云、微服务&#xff0c;这种变迁给运维部门和应用部门均带来了较大的挑战。基于当前企业 IT 运维均为多部门负责&#xff0c;且使用多种运维工具&#xff0c;因此&#xff0c;当…

hacksudo3 通关详解

环境配置 一开始桥接错网卡了 搞了半天 改回来就行了 信息收集 漏洞发现 扫个目录 大概看了一眼没什么有用的信息 然后对着login.php跑了一下弱口令 sqlmap 都没跑出来 那么利用点应该不在这 考虑到之前有过dirsearch字典太小扫不到东西的经历 换个gobuster扫一下 先看看g…

Android界面设计与用户体验

Android界面设计与用户体验 1. 引言 在如今竞争激烈的移动应用市场&#xff0c;提供优秀的用户体验成为了应用开发的关键要素。无论应用功能多么强大&#xff0c;如果用户界面设计不合理&#xff0c;用户体验不佳&#xff0c;很可能会导致用户流失。因此&#xff0c;在Androi…

移动技术相关基本概念

信息网络隔离装置 一种能够保障企业信息网络安全的高级网络设备&#xff0c;主要作用是隔离内外网&#xff0c;阻隔外界攻击&#xff0c;保护企业网络不遭受黑客攻击、木马病毒入侵、信息泄露等安全威胁。同时还能对企业内部的流量进行监视&#xff0c;保护企业敏感数据不被内…

Flink源码之JobManager启动流程

从启动命令flink-daemon.sh中可以看出StandaloneSession入口类为org.apache.flink.runtime.entrypoint.StandaloneSessionClusterEntrypoint, 从该类的main方法会进入ClusterEntrypoint::runCluster中, 该方法中会创建出主要服务和组件。 StandaloneSessionClusterEntrypoint:…

博客项目(Spring Boot)

1.需求分析 注册功能&#xff08;添加用户操纵&#xff09;登录功能&#xff08;查询操作)我的文章列表页&#xff08;查询我的文章|文章修改|文章详情|文章删除&#xff09;博客编辑页&#xff08;添加文章操作&#xff09;所有人博客列表&#xff08;带分页功能&#xff09;…

FPGA外部触发信号毛刺产生及滤波

1、背景 最近在某个项目中&#xff0c;遇到输入给FPGA管脚的外部触发信号因为有毛刺产生&#xff0c;导致FPGA接收到的外部触发信号数量多于实际值。比如&#xff1a;用某个信号源产生1000个外部触发信号&#xff08;上升沿触发方式&#xff09;给到FPGA输入IO&#xff0c;实际…

冠达管理:股票注册制通俗理解?

目前我国A股商场正在进行股票注册制变革&#xff0c;相较之前的发行准则&#xff0c;股票注册制在理念上更为商场化&#xff0c;这意味着公司发行股票的门槛将下降&#xff0c;公司数量将添加&#xff0c;而股票流通的方式也将有所改变。那么股票注册制指的是什么&#xff0c;它…

centos按用户保存历史执行命令

centos7 按用户记录历史命令的方法 在/etc/profile文件中添加以下代码。 添加完成后执行source /etc/profile 用户重新登录即可发现history被清空了。这时可以去看/usr/share/.history文件夹&#xff0c;该文件夹保存了所有用户每次登录所执行过的的操作记录。 文件路径为 /usr…

使用vscode远程登录以及本地使用的配置(插件推荐)

1、远程登陆ssh 1.1打开vscode插件商店&#xff0c;安装remote-ssh插件 远程ssh添加第三方插件&#xff1a;vscode下链接远程服务器安装插件失败、速度慢等解决方法_vscode远程安装不上扩展_Emphatic的博客-CSDN博客 转到定义&#xff0c;选中代码->鼠标右键->转到定义…

2023/08/05【网络课程总结】

1. 查看git拉取记录 git reflog --dateiso|grep pull2. TCP/IP和OSI七层参考模型 3. DNS域名解析 4. 预检请求OPTIONS 5. 渲染进程的回流(reflow)和重绘(repaint) 6. V8解析JavaScript 7. CDN负载均衡的简单理解 8. 重学Ajax 重学Ajax满神 9. 对于XML的理解 大白话叙述XML是…

JavaScript中的几种常用循环方式对比

JavaScript中的几种循环方式 1. for 循环1.1 使用方式1.2 不支持遍历对象&#xff08;但不会报错&#xff09; 2. for-of 循环2.1 使用方式2.2 for-of 和 for 循环比较&#xff08;不允许修改原数组元素&#xff09;2.2.1 相同点2.2.1.2 都可以遍历数组&#xff0c;但不允许遍历…