pytorch量化训练

训练时量化(Quantization-aware Training, QAT)是一种在模型训练过程中,通过模拟低精度量化效应来增强模型对量化操作的鲁棒性的技术。与后训练量化不同,QAT 允许模型在训练过程中考虑到量化引入的误差,从而在实际部署时使用低精度进行推理时能够维持更高的性能。

1. 假量化节点插入(Fake Quantization Nodes)

在训练过程中,通过在网络中插入假量化节点来模拟量化和反量化的过程。这些节点在前向传播过程中将权重和激活值量化到指定的数值范围和精度(如INT8),然后再反量化回浮点数,以进行后续的计算。通过这种方式,模型可以适应量化带来的信息损失。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStubclass QuantAwareNet(nn.Module):def __init__(self):super(QuantAwareNet, self).__init__()self.quant = QuantStub() # 新插入内容self.dequant = DeQuantStub() # 新插入内容self.fc1 = nn.Linear(784, 256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 10)def forward(self, x):x = self.quant(x) # 新插入内容x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.dequant(x) # 新插入内容return x

2. 量化配置

在PyTorch中,量化配置(QConfig)指定了模型量化过程中将使用的量化方案和算子。通过使用不同的QConfig,你可以控制如何量化模型中的权重和激活,这对于模型的性能和精度具有重要影响。

2.1 量化配置函数 get_default_qat_qconfig

get_default_qat_qconfig 是PyTorch提供的一个函数,用于获取用于量化感知训练(QAT)的默认量化配置。这个函数的一个重要参数是后端,通常是 ‘fbgemm’ 或 ‘qnnpack’:

  • ‘fbgemm’: 主要用于服务器和桌面平台上的x86架构,支持INT8量化。
  • ‘qnnpack’: 适用于移动设备,也支持INT8量化,优化了ARM架构。
from torch.quantization import get_default_qconfig
qconfig = get_default_qconfig('fbgemm')

这个函数会设置一个QConfig,其中包括针对权重和激活的量化方案。对于QAT,权重通常在前向过程中进行伪量化,而激活则在训练时进行动态量化。

2.2 可以设置的其他配置选项

PyTorch允许用户自定义QConfig,以适应特定的需求或实验不同的量化方案。自定义QConfig通常涉及以下部分:

2.2.1 量化方案:

  • torch.quantization.default_observer:
    默认的观察者,用于激活,基于移动平均和最小最大值自动调整量化参数。
  • torch.quantization.default_per_channel_weight_observer:
    用于权重的通道级观察者,每个输出通道有独立的量化参数。

2.2.2 量化和反量化函数:

  • torch.quantization.FakeQuantize: 实现伪量化和反量化,模拟量化的效果而不改变底层数据类型。

创建自定义的QConfig:

from torch.quantization import QConfig, default_observer, default_per_channel_weight_observercustom_qconfig = QConfig(activation=default_observer.with_args(dtype=torch.qint8),weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)

2.3 使用自定义QConfig

可以应用到模型的特定部分或整个模型上

model.fc1.qconfig = custom_qconfig  # 应用到模型的一个特定层 

# 应用到整个模型
from torch.quantization import prepare_qat  
model.qconfig = custom_qconfig 
model = prepare_qat(model, inplace=True) 

3. 量化感知训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert# 实例化模型
model = MyQuantizedModel() # 指定量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')# 准备量化感知训练,
model = prepare_qat(model)# 训练配置
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环
for epoch in range(num_epochs):for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()# 转换模型为完全量化if epoch == num_epochs - 1:model = convert(model.eval(), inplace=True)

4. 量化推理测试

import torch
from torch.quantization import convertdef test_quantized_model(model, dataloader, device='cpu'):model = convert(model.eval(), inplace=True)model.to(device)  # 确保模型在正确的设备上correct = 0total = 0with torch.no_grad():  # 关闭梯度计算,因为我们只做推理for data, targets in dataloader:data, targets = data.to(device), targets.to(device)  # 移动数据到相应设备outputs = model(data)  # 前向推理_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += targets.size(0)correct += (predicted == targets).sum().item()accuracy = 100 * correct / totalprint(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')# 'test_loader' 是用于测试的 DataLoader
# 测试模型
# test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

