模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

  • 前言
  • 量化
    • Post-Training-Quantization(PTQ)
    • Quantization-Aware-Training(QAT)
  • 参考文献

前言

随着人工智能的不断发展,深度学习网络被广泛应用于图像处理、自然语言处理等实际场景,将其部署至多种不同设备的需求也日益增加。然而,常见的深度学习网络模型通常包含大量参数和数百万的浮点数运算(例如ResNet50具有95MB的参数以及38亿浮点数运算),实时地运行这些模型需要消耗大量内存和算力,这使得它们难以部署到资源受限且需要满足实时性、低功耗等要求的边缘设备。为了进一步推动深度学习网络模型在移动端或边缘设备中的快速部署,深度学习领域提出了一系列的模型压缩与加速方法:

  • 知识蒸馏(Knowledge distillation):使用教师-学生网络结构,让小型的学生网络模仿大型教师网络的行为,以使得准确率尽可能高的同时,能够获得一个轻量化的网络。
  • 剪枝(Parameter pruning):删除不必要的网络参数,以减少模型的规模和计算复杂度。
  • 低秩分解(Low-rank factorization):将模型的参数矩阵分解为较低秩的小矩阵,以减少模型的复杂度和计算成本。
  • 参数共享(Parameter sharing):将多个层共用一组参数,以减少模型的参数数量。
  • 量化(Quantization):将模型的参数和运算转化为更小的数据类型,以减少内存占用和计算时间。

量化

模型量化(Quantization)是一种将浮点计算转化为定点计算的技术,例如从FP32降低至INT8,主要用于减少模型的计算强度、参数大小以及内存消耗,以提高模型在设备上的推理计算效率,但是也有可能会带来一定的精度损失。

模型量化精度损失的主要原因为量化-反量化(Quantization-Dequantization)过程中取整引起的误差。这里简单介绍一下量化的计算方法,以FP32到INT8的量化为例,量化的核心思想就是将浮点数区间的参数映射到INT8的离散区间中。
量化公式:
q = r s + Z q = \frac{r}{s} + Z q=sr+Z反量化公式:
r = S ( q − Z ) r = S(q-Z) r=S(qZ)其中, r r r 为FP32的浮点数(real value), q q q 为INT8的量化值(quantization value),
S S S Z Z Z 分别为缩放因子(Scale-factor)和零点(Zero-Point)。

量化最重要的便是确定 S S S Z Z Z 的值, S S S Z Z Z 的计算公式如下:
S = r m a x − r m i n q m a x − q m i n S = \frac{r_{max}-r_{min}}{q_{max}-q_{min}} S=qmaxqminrmaxrmin Z = − r m i n S + q m i n Z = -\frac{r_{min}}{S} + q_{min} Z=Srmin+qmin其中, r m a x r_{max} rmax r m i n r_{min} rmin 分别为FP32网络参数最大、最小值, q m a x q_{max} qmax q m i n q_{min} qmin 分别为INT8网络参数最大、最小值。

为了减少量化所带来的精度损失,学者提出了Quantization-Aware-Training(QAT)方法,再介绍此之前,由于Post-Training-Quantization(PTQ)方法也经常在文献中出现,此篇博客将着重介绍这两个方法的含义与区别。
在这里插入图片描述

Post-Training-Quantization(PTQ)

Post-Training-Quantization(PTQ)是目前常用的模型量化方法之一。以INT8量化为例,PTQ方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的模型;
  2. 使用小部分数据对FP32模型进行采样(Calibration),主要是为了得到网络各层参数的数据分布特性(比如统计最大最小值);
  3. 根据步骤2中的数据分布特性,计算出网络各层 S 和 Z 量化参数;
  4. 使用步骤3中的量化参数对FP32模型进行量化得到INT8模型,并将其部署至推理框架进行推理。

PTQ方法会使用小部分数据集来估计网络各层参数的数据分布,找到合适的S和Z的取值,从而一定程度上降低模型精度损失。然而,论文中指出PTQ方式虽然在大模型上效果较好(例如ResNet101),但是在小模型上经常会有较大的精度损失(例如MobileNet) 不同通道的输出范围相差可能会非常大(大于100x), 对异常值较为敏感。

Quantization-Aware-Training(QAT)

