Pytorch量化之Post Train Static Quantization(训练后静态量化)

使用Pytorch训练出的模型权重为fp32,部署时,为了加快速度,一般会将模型量化至int8。与fp32相比,int8模型的大小为原来的1/4, 速度为2~4倍。
Pytorch支持三种量化方式:

  • 动态量化(Dynamic Quantization): 只量化权重,激活在推理过程中进行量化
  • 静态量化(Static Quantization): 量化权重和激活
  • 量化感知训练(Quantization Aware Training,QAT): 插入量化算子后进行训练,主要在静态量化精度不满足需求时进行。
    大多数情况下,我们只需要进行静态量化,少数情况下在量化感知训练不满足时使用QAT进行微调。所以本篇只重点讲静态量化,并且理论部分先略过(后面再专门总结),只关注实操。
    注:下面的代码是在pytorch1.10下,后面Pytorch对量化的接口有调整
    官方文档:Quantization — PyTorch 1.10 documentation

动态模式(Eager Mode)与静态模式(fx graph)

Pytorch支持用2种方式量化,一种是动态图模式,也是我们日常使用Pytorch训练所使用的方式,使用这种方式量化需要自己手动修改网络结构,在支持量化的算子前、后插入量化节点,优点是方便调试。静态模式则是由pytorch自动在计算图中插入量化节点,不需要手动修改网络。
网络上大部分的教程都是基于静态模式,这种方式比较大的问题就是需要手动修改网络结构,官方教程里的网络是属于demo型, 其中的QuantStub和DeQuantStub就分别是量化和反量化的节点:

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):def __init__(self):super(M, self).__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.relu = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.quantization.DeQuantStub()def forward(self, x):# manually specify where tensors will be converted from floating# point to quantized in the quantized modelx = self.quant(x)x = self.conv(x)x = self.relu(x)# manually specify where tensors will be converted from quantized# to floating point in the quantized modelx = self.dequant(x)return x

Pytorch对于很多网络层是不支持量化的(比如很常用的Prelu),如果我们用这种方式,我们就必须在这些不支持的层前面插入DeQuantStub,然后在支持的层前面插入QuantStub。笔者体验下来,体验很差,个人觉得不太实用,会破坏原来的网络结构。
而静态图模式,我们只需要调用Pytorch提供的接口将原模型转换一下即可,不需要修改原来的网络结构文件,个人认为实用性更强。
image.png

静态模式量化

1. 载入fp32模型,并转成fx graph

其中量化参数有‘fbgemm’和‘qnnpack’两种,前者在x86运行,后者在arm运行。

model_fp32 = torch.load(xxx)
model_fp32_quantize = copy.deepcopy(model_fp32)
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
model_fp32_quantize.eval()
# preparemodel_prepared = quantize_fx.prepare_fx(model_fp32_quantize, qconfig_dict)
model_prepared.eval()

2.读取量化数据,标定(Calibration)量化参数

标定的过程就是使用模型推理量化图片,然后统计权重和激活分布,从而得到量化参数。量化图片一般来源于训练集(几百张左右,根据测试情况调整)。量化图片可以通过Pytorch的Dataloader读取,也可以直接自行实现读图片然后送入网络。

### 使用dataloader读取
for i, (data, label) in enumerate(train_loader):data = data.to(torch.device("cpu:0"))outputs = model_prepared(data)print("calibrating {}".format(i))if i > 1000:break

3. 转换为量化模型并保存

quantized_model = quantize_fx.convert_fx(model_prepared)
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")

速度测试

量化后的模型使用方法与fp32模型一样:

import torch
import cv2
import numpy as np
torch.set_num_threads(1)fused_model = torch.jit.load("jit_model.pt")
fused_model.eval()
fused_model.to(torch.device("cpu:0"))img = cv2.imread("./1.png")
img_fp32 = img.astype(np.float32)
img_fp32 = (img_fp32-127.5) / 127.5
input = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()def speed_test(model, input):# warm upfor i in range(10):model(input)import timestart = time.time()for i in range(100):model(input)end = time.time()print("model time: ", (end-start)/100)time.sleep(10)# quantized model
quantized_model= torch.jit.load("quantized_model.pt")
quantized_model.eval()
quantized_model.to(torch.device("cpu:0"))speed_test(fused_model, input)
speed_test(quantized_model, input)

实测fp32模型单核运行120ms, 量化后47ms

结语

本文介绍了fx graph模式下的Pytorch的PTSQ方法,并实测了一个模型,效果还比较不错。
1_995567224_161_79_3_732056265_62005da0d7c1b531a6cf91ea587d312e.jpg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/30107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nvm安装以及使用

注意事项: 安装前需要卸载原有的node,卸载干净后cmd输入node -v查看; 一,下载nvm 下载:https://github.com/coreybutler/nvm-windows/releases 选择第四个 “nvm-setup.zip”; 二,安装 1&…

IAR目标代码4字节对齐

向工程添加文件 eof.c : // 文件头 #if defined(__CC_ARM) // MDK // uint32_t g_update_flag[2] __attribute__((zero_init, at(0x1000FFF0)));const unsigned long gc_eof __attribute__((used)) 0xFFFFFFFFul; #elif defined(__ICCARM__) // IAR__root const unsigned…

分布式 - 消息队列Kafka:Kafka生产者发送消息的分区策略

文章目录 1. PartitionInfo 分区源码2. Partitioner 分区器接口源码3. 自定义分区策略4. 轮询策略 RoundRobinPartitioner5. 黏性分区策略 UniformStickyPartitioner6. hash分区策略7. 默认分区策略 DefaultPartitioner 分区的作用就是提供负载均衡的能力,或者说对数…