5.完整参考代码

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoaderimport torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStubimport torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')class QuantizedCNN(nn.Module):def __init__(self):super(QuantizedCNN, self).__init__()self.quant = QuantStub()self.conv1 = nn.Conv2d(3, 16, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.fc1 = nn.Linear(32 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)self.dequant = DeQuantStub()def forward(self, x):# x = self.quant(x)x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)x = self.dequant(x)return xdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')# # 准备模型进行量化感知训练
model = prepare_qat(model, inplace=True)optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0# 切换到评估模式进行测试model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))# 在最后一个epoch后完成量化if epoch == num_epochs - 1:model = convert(model.eval(), inplace=True)print("Model quantization completed.")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/60599.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Java绘制图片边框,解决微信小程序map组件中marker与label层级关系问题,label增加外边框后显示不能置与marker上面

今天上线的时候发现系统不同显示好像不一样,苹果手机打开的时候是正常的,但是一旦用安卓手机打开就会出现label不置顶的情况。尝试了很多种办法,也在官方查看了map相关的文档,发现并没有给label设置zIndex的属性,只看到…

Redisson的可重入锁

初始状态: 表示系统或资源在没有线程持有锁的情况下的状态,任何线程都可以尝试获取锁。 线程 1 获得锁: 线程 1 首次获取了锁并进入受保护的代码区域。 线程 1 再次请求锁: 在持有锁的情况下,线程 1 再次请求锁&a…

三周精通FastAPI:37 包含 WSGI - Flask,Django,Pyramid 以及其它

官方文档:https://fastapi.tiangolo.com/zh/advanced/wsgi/ 包含 WSGI - Flask,Django,其它 您可以挂载多个 WSGI 应用,正如您在 Sub Applications - Mounts, Behind a Proxy 中所看到的那样。 为此, 您可以使用 WSGIMiddlewar…

Swagger UI

Swagger UI 是一个开源工具,用于可视化、构建和交互式地探索 RESTful API。 它是 Swagger 生态系统的一部分,Swagger 是一套用于描述、生成、调用和可视化 RESTful Web 服务的工具和规范。 Swagger UI 可以自动生成 API 文档,并提供一个交互…

thinkphp6 --数据库操作 增删改查

一、数据库连接配置 如果是本地测试,它会优先读取 .env 配置,然后再读取 database.php 的配置; 如果禁用了 .env 配置,则会读取数据库连接的默认配置: # .env文件,部署服务器,请禁用我 我们可以…

【蓝桥等考C++真题】蓝桥杯等级考试C++组第13级L13真题原题(含答案)-最大的数

CL13 最大的数(20 分) 输入一个有 n 个无重复元素的整数数组 a&#xff0c;输出数组中最大的数。提示&#xff1a;如使用排序库函数 sort()&#xff0c;需要包含头文件#include 。输入&#xff1a; 第一行是一个正整数 n(2<n<20)&#xff1b; 第二行包含 n 个不重复的整…

让Git走代理

有时候idea提交代码或者从github拉取代码&#xff0c;一直报错超时或者:Recv failure: Connection was reset,下面记录一下怎么让git走代理从而访问到github。 1.打开梯子 2.打开网络和Internet设置 3.设置代理 记住这个地址和端口 4.打开git bash终端 输入以下内容 git c…

vivo 游戏中心包体积优化方案与实践

作者&#xff1a;来自 vivo 互联网大前端团队- Ke Jie 介绍 App 包体积优化的必要性&#xff0c;游戏中心 App 在实际优化过程中的有效措施&#xff0c;包括一些优化建议以及优化思路。 一、包体积优化的必要性 安装包大小与下载转化率的关系大致是成反比的&#xff0c;即安装…

Struts扫盲

Struts扫盲 这里的struts是struts1。以本文记录我的那些复习JavaEE的痛苦并快乐的晚上 Struts是什么 框架的概念想必大家都清楚&#xff0c;框架即“半成品代码”&#xff0c;是为了简化开发而设计的。一个项目有许多分层&#xff0c;拿一个MVC架构的Web应用来说&#xff0c;有…

【AiPPT-注册/登录安全分析报告-无验证方式导致安全隐患】

前言 由于网站注册入口容易被机器执行自动化程序攻击&#xff0c;存在如下风险&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露&#xff0c;不符合国家等级保护的要求。短信盗刷带来的拒绝服务风险 &#xff0c;造成用户无法登陆、注册&#xff0c;大量收到垃圾短信的…

