从零开始 TensorRT(3)Python 篇:解析 ONNX、PyTorch TensorRT 接口

前言

学习资料:
TensorRT 源码示例
官方文档:Working With TensorRT Using The Python API
官方文档:TensorRT Python
官方文档:CUDA Python
B站视频教程
视频配套代码 cookbook

示例:解析 ONNX 模型

参考源码:cookbook → 04-BuildEngineByONNXParser → pyTorch-ONNX-TensorRT

源码

  cookbook 中自定义了一个网络并在 MNIST 数据集上进行训练,然后保存为 ONNX 并用 TensorRT 读取生成引擎并进行推理。这里对代码进行了简化,直接用 Pytorch 提供的训练好的 ResNet18,存为 ONNX 然后推理。
  测试数据使用了 TensorRT 中自带的数据,./TensorRT-8.6.1.6/data/resnet50 下有4张图片以及标签 class_labels.txt。将图像复制到 ./data/images 下,标签复制到 ./data 下即可。

import os
import numpy as np
from PIL import Imageimport torch
import torchvision.models as models
import torchvision.transforms as transformsimport tensorrt as trt
from cuda import cudartbUseFP16Mode = False
bUseINT8Mode = Trueh, w = 224, 224
dataPath = 'data/images'
imgFiles = [os.path.join(dataPath, f) for f in os.listdir(dataPath)]
labelFile = 'data/class_labels.txt'
onnxFile = 'resnet18.onnx'
if bUseFP16Mode:trtFile = 'fp16.plan'
elif bUseINT8Mode:trtFile = 'int8.plan'
else:trtFile = 'fp32.plan'batch_size = len(imgFiles)
with open(labelFile, 'r') as f:label = np.array(f.readlines())# 准备数据
transform = transforms.Compose([transforms.Resize((h, w)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])def load_images(image_paths):images = [Image.open(image_path) for image_path in image_paths]tensors = [transform(image).unsqueeze(0) for image in images]res = torch.cat(tensors, dim=0)return resinput_tensor = load_images(imgFiles).cuda()weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights, progress=False).eval().cuda()result = model(input_tensor)print(f'PyTorch results: {label[result.argmax(dim=1).cpu()]}')# 导出 ONNX
torch.onnx.export(model,torch.randn(1, 3, h, w, device='cuda'),onnxFile,input_names=['x'],output_names=['y'],do_constant_folding=True,verbose=True,keep_initializers_as_inputs=True,opset_version=12,dynamic_axes={'x': {0: 'nBatchSize'}, 'y': {0: 'nBatchSize'}}
)# INT8 模式校准器
class MyCalibrator(trt.IInt8EntropyCalibrator2):def __init__(self, data_path, n_calibration, input_shape, cache_file):trt.IInt8EntropyCalibrator2.__init__(self)self.imageList = [os.path.join(data_path, f) for f in os.listdir(data_path)]self.nCalibration = n_calibrationself.shape = input_shapeself.bufferSize = trt.volume(input_shape) * trt.float32.itemsizeself.cacheFile = cache_file_, self.dIn = cudart.cudaMalloc(self.bufferSize)self.oneBatch = self.batch_generator()def __del__(self):cudart.cudaFree(self.dIn)def batch_generator(self):for i in range(self.nCalibration):print("> calibration %d" % i)sub_images = np.random.choice(self.imageList, self.shape[0], replace=False)yield np.ascontiguousarray(load_images(sub_images).numpy())def get_batch_size(self):  # necessary APIreturn self.shape[0]def get_batch(self, nameList=None, inputNodeName=None):  # necessary APItry:data = next(self.oneBatch)cudart.cudaMemcpy(self.dIn, data.ctypes.data, self.bufferSize, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)return [int(self.dIn)]except StopIteration:return Nonedef read_calibration_cache(self):  # necessary APIif os.path.exists(self.cacheFile):print("Succeed finding cache file: %s" % self.cacheFile)with open(self.cacheFile, "rb") as f:cache = f.read()return cacheelse:print("Failed finding int8 cache!")returndef write_calibration_cache(self, cache):  # necessary APIwith open(self.cacheFile, "wb") as f:f.write(cache)print("Succeed saving int8 cache!")return# 构建期
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
if bUseFP16Mode:config.set_flag(trt.BuilderFlag.FP16)
if bUseINT8Mode:config.set_flag(trt.BuilderFlag.INT8)config.int8_calibrator = MyCalibrator(dataPath, 1, (4, 3, h, w), 'int8.cache')# 加载 ONNX
parser = trt.OnnxParser(network, logger)
with open(onnxFile, "rb") as model:if not parser.parse(model.read()):print("Failed parsing .onnx file!")for error in range(parser.num_errors):print(parser.get_error(error))exit()print("Succeeded parsing .onnx file!")inputTensor = network.get_input(0)
profile.set_shape(inputTensor.name, [1, 3, h, w], [4, 3, h, w], [8, 3, h, w])
config.add_optimization_profile(profile)
# 生成序列化网络
engineString = builder.build_serialized_network(network, config)
with open(trtFile, "wb") as f:f.write(engineString)# 运行期
engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
nIO = engine.num_io_tensors
lTensorName = [engine.get_tensor_name(i) for i in range(nIO)]
nInput = [engine.get_tensor_mode(lTensorName[i]) for i in range(nIO)].count(trt.TensorIOMode.INPUT)context = engine.create_execution_context()
context.set_input_shape(lTensorName[0], [batch_size, 3, h, w])inputHost = np.ascontiguousarray(input_tensor.cpu().numpy())
outputHost = np.empty(context.get_tensor_shape(lTensorName[1]),dtype=trt.nptype(engine.get_tensor_dtype(lTensorName[1]))
)_, inputDevice = cudart.cudaMalloc(inputHost.nbytes)
_, outputDevice = cudart.cudaMalloc(outputHost.nbytes)
context.set_tensor_address(lTensorName[0], inputDevice)
context.set_tensor_address(lTensorName[1], outputDevice)cudart.cudaMemcpy(inputDevice, inputHost.ctypes.data, inputHost.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)
context.execute_async_v3(0)
cudart.cudaMemcpy(outputHost.ctypes.data, outputDevice, outputHost.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost)print(f'TensorRT results: {label[outputHost.argmax(axis=1)]}')cudart.cudaFree(inputDevice)
cudart.cudaFree(outputDevice)