由上文可知PTQ方法中模型的训练和量化是分开的,而Quantization-Aware-Training(QAT)方法则是在模型训练时加入了伪量化节点,用于模拟模型量化时引起的误差,并通过微调使得模型在量化后尽可能减少精度损失。以INT8量化为例,QAT方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的FP32模型;
  2. 在FP32模型中插入伪量化节点,得到QAT模型,并且在数据集上对QAT模型进行微调(Fine-tuning);
  3. 同PTQ方法中的采样(Calibration),并计算量化参数 S 和 Z ;
  4. 使用步骤3中得到的量化参数对QAT模型进行量化得到INT8模型,并部署至推理框架中进行推理。

在PyTorch中,可以使用 torch.quantization.quantize_dynamic() 方法来执行 QAT。这是一个基本的 QAT 代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import quantize_dynamic, QuantStub, DeQuantStub# 定义简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.quant = QuantStub()self.dequant = DeQuantStub()self.fc1 = nn.Linear(784, 256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 10)def forward(self, x):x = self.quant(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.dequant(x)return x# 数据加载
# 这里使用 MNIST 数据集作为示例
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform),batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, download=True, transform=transform),batch_size=64, shuffle=False)# 定义损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 定义 QAT 训练函数
def train(model, train_loader, criterion, optimizer, num_epochs=5):model.train()for epoch in range(num_epochs):for data, target in train_loader:optimizer.zero_grad()output = model(data.view(data.shape[0], -1))loss = criterion(output, target)loss.backward()optimizer.step()# 训练模型
train(model, train_loader, criterion, optimizer, num_epochs=5)# 在训练完成后执行动态量化
quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)# 评估量化模型
def test(model, test_loader, criterion):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:output = model(data.view(data.shape[0], -1))_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()accuracy = correct / totalprint(f'Accuracy of the network on the test images: {accuracy * 100:.2f}%')# 测试量化模型
test(quantized_model, test_loader, criterion)

上述代码示例中,我使用了一个简单的全连接神经网络,并在训练完成后使用torch.quantization.quantize_dynamic()对模型进行动态量化。在量化之前,我们通过QuantStub()DeQuantStub()添加了量化和反量化的辅助模块。这个示例使用了MNIST数据集,你可以根据你的实际需求替换成其他数据集和模型。

参考文献

量化感知训练(Quantization-aware-training)探索-从原理到实践

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/140043.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FFmpeg简介1

适逢FFmpeg6.1发布,准备深入学习下FFmpeg,将会写下系列学习记录。 在此列出主要学习资料,后续再不列,感谢这些大神的探路和分享,特别是雷神,致敬! 《FFmpeg从入门到精通》 《深入理解FFmpeg》 …

Git版本控制系统之分支与标签(版本)

目录 一、Git分支(Branch) 1.1 分支作用 1.2 四种分支管理策略 1.3 使用案例 1.3.1 指令 1.3.2 结合应用场景使用 二、Git标签(Tag) 2.1 标签作用 2.2 标签规范 2.3 使用案例 2.3.1 指令 2.3.2 使用示例 一、Git分支&…

分布式理论基础:CAP定理

什么是CAP CAP原则又称CAP定理,指的是在一个分布式系统中,Consistency(一致性)、 Availability(可用性)、Partition tolerance(分区容错性)这三个基本需求,最多只能同时…

Unity Mirror学习(二) Command特性使用

