pytorch量化训练

训练时量化(Quantization-aware Training, QAT)是一种在模型训练过程中,通过模拟低精度量化效应来增强模型对量化操作的鲁棒性的技术。与后训练量化不同,QAT 允许模型在训练过程中考虑到量化引入的误差,从而在实际部署时使用低精度进行推理时能够维持更高的性能。

1. 假量化节点插入(Fake Quantization Nodes)

在训练过程中,通过在网络中插入假量化节点来模拟量化和反量化的过程。这些节点在前向传播过程中将权重和激活值量化到指定的数值范围和精度(如INT8),然后再反量化回浮点数,以进行后续的计算。通过这种方式,模型可以适应量化带来的信息损失。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStubclass QuantAwareNet(nn.Module):def __init__(self):super(QuantAwareNet, self).__init__()self.quant = QuantStub() # 新插入内容self.dequant = DeQuantStub() # 新插入内容self.fc1 = nn.Linear(784, 256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 10)def forward(self, x):x = self.quant(x) # 新插入内容x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.dequant(x) # 新插入内容return x

2. 量化配置

在PyTorch中,量化配置(QConfig)指定了模型量化过程中将使用的量化方案和算子。通过使用不同的QConfig,你可以控制如何量化模型中的权重和激活,这对于模型的性能和精度具有重要影响。

2.1 量化配置函数 get_default_qat_qconfig

get_default_qat_qconfig 是PyTorch提供的一个函数,用于获取用于量化感知训练(QAT)的默认量化配置。这个函数的一个重要参数是后端,通常是 ‘fbgemm’ 或 ‘qnnpack’:

  • ‘fbgemm’: 主要用于服务器和桌面平台上的x86架构,支持INT8量化。
  • ‘qnnpack’: 适用于移动设备,也支持INT8量化,优化了ARM架构。
from torch.quantization import get_default_qconfig
qconfig = get_default_qconfig('fbgemm')

这个函数会设置一个QConfig,其中包括针对权重和激活的量化方案。对于QAT,权重通常在前向过程中进行伪量化,而激活则在训练时进行动态量化。

2.2 可以设置的其他配置选项

PyTorch允许用户自定义QConfig,以适应特定的需求或实验不同的量化方案。自定义QConfig通常涉及以下部分:

2.2.1 量化方案:

  • torch.quantization.default_observer:
    默认的观察者,用于激活,基于移动平均和最小最大值自动调整量化参数。
  • torch.quantization.default_per_channel_weight_observer:
    用于权重的通道级观察者,每个输出通道有独立的量化参数。

2.2.2 量化和反量化函数:

  • torch.quantization.FakeQuantize: 实现伪量化和反量化,模拟量化的效果而不改变底层数据类型。

创建自定义的QConfig:

from torch.quantization import QConfig, default_observer, default_per_channel_weight_observercustom_qconfig = QConfig(activation=default_observer.with_args(dtype=torch.qint8),weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)

2.3 使用自定义QConfig

可以应用到模型的特定部分或整个模型上

model.fc1.qconfig = custom_qconfig  # 应用到模型的一个特定层 

# 应用到整个模型
from torch.quantization import prepare_qat  
model.qconfig = custom_qconfig 
model = prepare_qat(model, inplace=True) 

3. 量化感知训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert# 实例化模型
model = MyQuantizedModel() # 指定量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')# 准备量化感知训练,
model = prepare_qat(model)# 训练配置
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环
for epoch in range(num_epochs):for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()# 转换模型为完全量化if epoch == num_epochs - 1:model = convert(model.eval(), inplace=True)

4. 量化推理测试

import torch
from torch.quantization import convertdef test_quantized_model(model, dataloader, device='cpu'):model = convert(model.eval(), inplace=True)model.to(device)  # 确保模型在正确的设备上correct = 0total = 0with torch.no_grad():  # 关闭梯度计算,因为我们只做推理for data, targets in dataloader:data, targets = data.to(device), targets.to(device)  # 移动数据到相应设备outputs = model(data)  # 前向推理_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += targets.size(0)correct += (predicted == targets).sum().item()accuracy = 100 * correct / totalprint(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')# 'test_loader' 是用于测试的 DataLoader
# 测试模型
# test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

