模型部署笔记--Pytorch-FX量化

目录

1--Pytorch-FX量化

2--校准模型

3--代码实例

3-1--主函数

3-2--prepare_dataloader函数

3-3--训练和测试函数


1--Pytorch-FX量化

        Pytorch在torch.quantization.quantize_fx中提供了两个API,即prepare_fx和convert_fx。

        prepare_fx的作用是准备量化,其在输入模型里按照设定的规则qconfig_dict来插入观察节点,进行的工作包括:

1. 将nn.Module转换为GraphModule。
2. 合并算子,例如将Conv、BN和Relu算子进行合并(通过打印模型可以查看合并的算子)。
3. 在Conv和Linear等OP前后插入Observer, 用于观测激活值Feature map的特征(权重的最大最小值),计算scale和zero_point。

        convert_fx的作用是根据scale和zero_point来将模型进行量化。

2--校准模型

        完整项目代码参考:ljf69/Model-Deployment-Notes

        在对原始模型model调用prepare_fx()后得到prepare_model,一般需要对模型进行校准,校准后再调用convert_fx()进行模型的量化。

3--代码实例

3-1--主函数

import os
import copyimport torch
import torch.nn as nn
from torchvision.models.resnet import resnet18
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.fx.graph_module import ObservedGraphModulefrom dataloader import prepare_dataloader
from train_val import train_model, evaluate_model# 量化模型
def quant_fx(model):# 使用Pytorch中的FX模式对模型进行量化model.eval()qconfig = get_default_qconfig("fbgemm")  # 默认是静态量化qconfig_dict = {"": qconfig,}model_to_quantize = copy.deepcopy(model)# 通过调用prepare_fx和convert_fx直接量化模型prepared_model = prepare_fx(model_to_quantize, qconfig_dict)# print("prepared model: ", prepared_model) # 打印模型quantized_model = convert_fx(prepared_model)# print("quantized model: ", quantized_model) # 打印模型# 保存量化后的模型torch.save(quantized_model.state_dict(), "r18_quant.pth")# 校准函数
def calib_quant_model(model, calib_dataloader):# 判断model一定是ObservedGraphModule,即一定是量化模型,而不是原始模型nn.moduleassert isinstance(model, ObservedGraphModule), "model must be a perpared fx ObservedGraphModule."model.eval()with torch.inference_mode():for inputs, labels in calib_dataloader:model(inputs)print("calib done.")# 比较校准前后的差异
def quant_calib_and_eval(model, test_loader):model.to(torch.device("cpu"))model.eval()qconfig = get_default_qconfig("fbgemm")qconfig_dict = {"": qconfig,}# 原始模型(未量化前的结果)print("model:")evaluate_model(model, test_loader)# 量化模型(未经过校准的结果)model2 = copy.deepcopy(model)model_prepared = prepare_fx(model2, qconfig_dict)model_int8 = convert_fx(model_prepared)print("Not calibration model_int8:")evaluate_model(model_int8, test_loader)# 通过原始模型转换为量化模型model3 = copy.deepcopy(model)model_prepared = prepare_fx(model3, qconfig_dict) # 将模型准备为量化模型,即插入观察节点calib_quant_model(model_prepared, test_loader)  # 使用数据对模型进行校准model_int8 = convert_fx(model_prepared) # 调用convert_fx将模型设置为量化模型torch.save(model_int8.state_dict(), "r18_quant_calib.pth") # 保存校准后的模型# 量化模型(已经过校准的结果)print("Do calibration model_int8:")evaluate_model(model_int8, test_loader)if __name__ == "__main__":# 准备训练数据和测试数据train_loader, test_loader = prepare_dataloader()# 定义模型model = resnet18(pretrained=True)model.fc = nn.Linear(512, 10)# 训练模型(如果事先没有训练)if os.path.exists("r18_row.pth"): # 之前训练过就直接加载权重model.load_state_dict(torch.load("r18_row.pth", map_location="cpu"))else:train_model(model, train_loader, test_loader, torch.device("cuda"))print("train finished.")torch.save(model.state_dict(), "r18_row.pth")# 量化模型quant_fx(model)# 对比是否进行校准的影响quant_calib_and_eval(model, test_loader)

3-2--prepare_dataloader函数