代码解析

  代码整体流程就是加载 PyTorch 提供的预训练模型 ResNet-18 并在几张 ImageNet 数据上进行推理,然后把模型保存为 ONNX,最后用 TensorRT 中的 Parser 加载模型进行推理。
  PyTorch 模型的存储路径为 /home/xxx/.cache/torch/hub/checkpoints
  比较陌生的部分在于导出 ONNX 所使用的 API torch.onnx.export(),以及使用 INT8 模式时需要额外编写一个校准器。

(1)导出 ONNX 模型:官方文档
部分参数解释:
do_constant_folding=True 是否进行常量折叠
verbose=True 是否打印详细信息
keep_initializers_as_inputs=True 是否将模型参数作为输入,个人理解是当设为 True 时,模型的参数是不固定的,是输入的一部分,这样可以加载不同参数的模型。而设为 False 时,模型的参数被固定了,但会更有利于优化加速。
dynamic_axes 设定动态轴

(2)INT8 模式下的校准器
config.int8_calibrator = MyCalibrator(dataPath, 1, (4, 3, h, w), 'int8.cache')

  INT8 模式需要确定每个权重张量的量化范围,以便在量化时保持模型的精度。此处算是作弊了,用推理的数据作为校准数据,并且校准时也把 4 张图像作为一个 batch 输入,从而得到的结果相同。若设为 config.int8_calibrator = MyCalibrator(dataPath, 5, (1, 3, h, w), 'int8.cache') 会发现结果有所不同。

  input_shape 对应了输入形状和 batch size。校准器的作用就是从输入的数据集中随机挑选 batch size 个样本,循环校准 n_calibration 次。

踩坑

  最初仅使用 Numpy 做数据预处理:

image = Image.open(imgFile).resize((224, 224))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = (np.array(image, dtype=np.float32) / 255 - mean) / std
image = np.expand_dims(image.transpose((2, 0, 1)), axis=0)
input_tensor = torch.from_numpy(image).cuda()

  报错:RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same
  原因是第 4 行做标准化计算时,会把数据类型自动转变为 float64。若使用 torch.FloatTensor(image) 可解决报错,但是在 TensorRT 推理时使用 inputHost = np.ascontiguousarray(image) 会导致推理结果不同但没有报错,必须使用一样的 float32 数据。这里因为在 Pytorch 上推理了,后续用 input_tensor.cpu().numpy() 很安全,但实际使用会跳过这一步直接在 TensorRT 上推理,要小心数据类型问题。