自动驾驶系列—从数据采集到存储:解密自动驾驶传感器数据采集盒子的关键技术

&#x1f31f;&#x1f31f; 欢迎来到我的技术小筑&#xff0c;一个专为技术探索者打造的交流空间。在这里&#xff0c;我们不仅分享代码的智慧&#xff0c;还探讨技术的深度与广度。无论您是资深开发者还是技术新手&#xff0c;这里都有一片属于您的天空。让我们在知识的海洋中…

【月之暗面kimi-注册/登录安全分析报告】

前言 由于网站注册入口容易被机器执行自动化程序攻击&#xff0c;存在如下风险&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露&#xff0c;不符合国家等级保护的要求。短信盗刷带来的拒绝服务风险 &#xff0c;造成用户无法登陆、注册&#xff0c;大量收到垃圾短信的…

时序预测 | 改进图卷积+informer时间序列预测,pytorch架构

时序预测 | 改进图卷积informer时间序列预测&#xff0c;pytorch架构 目录 时序预测 | 改进图卷积informer时间序列预测&#xff0c;pytorch架构预测效果基本介绍参考资料 预测效果 基本介绍 改进图卷积informer时间序列预测代码 CTR-GC卷积,informer&#xff0c;CTR-GC 图卷积…

从入门到精通:一文掌握 Dockerfile 的用法!(多阶段构建与缓存优化)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 Dockerfile基础用法 📒📝 什么是 Dockerfile?📝 Dockerfile 的常见指令🔖 构建指令🔖 命令指令🎈 完整示例:构建一个 Python Flask 应用🔖 1. 项目结构🔖 2. 编写 Dockerfile🔖 3. 构建和运行 Docker 镜像�…

Go语言开发基于SQLite数据库实现用户表修改接口(四)

背景 上一章 Go语言开发基于SQLite数据库实现用户表查询详情接口(三) 这一章我们实现用户表的修改接口 代码实现 mapper层 type UserMapper interface {UpdateById(user *model.User, id uint64) error}type userMapper struct { }func (m *userMapper) UpdateById(user *m…

【C++学习(35)】在Linux中基于ucontext实现C++实现协程(Coroutine),基于C++20的co_await 协程的关键字实现协程

文章目录 为什么使用协程协程的理解协程优势协程的原语操作yield 与 resume 是一个switch操作&#xff08;三种实现方式&#xff09;&#xff1a; 基于 ucontext 的协程基于 XFiber 库的操作1 包装上下文2 XFiber 上下文调度器2.1 CreateFiber2.2 Dispatch 基于C20的co_return …

844.比较含退格的字符串

java用 O&#xff08;1&#xff09;空间这个方法&#xff0c;容易挺多bug的… O&#xff08;1&#xff09;空间 #&#xff1a;删除前一个字符 》 从后面开始判断&#xff08;这样可以用跳过的思想&#xff09;不能使用两次 i- - 来处理 # 的操作&#xff0c;会造成误删了前面…

大数据实训室建设的必要性

一、大数据发展的背景 大数据作为当今信息技术领域的核心驱动力&#xff0c;正在深刻地改变着社会的各个方面。它不仅仅是指数据量庞大&#xff0c;更重要的是指数据的多样性、实时性和复杂性。随着云计算、物联网等技术的迅猛发展&#xff0c;大数据已成为推动经济社会发展的…

MyBatis——增删查改(XML 方式)

1. 查询 1.1. 简单查询 使用注解的方式主要是完成一些简单的增删查改功能&#xff0c;如果要实现复杂的 SQL 功能&#xff0c;还是建议使用 XML 来配置映射语句&#xff0c;将 SQL 语句写在 XML 配置文件中 如果要操作数据库&#xff0c;需要做以下的配置&#xff0c;与注解…

K8S如何基于Istio实现全链路HTTPS

K8S如何基于Istio实现全链路HTTPS Istio 简介Istio 是什么?为什么选择 Istio?Istio 的核心概念Service Mesh(服务网格)Data Plane(数据平面)Sidecar Mode(边车模式)Ambient Mode(环境模式)Control Plane(控制平面)Istio 的架构与组件Envoy ProxyIstiod其他组件Istio 的流量管…