5.完整参考代码

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoaderimport torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStubimport torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')class QuantizedCNN(nn.Module):def __init__(self):super(QuantizedCNN, self).__init__()self.quant = QuantStub()self.conv1 = nn.Conv2d(3, 16, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.fc1 = nn.Linear(32 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)self.dequant = DeQuantStub()def forward(self, x):# x = self.quant(x)x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)x = self.dequant(x)return xdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')# # 准备模型进行量化感知训练
model = prepare_qat(model, inplace=True)optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0# 切换到评估模式进行测试model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))# 在最后一个epoch后完成量化if epoch == num_epochs - 1:model = convert(model.eval(), inplace=True)print("Model quantization completed.")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/60599.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker--工作目录迁移

前言 安装docker,默认的情况容器的默认存储路径会存储系统盘的 /var/lib/docker 目录下,系统盘一般默认 50G,容器输出的所有的日志,文件,镜像,都会存在这个地方,时间久了就会占满系统盘。 一、…

开发效率工具链全解析

🛠 开发效率工具链全解析:从入门到精通 在现代前端开发中,高效的工具链对于提升开发效率至关重要。本文将全方位剖析项目脚手架、包管理工具以及构建工具的深度集成与实战应用。 📑 内容导航 工具链概述项目脚手架包管理工具常见…

[ 网络安全介绍 3 ] 网络安全事件相关案例有哪些?

🍬 博主介绍 👨‍🎓 博主介绍:大家好,我是 _PowerShell ,很高兴认识大家~ ✨主攻领域:【渗透领域】【数据通信】 【通讯安全】 【web安全】【面试分析】 🎉点赞➕评论➕收藏 养成习…

【Unity基础】Unity中碰撞及触发类物理交互应用场景说明

一、碰撞类回调方法 在Unity中,碰撞类回调方法是用于处理物体间碰撞的逻辑。这些方法常用于 MonoBehaviour 脚本中,以便在物体发生碰撞时进行响应。以下是最常用的三个碰撞类回调方法的详细说明: 1. OnCollisionEnter(Collision collision)…

【MySQL】MySQL中的函数之REGEXP_SUBSTR