示例:PyTorch 框架内 TensorRT 接口

参考源码:cookbook → 06-UseFrameworkTRT → Torch-TensorRT
PyTorch 官方示例

源码

  省略加载数据部分,与上个示例相同。测试时发现启用 TorchScript 会让推理速度加快很多。

import torch_tensorrtTorchScript = True
if TorchScript:model = torch.jit.trace(model, torch.randn(batch_size, 3, h, w, device="cuda"))optimized_model = torch_tensorrt.compile(model,inputs=[torch.randn((batch_size, 3, h, w)).float().cuda()],enabled_precisions=torch.float,debug=True,
)optimized_result = optimized_model(input_tensor)
print(f'Torch TensorRT results: {label[optimized_result.argmax(dim=1).cpu()]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/667622.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

get通过发送Body传参-工具类

1、调用方式 String url "http://ip/xxx/zh/xxxxx/xxxx/userCode"; //进行url中的对应的参数 url2 url2.replace("ip",bancirili); url2 url2.replace("zh",zh); url2 url2.replace("userCode",userCode);String dateTime xxxx; //组…

深度学习系列55:深度学习加速技术概述

总体有两个方向&#xff1a;模型优化 / 框架优化 1. 模型优化 1.1 量化 最常见的量化方法为线性量化&#xff0c;权重从float32量化为int8&#xff0c;将输入数据映射在[-128,127]的范围内。在 nvdia gpu&#xff0c;x86、arm 和 部分 AI 芯片平台上&#xff0c;均支持 8bit…

Python使用回调函数或async/await关键字、协程实现异步编程

异步编程是一种编程模式,它允许程序在执行某个任务时,能够同时执行其他任务而不需要等待当前任务完成。在传统的同步编程中,程序执行一个任务后必须等待该任务完成后才能继续执行下一个任务。而在异步编程中,程序可以发起一个任务后立即执行其他任务,当原先的任务完成后,…

LiveData 迁移到 Kotlin Flow详解

LiveData ,是Android 2017推出的一个东西,配合MVVM使用。观察者模式,的确简化了我们的工作方式,但 RxJava 等选项,对于当时的初学者来说实在是太复杂了。因此 Architecture Components 团队创建了 LiveData :这是个非常 “有主见的” 可观察数据持有者类,并且是专门为 A…

Go 函数 可变参数

有的时候&#xff0c;我们定义函数&#xff0c;都是形参固定&#xff0c;有时候我们并不知道参数有多少个&#xff0c;这里我们就可以用... 来使用可变参数 代码如下&#xff1a; package main import "fmt"//定义一个函数&#xff0c;函数的参数为&#xff1a;可变…

AI时代,人才的“重新定义”

参加一个讨论&#xff0c;梳理一下自己所理解到一些内容&#xff0c;有不对的请指正&#xff1a; 我们要好好借助AI工具成为一个π型人才&#xff0c;π型人才 是指至少拥有两种专业技能&#xff0c;并能将多门知识融会贯通的高级复合型人才。 π下面的两竖指两种专业技能&…

全自动网页生成系统重构版源码

全自动网页生成系统重构版源码分享&#xff0c;所有模板经过精心审核与修改&#xff0c;完美兼容小屏手机大屏手机&#xff0c;以及各种平板端、电脑端和360浏览器、谷歌浏览器、火狐浏览器等等各大浏览器显示。 为用户使用方便考虑&#xff0c;全自动网页制作系统无需繁琐的注…

PMP资料怎么学?PMP备考经验分享

PMP考试前大家大多都是提前备考个一两个月&#xff0c;但是有些朋友喜欢“不走寻常路”&#xff0c;并不打算去考PMP认证&#xff0c;想要单纯了解PMP&#xff0c;不管要不要考证&#xff0c;即使是仅仅学习了解一下我个人都非常支持&#xff0c;因为专业的基础的确能提高工作效…

基恩士 KV-8000 PLC通讯简单测试

1、KV-8000通讯协议 基恩士 KV-8000 PLC支持多种通讯方式&#xff0c;包括&#xff1a;OPC UA、Modbus、上位链路命令等。其中OPC UA需要对服务器和全局变量进行设置&#xff0c;Modbus需要调用功能块。默认支持的是上位链路命令&#xff0c;实际是一条条以回车换行结束的ASCII…

基于微信小程序的医保行政执法案件管理系统

本系统设计的是一个医保行政执法的网站&#xff0c;此网站使用户实现了不需出门就可以在手机或电脑前进行网上查询需求信息等。 用户在注册登陆后&#xff0c;在客户端可以实现&#xff1b;案件信息、结案归档、我的等。然而管理员则可以在服务端直接管理&#xff1b;个人中心、…

【已解决】Oracle 12541 TNS 无监听程序

目录 1、找到Oracle监听服务&#xff08;OracleOraDb10g_homeTNLListener&#xff09;&#xff0c;停止运行 2、首先查看监听文件是否超过4G 3、修改配置文件 连接oracle突然报错&#xff0c;提示Oracle 12541 TNS 无监听程序&#xff0c;可以按照以下步骤解决 1、找到Ora…

Redis-布隆过滤器解决穿透详解

本文已收录于专栏 《中间件合集》 目录 背景介绍概念说明原理说明解决穿透安装使用安装过程Redis为普通安装的配置方式Redis为Docker镜像安装的配置方式 具体使用控制台操作命令说明Spring Boot集成布隆过滤器 总结提升 背景介绍 布隆过滤器可以帮助我们解决Redis缓存雪崩的问题…

Fink CDC数据同步(四)Mysql数据同步到Kafka

依赖项 将下列依赖包放在flink/lib flink-sql-connector-kafka-1.16.2 创建映射表 创建MySQL映射表 CREATE TABLE if not exists mysql_user (id int,name STRING,birth STRING,gender STRING,PRIMARY KEY (id) NOT ENFORCED ) WITH (connector mysql-cdc,hostn…

算法学习打卡day47|单调栈系列题目

单调栈题目思路 通常是一维数组&#xff0c;要寻找任一个元素的右边或者左边第一个比自己大或者小的元素的位置&#xff0c;此时我们就要想到可以用单调栈了。时间复杂度为O(n)。单调栈的本质是空间换时间&#xff0c;因为在遍历的过程中需要用一个栈来记录右边第一个比当前元…

链式二叉树(三种遍历)

1.链式二叉树的遍历&#xff1a;前序&#xff08;根&#xff0c;左子树&#xff0c;右子树&#xff09;中序&#xff08;左子树&#xff0c;根&#xff0c;右子树&#xff09;后序&#xff08;左子树&#xff0c;右子树&#xff0c;根&#xff09;层序&#xff08;一层一层访问…

探索深度学习的边界:使用 TensorFlow 实现高效空洞卷积(Atrous Convolution)的全面指南

空洞卷积&#xff08;Atrous Convolution&#xff09;&#xff0c;在 TensorFlow 中通过 tf.nn.atrous_conv2d 函数实现&#xff0c;是一种强大的工具&#xff0c;用于增强卷积神经网络的功能&#xff0c;特别是在处理图像和视觉识别任务时。这种方法的核心在于它允许网络以更高…

Flask 入门6:模板继承

一个网站中&#xff0c;大部分网页的模块是重复的&#xff0c;比如顶部的导航栏&#xff0c;底部的备案信息。如果在每个页面中都重复的去写这些代码&#xff0c;会让项目变得臃肿&#xff0c;提高后期的维护成本。比较好的做法是&#xff0c;通过模板继承&#xff0c;把一些重…

电脑文件误删除怎么办?8个恢复软件解决电脑磁盘数据可能的误删

您是否刚刚发现您的电脑磁盘数据丢失了&#xff1f;不要绝望&#xff01;无论分区是否损坏、意外格式化或配置错误&#xff0c;存储在其上的文件都不一定会丢失到数字深渊。 我们已经卷起袖子&#xff0c;深入研究电脑分区恢复软件的广阔领域&#xff0c;为您带来一系列最有效…

合并排序算法

合并排序依赖于合并操作&#xff0c;即将两个已经排序的序列合并成一个序列&#xff0c;具体的过程如下&#xff1a; 1申请空间&#xff0c;使其大小为两个已经排序序列之和&#xff0c;然后将待排序数组复制到该数组中。 2设定两个指针&#xff0c;最初位置分别为两个已经排…

如何标准化地快速编辑文档

介绍个公文类的文档技巧吧&#xff0c;尤其在国企、机关、有ISO管理体系内控要求的会议记录、公文写作等&#xff0c;要求大同小异&#xff0c;一般都是中规中矩的【GB/T 9704—2012】&#xff0c;其实国标本身就是经过长期检验&#xff0c;证明是最规范合理&#xff0c;阅读效…