【TVM 教程】使用 TVM 部署框架预量化模型

本文介绍如何将深度学习框架量化的模型加载到 TVM。预量化模型的导入是 TVM 中支持的量化之一。有关 TVM 中量化的更多信息,参阅 此处。

这里演示了如何加载和运行由 PyTorch、MXNet 和 TFLite 量化的模型。加载后,可以在任何 TVM 支持的硬件上运行编译后的量化模型。

首先,导入必要的包:

from PIL import Image
import numpy as np
import torch
from torchvision.models.quantization import mobilenet as qmobilenetimport tvm
from tvm import relay
from tvm.contrib.download import download_testdata

定义运行 demo 的辅助函数:

def get_transform():import torchvision.transforms as transformsnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])return transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize,])def get_real_image(im_height, im_width):img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"img_path = download_testdata(img_url, "cat.png", module="data")return Image.open(img_path).resize((im_height, im_width))def get_imagenet_input():im = get_real_image(224, 224)preprocess = get_transform()pt_tensor = preprocess(im)return np.expand_dims(pt_tensor.numpy(), 0)def get_synset():synset_url = "".join(["https://gist.githubusercontent.com/zhreshold/","4d0b62f3d01426887599d4f7ede23ee5/raw/","596b27d23537e5a1b5751d2b0481ef172f58b539/","imagenet1000_clsid_to_human.txt",])synset_name = "imagenet1000_clsid_to_human.txt"synset_path = download_testdata(synset_url, synset_name, module="data")with open(synset_path) as f:return eval(f.read())def run_tvm_model(mod, params, input_name, inp, target="llvm"):with tvm.transform.PassContext(opt_level=3):lib = relay.build(mod, target=target, params=params)runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device(target, 0)))runtime.set_input(input_name, inp)runtime.run()return runtime.get_output(0).numpy(), runtime

从标签到类名的映射,验证模型的输出是否合理:

synset = get_synset()

用猫的图像进行演示:

inp = get_imagenet_input()

部署量化的 PyTorch 模型

首先演示如何用 PyTorch 前端加载由 PyTorch 量化的深度学习模型。

参考 PyTorch 静态量化教程,了解量化的工作流程。

用下面的函数来量化 PyTorch 模型。此函数采用浮点模型,并将其转换为 uint8。这个模型是按通道量化的。

def quantize_model(model, inp):model.fuse_model()model.qconfig = torch.quantization.get_default_qconfig("fbgemm")torch.quantization.prepare(model, inplace=True)# Dummy calibrationmodel(inp)torch.quantization.convert(model, inplace=True)

从 torchvision 加载预量化、预训练的 Mobilenet v2 模型

之所以选择 mobilenet v2,是因为该模型接受了量化感知训练,而其他模型则需要完整的训练后校准。

qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval()

输出结果:

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /workspace/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth0%|          | 0.00/13.6M [00:00<?, ?B/s]44%|####4     | 6.03M/13.6M [00:00<00:00, 63.2MB/s]89%|########8 | 12.1M/13.6M [00:00<00:00, 61.4MB/s]
100%|##########| 13.6M/13.6M [00:00<00:00, 66.0MB/s]

量化、跟踪和运行 PyTorch Mobilenet v2 模型

量化和 jit 的详细信息可参考 PyTorch 网站上的教程。

pt_inp = torch.from_numpy(inp)
quantize_model(qmodel, pt_inp)
script_module = torch.jit.trace(qmodel, pt_inp).eval()with torch.no_grad():pt_result = script_module(pt_inp).numpy()

输出结果:

/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:179: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.reduce_range will be deprecated in a future release of PyTorch."
/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:1126: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero pointReturning default scale and zero point "

使用 PyTorch 前端将量化的 Mobilenet v2 转换为 Relay-QNN

PyTorch 前端支持将量化的 PyTorch 模型,转换为具有量化感知算子的等效 Relay 模块。将此表示称为 Relay QNN dialect。

若要查看量化模型是如何表示的,可以从前端打印输出。

可以看到特定于量化的算子,例如 qnn.quantize、qnn.dequantize、qnn.requantize 和 qnn.conv2d 等。

input_name = "input"  # 对于 PyTorch 前端,输入名称可以是任意的。
input_shapes = [(input_name, (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
# print(mod) # 打印查看 QNN IR 转储

编译并运行 Relay 模块

获得量化的 Relay 模块后,剩下的工作流程与运行浮点模型相同。详细信息请参阅其他教程。

在底层,量化特定的算子在编译之前,会被降级为一系列标准 Relay 算子。

target = "llvm"
tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target=target)

输出结果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead."target_host parameter is going to be deprecated. "

比较输出标签

可看到打印出相同的标签。

pt_top3_labels = np.argsort(pt_result[0])[::-1][:3]
tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3]print("PyTorch top3 labels:", [synset[label] for label in pt_top3_labels])
print("TVM top3 labels:", [synset[label] for label in tvm_top3_labels])

输出结果:

PyTorch top3 labels: ['tiger cat', 'Egyptian cat', 'tabby, tabby cat']
TVM top3 labels: ['tiger cat', 'Egyptian cat', 'tabby, tabby cat']

但由于数字的差异,通常原始浮点输出不应该是相同的。下面打印 mobilenet v2 的 1000 个输出中,有多少个浮点输出值是相同的。

print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0]))