# 准备训练数据和测试数据
def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256):train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=train_transform)test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=test_transform)train_sampler = torch.utils.data.RandomSampler(train_set)test_sampler = torch.utils.data.SequentialSampler(test_set)train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=train_batch_size,sampler=train_sampler,num_workers=num_workers,)test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=eval_batch_size,sampler=test_sampler,num_workers=num_workers,)return train_loader, test_loader

3-3--训练和测试函数

# 训练模型,用于后面的量化
def train_model(model, train_loader, test_loader, device):learning_rate = 1e-2num_epochs = 20criterion = nn.CrossEntropyLoss()model.to(device)optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)for epoch in range(num_epochs):# Trainingmodel.train()running_loss = 0running_corrects = 0for inputs, labels in train_loader:inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)train_loss = running_loss / len(train_loader.dataset)train_accuracy = running_corrects / len(train_loader.dataset)# Evaluationmodel.eval()eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)print("Epoch: {:02d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))return modeldef evaluate_model(model, test_loader, device=torch.device("cpu"), criterion=None):t0 = time.time()model.eval()model.to(device)running_loss = 0running_corrects = 0for inputs, labels in test_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)if criterion is not None:loss = criterion(outputs, labels).item()else:loss = 0# statisticsrunning_loss += loss * inputs.size(0)running_corrects += torch.sum(preds == labels.data)eval_loss = running_loss / len(test_loader.dataset)eval_accuracy = running_corrects / len(test_loader.dataset)t1 = time.time()print(f"eval loss: {eval_loss}, eval acc: {eval_accuracy}, cost: {t1 - t0}")return eval_loss, eval_accuracy

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/115732.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

supervisor的使用

一、supervisor简介 Supervisor是用Python开发的一套通用的进程管理程序,能将一个普通的命令行进程变为后台daemon,并监控进程状态,异常退出时能自动重启。它是通过fork/exec的方式把这些被管理的进程当作supervisor的子进程来启动&#xff…

Leetcode 2909. Minimum Sum of Mountain Triplets II

Leetcode 2909. Minimum Sum of Mountain Triplets II 1. 解题思路2. 代码实现 题目链接:2909. Minimum Sum of Mountain Triplets II 1. 解题思路 这一题思路上就是一个累积数组的思路。 我们要找一个山峰结构,使得其和最小,那么我们只需…