ArcGIS Pro实践技术应用暨基础入门、制图、空间分析、影像分析、三维建模、空间统计分析与建模、python融合、案例应用

GIS是利用电子计算机及其外部设备,采集、存储、分析和描述整个或部分地球表面与空间信息系统。简单地讲,它是在一定的地域内,将地理空间信息和 一些与该地域地理信息相关的属性信息结合起来,达到对地理和属性信息的综合管理。GIS的…

数字化时代,如何做好用户体验与应用性能管理

引言 随着数字化时代的到来,各个行业的应用系统从传统私有化部署逐渐转向公有云、行业云、微服务,这种变迁给运维部门和应用部门均带来了较大的挑战。基于当前企业 IT 运维均为多部门负责,且使用多种运维工具,因此,当…

hacksudo3 通关详解

环境配置 一开始桥接错网卡了 搞了半天 改回来就行了 信息收集 漏洞发现 扫个目录 大概看了一眼没什么有用的信息 然后对着login.php跑了一下弱口令 sqlmap 都没跑出来 那么利用点应该不在这 考虑到之前有过dirsearch字典太小扫不到东西的经历 换个gobuster扫一下 先看看g…

Android界面设计与用户体验

Android界面设计与用户体验 1. 引言 在如今竞争激烈的移动应用市场,提供优秀的用户体验成为了应用开发的关键要素。无论应用功能多么强大,如果用户界面设计不合理,用户体验不佳,很可能会导致用户流失。因此,在Androi…

Flink源码之JobManager启动流程

从启动命令flink-daemon.sh中可以看出StandaloneSession入口类为org.apache.flink.runtime.entrypoint.StandaloneSessionClusterEntrypoint, 从该类的main方法会进入ClusterEntrypoint::runCluster中, 该方法中会创建出主要服务和组件。 StandaloneSessionClusterEntrypoint:…

博客项目(Spring Boot)

1.需求分析 注册功能(添加用户操纵)登录功能(查询操作)我的文章列表页(查询我的文章|文章修改|文章详情|文章删除)博客编辑页(添加文章操作)所有人博客列表(带分页功能)…

FPGA外部触发信号毛刺产生及滤波

1、背景 最近在某个项目中,遇到输入给FPGA管脚的外部触发信号因为有毛刺产生,导致FPGA接收到的外部触发信号数量多于实际值。比如:用某个信号源产生1000个外部触发信号(上升沿触发方式)给到FPGA输入IO,实际…

冠达管理:股票注册制通俗理解?

目前我国A股商场正在进行股票注册制变革,相较之前的发行准则,股票注册制在理念上更为商场化,这意味着公司发行股票的门槛将下降,公司数量将添加,而股票流通的方式也将有所改变。那么股票注册制指的是什么,它…

使用vscode远程登录以及本地使用的配置(插件推荐)

1、远程登陆ssh 1.1打开vscode插件商店,安装remote-ssh插件 远程ssh添加第三方插件:vscode下链接远程服务器安装插件失败、速度慢等解决方法_vscode远程安装不上扩展_Emphatic的博客-CSDN博客 转到定义,选中代码->鼠标右键->转到定义…

2023/08/05【网络课程总结】

1. 查看git拉取记录 git reflog --dateiso|grep pull2. TCP/IP和OSI七层参考模型 3. DNS域名解析 4. 预检请求OPTIONS 5. 渲染进程的回流(reflow)和重绘(repaint) 6. V8解析JavaScript 7. CDN负载均衡的简单理解 8. 重学Ajax 重学Ajax满神 9. 对于XML的理解 大白话叙述XML是…

JavaScript中的几种常用循环方式对比

JavaScript中的几种循环方式 1. for 循环1.1 使用方式1.2 不支持遍历对象(但不会报错) 2. for-of 循环2.1 使用方式2.2 for-of 和 for 循环比较(不允许修改原数组元素)2.2.1 相同点2.2.1.2 都可以遍历数组,但不允许遍历…

灰度非线性变换之c++实现(qt + 不调包)

本章介绍灰度非线性变换,具体内容包括:对数变换、幂次变换、指数变换。他们的共同特点是使用非线性变换关系式进行图像变换。 1.灰度对数变换 变换公式:y a log(1x) / b,其中,a控制曲线的垂直移量;b为正…

快速解决IDEA中类的图标变成J,不是C的情况

有时候导入新的项目后,会出现如下情况,类的图标变成J,如图: 直接上解决方法: 找到项目的pom.xml,右键,在靠近最下方的位置找到Add as Maven Project,点击即可。 此时,一般类的图标就…

Wlan——射频和天线基础知识

目录 射频的介绍 射频和Wifi 射频的相关基础概念 射频的传输 信号功率的单位 射频信号传输行为 天线的介绍 天线的分类 天线的基本原理 天线的参数 射频的介绍 射频和Wifi 什么是射频 从射频发射器产生一个变化的电流(交流电),通过…

【信号生成器】从 Excel 数据文件创建 Simulink 信号生成器块研究(Simulink)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Linux 文件基本属性

Linux 文件基本属性 Linux 系统是一种典型的多用户系统,不同的用户处于不同的地位,拥有不同的权限。 为了保护系统的安全性,Linux 系统对不同的用户访问同一文件(包括目录文件)的权限做了不同的规定。 在 Linux 中我…

【项目学习1】如何将java对象转化为XML字符串

如何将java对象转化为XML字符串 将java对象转化为XML字符串,可以使用Java的XML操作库JAXB,具体操作步骤如下: 主要分为以下几步: 1、创建JAXBContext对象,用于映射Java类和XML。 JAXBContext jaxbContext JAXBConte…