输出结果:

154 in 1000 raw floating outputs identical.

测试性能

以下举例说明如何测试 TVM 编译模型的性能。

n_repeat = 100  # 为使测试更准确,应选取更大的数值
dev = tvm.cpu(0)
print(rt_mod.benchmark(dev, number=1, repeat=n_repeat))

输出结果:

Execution time summary:mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)90.3752      90.2667      94.6845      90.0629       0.6087

备注

推荐这种方法的原因如下:

  • 测试是在 C++ 中完成的,因此没有 Python 开销大
  • 包括几个准备工作
  • 可用相同的方法在远程设备(Android 等)上进行分析。

备注

如果硬件对 INT8 整数的指令没有特殊支持,量化模型与 FP32 模型速度相近。如果没有 INT8 整数的指令,TVM 会以 16 位进行量化卷积,即使模型本身是 8 位。

对于 x86,在具有 AVX512 指令集的 CPU 上可实现最佳性能。这种情况 TVM 对给定 target 使用最快的可用 8 位指令,包括对 VNNI 8 位点积指令(CascadeLake 或更新版本)的支持。

此外,以下一般技巧对 CPU 性能的提升同样适用:

  • 将环境变量 TVM_NUM_THREADS 设置为物理 core 的数量
  • 为硬件选择最佳 target,例如 “llvm -mcpu=skylake-avx512” 或 “llvm -mcpu=cascadelake”(未来会有更多支持 AVX512 的 CPU)

部署量化的 MXNet 模型

待更新

部署量化的 TFLite 模型

待更新

脚本总运行时长: (1 分 7.374 秒)

下载 Python 源代码:deploy_prequantized.py

下载 Jupyter Notebook:deploy_prequantized.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/44330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】常见指令收官权限理解

tar指令 上一篇博客已经介绍了zip/unzip指令&#xff0c;接下来我们来看一下另一个关于压缩和解压的指令&#xff1a;tar指令tar指令&#xff1a;打包/解包&#xff0c;不打开它&#xff0c;直接看内容 关于tar的指令有太多了&#xff1a; tar [-cxtzjvf] 文件与目录 ...…

C++运行时类型识别

目录 C运行时类型识别A.What&#xff08;什么是运行时类型识别RTTI&#xff09;B.Why&#xff08;为什么需要RTTI&#xff09;C.dynamic_cast运算符Why&#xff08;dynamic_cast运算符的作用&#xff09;How&#xff08;如何使用dynamic_cast运算符&#xff09; D.typeid运算符…

【Scrapy】 Scrapy 爬虫框架

准我快乐地重饰演某段美丽故事主人 饰演你旧年共寻梦的恋人 再去做没流着情泪的伊人 假装再有从前演过的戏份 重饰演某段美丽故事主人 饰演你旧年共寻梦的恋人 你纵是未明白仍夜深一人 穿起你那无言毛衣当跟你接近 &#x1f3b5; 陈慧娴《傻女》 Scrapy 是…

各地户外分散视频监控点位,如何实现远程集中实时监看?

公司业务涉及视频监控项目承包搭建&#xff0c;此前某个项目需求是为某林业公司提供视频监控解决方案&#xff0c;需要实现各地视频摄像头的集中实时监看&#xff0c;以防止国家储备林的盗砍、盗伐行为。 公司原计划采用运营商专线连接各个视频监控点位&#xff0c;实现远程视…

跟着李沐学AI:线性回归

引入 买房出价需要对房价进行预测。 假设1&#xff1a;影响房价的关键因素是卧室个数、卫生间个数和居住面积&#xff0c;记为x1、x2、x3。 假设2&#xff1a;成交价是关键因素的加权和 。权重和偏差的实际值在后面决定。 拓展至一般线性模型&#xff1a; 给定n维输入&…

MySQL 9.0 正式发行Innovation创新版已支持向量

从 MySQL 8.1 开始&#xff0c;官方启用了新的版本模型&#xff1a;MySQL 创新版 (Innovation) 和长期支持版 (LTS)。 根据介绍&#xff0c;两者的质量都已达到可用于生产环境级别。区别在于&#xff1a; 如果希望尝试最新的功能和改进&#xff0c;并喜欢与最新技术保持同步&am…

怎样在 C 语言中实现栈?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; &#x1f4d9;C 语言百万年薪修炼课程 通俗易懂&#xff0c;深入浅出&#xff0c;匠心打磨&#xff0c;死磕细节&#xff0c;6年迭代&#xff0c;看过的人都说好。 文章目…

动手学深度学习(Pytorch版)代码实践 -循环神经网络-55循环神经网络的从零开始实现和简洁实现

55循环神经网络的实现 1.从零开始实现 import math import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l import matplotlib.pyplot as plt import liliPytorch as lp# 读取H.G.Wells的时光机器数据集 batch_size, num_ste…