【GWO-KELM预测】基于灰狼算法优化核极限学习机回归预测研究(matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Dapper中使用字符串作为动态参数查询时,结果不是预期的问题

1、如下图,c.industryId作为string类型当作参数传递,解析时会加单引号,即:”c.industryId“, 生成的查询语句就会变成 -- 这里把c.IndustryGroup 当成实际的值所以会查询不出数据 select b.Name,COUNT(c.Id) Num …

爬虫模拟用户登录

使用爬虫模拟用户登录过程一般包括以下几个步骤: 导入所需的库:一般需要导入requests和BeautifulSoup库来发送HTTP请求和解析HTML。 import requestsfrom bs4 import BeautifulSoup 发送GET请求获取登录页面:使用requests库发送GET请求&#…

Spring Boot OAuth 2.0整合详解

目录 一、Spring Boot 2.x 示例 1、初始化设置 2、设置重定向URI 3、配置 application.yml 4、启动应用程序 二、Spring Boot 2.x 属性映射 二、CommonOAuth2Provider 三、配置自定义提供者(Provider)属性 四、覆盖 Spring Boot 2.x 的自动配置…

数学建模——最大流问题(配合例子说明)

目录 一、最大流有关的概念 例1 1、容量网络的定义 2、符号设置 3、建立模型 3.1 每条边的容量限制 3.2 平衡条件 3.3 网络的总流量 4、网络最大流数学模型 5、计算 二、最小费用流 例2 【符号说明】 【建立模型】 (1)各条边的流量限制 &a…

Java赋值运算符(=)

赋值运算符是指为变量或常量指定数值的符号。赋值运算符的符号为“”,它是双目运算符,左边的操作数必须是变量,不能是常量或表达式。 其语法格式如下所示: 变量名称表达式内容 在 Java 语言中,“变量名称”和“表达式…

执行autoreconf -fi的过程报错

https://xie.infoq.cn/article/6bba9dd34fb49b7adacb4aacd https://github.com/curl/curl/blob/master/docs/HTTP3.md#quiche-version curl配置quiche的过程中报错, configure:7902: error: possibly undefined macro: AC_LIBTOOL_WIN32_DLLIf this token and ot…

Linux常用的调试工具

在开发和调试Linux的过程中,经常会遇到各种各样的问题,如程序崩溃、性能低下、内存泄漏等。这时候,调试就显得尤为重要。调试技巧和工具能够帮助开发人员快速定位问题并快速解决。在本文中,我们将介绍一些常用的Linux调试技巧和工…

acwing第 126 场周赛 (扩展字符串)

5281. 扩展字符串 一、题目要求 某字符串序列 s0,s1,s2,… 的生成规律如下: s0 DKER EPH VOS GOLNJ ER RKH HNG OI RKH UOPMGB CPH VOS FSQVB DLMM VOS QETH SQBsnDKER EPH VOS GOLNJ UKLMH QHNGLNJ Asn−1AB CPH VOS FSQVB DLMM VOS QHNG Asn−1AB,其…

MyBatis源码基础-常用类-Configuration

Configuration Configuration类a.java配置b.构建配置类 Configuration类 a.java配置 针对上述的xml配置,可以使用如下的java代码替换: Test public void testConfiguration() {Configuration configuration new Configuration();// 配置propertiesPr…

C++中作为类实例的对象

C中作为类实例的对象 类相当于蓝图,仅声明类并不会对程序的执行产生影响。在程序执行阶段,对象是类的化身。要使用类的功能,通常需要创建其实例—对象,并通过对象访问成员方法和属性。 在C中。类的对象就是该类的某一特定实体&a…

canvas绘制动态视频并且在视频上加上自定义logo

实现的效果&#xff1a;可以在画布上播放动态视频&#xff0c;并且加上自定义的图片logo放在视频的右下角 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthd…

找不到conda可执行文件:解决方法

1.在新版本的pycharm出现的问题如下&#xff1a; 2.解决方法: 2.1 将anaconda\Scripts\conda.exe选中 2.2选择自己的anconda自己的环境&#xff0c;之后就可以正常创建conda环境

自然语言处理---Transformer机制详解之BERT GPT ELMo模型的对比

1 BERT、GPT、ELMo的不同点 关于特征提取器: ELMo采用两部分双层双向LSTM进行特征提取, 然后再进行特征拼接来融合语义信息.GPT和BERT采用Transformer进行特征提取.很多NLP任务表明Transformer的特征提取能力强于LSTM, 对于ELMo而言, 采用1层静态token embedding 2层LSTM, 提取…

IT行业就业方向:探索未来的职业机会

引言&#xff1a; 随着信息技术的飞速发展&#xff0c;IT行业已经成为了当前最具活力和发展潜力的行业之一。在这个充满机遇和挑战的时代&#xff0c;选择一个好的就业方向对于个人的职业发展至关重要。本文将探讨IT行业中哪些方向具有更好的就业前景&#xff0c;并提供一些建…

云原生之深入解析如何使用Vcluster Kubernetes加速开发效率

一、背景 为什么一个已经在使用 Kubernetes 本身方面已经很挣扎的开发人员还要处理虚拟集群呢&#xff1f;答案可能会让您感到惊讶&#xff0c;但我相信虚拟集群实际上比单独的物理集群更容易处理&#xff0c;并且与本地 k3d、KinD 或 minikube 部署的集群相比具有相当多的优势…

node教程(二)

文章目录 1.模块化1.1模块化介绍1.2模块化初体验1.3模块暴露数据(&#x1f53a;)1.4引入&#xff08;导入&#xff09;模块 1.模块化 1.1模块化介绍 ⭐什么是模块化与模块&#xff1f; 将一个复杂的程序文件依据一定规则&#xff08;规范&#xff09;拆分成多个文件的过程称之…

python 之numpy 之随机生成数

文章目录 1. **生成均匀分布的随机浮点数**&#xff1a;2. **生成随机整数**&#xff1a;3. **生成标准正态分布随机数**&#xff1a;4. **生成正态分布随机数**&#xff1a;5. **生成均匀分布的随机浮点数**&#xff1a;6. **生成随机抽样**&#xff1a;7. **设置随机数种子**…