Command(命令)特性 1,修饰方法的,当在客户端调用此方法,它将在服务端运行(我的理解:客户端命令服务端做某事;或者说:客户端向服务端发消息,消息方法&#xff…

几种解决mfc140.dll文件缺失的方法,电脑提示mfc140.dll怎么办

电脑提示mfc140.dll缺失,如果你不去处理的话,那么你的程序游戏什么都是启动不了的,如果你想知道有什么方法可以解决那么可以参考这篇文章进行解决,今天给大家几种解决mfc140.dll文件缺失的方法。电脑提示mfc140.dll也不用担心解决…

Qt贝塞尔曲线

目录 引言核心代码基本表达绘制曲线使用QEasingCurve 完整代码 引言 贝塞尔曲线客户端开发中常见的过渡效果,如界面的淡入淡出、数值变化、颜色变化等等。为了能够更深的了解地理解贝塞尔曲线,本文通过Demo将贝塞尔曲线绘制出来,如下所示&am…

基于SSM的数据结构课程网络学习平台

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…

Git系列之分支与标签的使用及应用场景模拟

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是君易--鑨,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的博客专栏《Git实战开发》。🎯🎯 &a…

Java学习_对象

对象在计算机中的执行原理 类和对象的一些注意事项 this关键字 构造器 构造器是一种特殊的方法 : 特殊之处在于,名字必须与所在类的名字一样,而且不能写返回值类型 封装 封装的设计规范:合理隐藏、合理暴露 实体类 成员变量和局部变量的区别 …

微信聊天,收到二维码图片就自动帮你提取出来的方法

10-3 如果你是二维码收集的重度用户,那我非常推荐你好好阅读本文,也许可以帮你解决你的问题,比如做网推的人,需要常年混迹在各种微信群,那如何在各个微信群中收集到群友分享出来的二维码,并且要立即保存出…

组件的设计原则

目录 插槽的基本概念 基础用法 具名插槽 使用场景 布局控制 嵌套组件 组件的灵活性 高级用法 作用域插槽 总结 前言 Vue 的 slot 是一项强大的特性,用于组件化开发中。它允许父组件向子组件传递内容,使得组件更加灵活和可复用。通过 slot&…

Python之函数进阶-nonlocal和LEGB

Python之函数进阶-nonlocal和LEGB nonlocal语句 nonlocal:将变量标记为不在本地作用域定义,而是在上级的某一级局部作用域中定义,但不能是全局作用域中定义。 函数的销毁 定义一个函数就是生成一个函数对象,函数名指向的就是函数对象。可…

华为云Ascend310服务器使用

使用华为云服务器 cpu: 16vCPUs Kunpeng 920 内存:16GiB gpu:4* HUAWEI Ascend 310 cann: 20.1.rc1 操作系统:Ubuntu aarch64目的 使用该服务器进行docker镜像编译,测试模型。 已知生产环境:mindx版本为3.0.rc3&a…

【机器学习】Kmeans聚类算法

一、聚类简介 Clustering (聚类)是常见的unsupervised learning (无监督学习)方法,简单地说就是把相似的数据样本分到一组(簇),聚类的过程,我们并不清楚某一类是什么(通常无标签信息)&#xff0…

通义千问, 文心一言, ChatGLM, GPT-4, Llama2, DevOps 能力评测

引言 “克隆 dev 环境到 test 环境,等所有服务运行正常之后,把访问地址告诉我”,“检查所有项目,告诉我有哪些服务不正常,给出异常原因和修复建议”,在过去的工程师生涯中,也曾幻想过能够通过这…

【FAQ】Gradle开发问题汇总

1. buildSrc依赖Spring Denpendency时报错 来自预编译脚本的插件请求不能包含版本号。请从有问题的请求中删除该版本,并确保包含所请求插件io.spring.dependency-management的模块是一个实现依赖项 解决方案 https://www.5axxw.com/questions/content/uqw0grhttps:/…

基于springboot实现桥牌计分管理系统项目【项目源码】计算机毕业设计

基于springboot实现桥牌计分管理系统演示 JAVA简介 JavaScript是一种网络脚本语言,广泛运用于web应用开发,可以用来添加网页的格式动态效果,该语言不用进行预编译就直接运行,可以直接嵌入HTML语言中,写成js语言&#…

MYSQL操作详解

一)计算机的基本结构 但是实际上,更多的是这种情况: 二)MYSQL中的数据类型: 一)数值类型: 数据类型内存大小(字节)说明bit(M)M指定位数,默认为1单个二进制位值,或者为0或者为1,主要用于开/关标志tinyint1字节1个字节的整数值,支持…

使用openvc进行人脸检测:Haar级联分类器

1 人脸检测介绍 1.1 什么是人脸检测 人脸检测的目标是确定图像或视频中是否存在人脸。如果存在多个面,则每个面都被一个边界框包围,因此我们知道这些面的位置 人脸检测算法的主要目标是准确有效地确定图像或视频中人脸的存在和位置。这些算法分析数据…

一文入门Springboot+actuator+Prometheus+Grafana

环境介绍 技术栈 springbootmybatis-plusmysqloracleactuatorPrometheusGrafana 软件 版本 mysql 8 IDEA IntelliJ IDEA 2022.2.1 JDK 1.8 Spring Boot 2.7.13 mybatis-plus 3.5.3.2 本地主机应用 192.168.1.9:8007 PrometheusGrafana安装在同一台主机 http://…