开发个人Ollama-Chat--7 服务部署

开发个人Ollama-Chat–7 服务部署 服务部署 go-ChatGPT项目涉及的中间件服务较多&#xff0c;以下部署文件目录&#xff1a; |-- chat-api | |-- etc | | -- config.yaml | -- logs |-- chat-rpc | |-- etc | | -- config.yaml | -- logs |-- docker-compos…

ElasticSearch第一天

学习目标&#xff1a; 能够理解ElasticSearch的作用能够安装ElasticSearch服务能够理解ElasticSearch的相关概念能够使用Postman发送Restful请求操作ElasticSearch能够理解分词器的作用能够使用ElasticSearch集成IK分词器能够完成es集群搭建 第一章 ElasticSearch简介 1.1 什么…

windows 中的 Nsight Systems 通过ssh 链接分析 Linux 中的cuda程序性能

1&#xff0c;Linux 环境 安装 ssh-server $ sudo apt install openssh-server 安装较新版本的 cuda sdk 下载cuda-samples github repo 编辑修改 ssh 配置&#xff1a; $ sudo vim /etc/ssh/sshd_config 删除相关注释&#xff0c;修改后如下&#xff1a; Port 22 Addres…

只会vue的前端开发工程师是不是不能活了?最近被一个flutter叼了

**Vue与Flutter&#xff1a;前端开发的新篇章** 在前端开发的世界里&#xff0c;Vue.js和Flutter无疑是两颗璀璨的明星。Vue以其轻量级、易上手的特点吸引了大量前端开发者的青睐&#xff0c;而Flutter则以其跨平台、高性能的优势迅速崛起。那么&#xff0c;对于只会Vue的前端…

【深度学习基础】环境搭建 linux系统下安装pytorch

目录 一、anaconda 安装二、创建pytorch1. 创建pytorch环境&#xff1a;2. 激活环境3. 下载安装pytorch包4. 检查是否安装成功 一、anaconda 安装 具体的安装说明可以参考我的另外一篇文章【环境搭建】Linux报错bash: conda: command not found… 二、创建pytorch 1. 创建py…

OceanBase:引领下一代分布式数据库技术的前沿

OceanBase的基本概念 定义和特点 OceanBase是一款由蚂蚁金服开发的分布式关系数据库系统&#xff0c;旨在提供高性能、高可用性和强一致性的数据库服务。它结合了关系数据库和分布式系统的优势&#xff0c;适用于大规模数据处理和高并发业务场景。其核心特点包括&#xff1a; …

【考研数学】25张宇强化36讲测评及强化阶段注意事项

张宇新版36讲创新真的很大&#x1f979; 引入了很多张宇老师认为对大家解题帮助很大的技巧和知识点&#xff0c;但是也有人认为是多余的。 张宇老师新版36讲第一讲就讲了整整8个小时&#xff01;&#x1f62d; 大家想想&#xff0c;自己有那个时间去吃透36讲吗&#xff1f;如果…

python调用阿里云汇率接口

整体请求流程 介绍&#xff1a; 本次解析通过阿里云云市场的云服务来实现程序中对货币汇率实时监控&#xff0c;首先需要准备选择一家可以提供汇率查询的商品。 https://market.aliyun.com/apimarket/detail/cmapi00065831#skuyuncode5983100001 步骤1: 选择商品 如图点击…

debian 12 Install

debian 前言 Debian是一个基于Linux内核的自由和开放源代码操作系统&#xff0c;由全球志愿者组成的Debian项目维护和开发。该项目始于1993年&#xff0c;由Ian Murdock发起&#xff0c;旨在创建一个完整的、基于Linux的自由软件操作系统。 debian download debian 百度网盘…

分布式应用系统设计:即时消息系统

即时消息(IM)系统&#xff0c;涉及&#xff1a;站内消息系统 组件如下&#xff1b; 客户端&#xff1a; WEB页面&#xff0c;IM桌面客户端。通过WebSocket 跟ChatService后端服务连接 Chat Service&#xff1a; 提供WebSocket接口&#xff0c;并保持跟“客户端”状态的维护。…

会声会影分割音频怎么不能用 会声会影分割音频方法 会声会影视频制作教程 会声会影下载免费中文版2023

将素材中的音频分割出来&#xff0c;对声音部分进行单独编辑&#xff0c;是剪辑过程中的常用操作。会声会影视频剪辑软件在分割音频后&#xff0c;还可以对声音素材进行混音编辑、音频调节、添加音频滤镜等操作。有关会声会影分割音频怎么不能用&#xff0c;会声会影分割音频方…

如何快速制作您的数据可视化大屏?

数据大屏可视化主要就是借助图形&#xff0c;利用生动、直观的形式展示出数据信息的具体数值&#xff0c;使得使用者短时间内更加直观的接受到大量信息。数据大屏以直观、高度视觉冲击力的方式向受众揭示数据背后隐藏的规律&#xff0c;传达数据价值。其以图形化的形式呈现数据…