在 MySQL 中,REGEXP_SUBSTR() 函数用于从字符串中提取与正则表达式匹配的子串。这个函数也是从 MySQL 8.0 开始引入的。下面是一些关于如何使用 REGEXP_SUBSTR() 的详细说明和示例。 基本语法 REGEXP_SUBSTR(str, pat [, position [, occurrence [, match_type ]]…

使用Java绘制图片边框,解决微信小程序map组件中marker与label层级关系问题,label增加外边框后显示不能置与marker上面

今天上线的时候发现系统不同显示好像不一样,苹果手机打开的时候是正常的,但是一旦用安卓手机打开就会出现label不置顶的情况。尝试了很多种办法,也在官方查看了map相关的文档,发现并没有给label设置zIndex的属性,只看到…

arm64架构的linux 配置vm_page_prot方式

在 ARM64 架构上,通过 vm_page_prot 属性可以修改 UIO 映射内存的访问权限及缓存策略,常见的有非缓存(Non-cached)、写合并(Write Combine)等。下面是 ARM64 常用的 vm_page_prot 设置及其对应的操作方式。…

Redisson的可重入锁

初始状态: 表示系统或资源在没有线程持有锁的情况下的状态,任何线程都可以尝试获取锁。 线程 1 获得锁: 线程 1 首次获取了锁并进入受保护的代码区域。 线程 1 再次请求锁: 在持有锁的情况下,线程 1 再次请求锁&a…

探秘Spring Boot中的@Conditional注解

文章目录 1. 什么是Conditional注解?2. 为什么需要Conditional注解?3. 如何使用Conditional注解?4. Conditional注解的高级用法5. 注意事项6. 结语推荐阅读文章 在Spring Boot的世界里,配置的灵活性和多样性是至关重要的。有时候&…

三周精通FastAPI:37 包含 WSGI - Flask,Django,Pyramid 以及其它

官方文档:https://fastapi.tiangolo.com/zh/advanced/wsgi/ 包含 WSGI - Flask,Django,其它 您可以挂载多个 WSGI 应用,正如您在 Sub Applications - Mounts, Behind a Proxy 中所看到的那样。 为此, 您可以使用 WSGIMiddlewar…

Swagger UI

Swagger UI 是一个开源工具,用于可视化、构建和交互式地探索 RESTful API。 它是 Swagger 生态系统的一部分,Swagger 是一套用于描述、生成、调用和可视化 RESTful Web 服务的工具和规范。 Swagger UI 可以自动生成 API 文档,并提供一个交互…

thinkphp6 --数据库操作 增删改查

一、数据库连接配置 如果是本地测试,它会优先读取 .env 配置,然后再读取 database.php 的配置; 如果禁用了 .env 配置,则会读取数据库连接的默认配置: # .env文件,部署服务器,请禁用我 我们可以…

WPF中MVVM工具包 CommunityToolkit.Mvvm

CommunityToolkit.Mvvm,也称为MVVM工具包,是Microsoft Community Toolkit的一部分。它是一个轻量级但功能强大的MVVM(Model-View-ViewModel)库,旨在帮助开发者更容易地实现MVVM设计模式。 特点 独立于平台和运行时&a…

【蓝桥等考C++真题】蓝桥杯等级考试C++组第13级L13真题原题(含答案)-最大的数

CL13 最大的数(20 分) 输入一个有 n 个无重复元素的整数数组 a&#xff0c;输出数组中最大的数。提示&#xff1a;如使用排序库函数 sort()&#xff0c;需要包含头文件#include 。输入&#xff1a; 第一行是一个正整数 n(2<n<20)&#xff1b; 第二行包含 n 个不重复的整…

让Git走代理

有时候idea提交代码或者从github拉取代码&#xff0c;一直报错超时或者:Recv failure: Connection was reset,下面记录一下怎么让git走代理从而访问到github。 1.打开梯子 2.打开网络和Internet设置 3.设置代理 记住这个地址和端口 4.打开git bash终端 输入以下内容 git c…

vivo 游戏中心包体积优化方案与实践

作者&#xff1a;来自 vivo 互联网大前端团队- Ke Jie 介绍 App 包体积优化的必要性&#xff0c;游戏中心 App 在实际优化过程中的有效措施&#xff0c;包括一些优化建议以及优化思路。 一、包体积优化的必要性 安装包大小与下载转化率的关系大致是成反比的&#xff0c;即安装…

Leetcode 每日一题 125.验证回文串

问题定义 给定一个字符串s&#xff0c;我们需要判断它是否是一个回文串。但在此之前&#xff0c;我们需要将所有大写字符转换为小写字符&#xff0c;并移除所有非字母数字字符。只有经过这样处理后的字符串&#xff0c;我们才进行回文检测。 示例解析 以下是几个示例&#x…

Struts扫盲

Struts扫盲 这里的struts是struts1。以本文记录我的那些复习JavaEE的痛苦并快乐的晚上 Struts是什么 框架的概念想必大家都清楚&#xff0c;框架即“半成品代码”&#xff0c;是为了简化开发而设计的。一个项目有许多分层&#xff0c;拿一个MVC架构的Web应用来说&#xff0c;有…

【AiPPT-注册/登录安全分析报告-无验证方式导致安全隐患】

前言 由于网站注册入口容易被机器执行自动化程序攻击&#xff0c;存在如下风险&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露&#xff0c;不符合国家等级保护的要求。短信盗刷带来的拒绝服务风险 &#xff0c;造成用户无法登陆、注册&#xff0c;大量收到垃圾短信的…

自动驾驶系列—从数据采集到存储:解密自动驾驶传感器数据采集盒子的关键技术

&#x1f31f;&#x1f31f; 欢迎来到我的技术小筑&#xff0c;一个专为技术探索者打造的交流空间。在这里&#xff0c;我们不仅分享代码的智慧&#xff0c;还探讨技术的深度与广度。无论您是资深开发者还是技术新手&#xff0c;这里都有一片属于您的天空。让我们在知识的海洋中…