使用tensorrt加速深度学习模型推断

使用tensorrt加速深度学习模型推断

  • 1.import以及数据加载、构建engine函数
  • 2.导入官方模型及CIFAR100数据集
  • 3.不采用tensort的推断时间
  • 4.采用tensort加速—使用tensorrt 库
    • 4.1 导出onnx模型
    • 4.2 生成tensorrt engine 文件
    • 4.3 deserialize
    • 4.4 推断
  • 5.采用tensort加速—使用torch2trt库
  • 参考文献

此博客介绍如何将resnet101模型在CIFAR100数据集的分类任务,使用tensorrt部署。

完整代码如下

1.import以及数据加载、构建engine函数

import argparse
import osimport torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as modelsimport timeimport numpy as np
import tensorrt as trt
import common
import torchvision.transforms as transformsTRT_LOGGER = trt.Logger()
os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # 指定0号GPU可用# mean and std of cifar100 dataset
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401,0.2564384629170883, 0.27615047132568404)def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)cifar100_test_loader = DataLoader(cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)return cifar100_test_loaderdef ONNX_build_engine(onnx_file_path, trt_file):G_LOGGER = trt.Logger(trt.Logger.WARNING)explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)batch_size = 64  with trt.Builder(G_LOGGER) as builder, builder.create_network(explicit_batch) as network, \trt.OnnxParser(network, G_LOGGER) as parser:builder.max_batch_size = batch_sizeconfig = builder.create_builder_config()config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, common.GiB(1))config.set_flag(trt.BuilderFlag.FP16)print('Loading ONNX file from path {}...'.format(onnx_file_path))with open(onnx_file_path, 'rb') as model:print('Beginning ONNX file parsing')parser.parse(model.read())print('Completed parsing of ONNX file')print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))profile = builder.create_optimization_profile()profile.set_shape("input", (1, 3, 32, 32),(1, 3, 32, 32), (batch_size, 3, 32, 32))config.add_optimization_profile(profile)engine = builder.build_serialized_network(network, config)print("Completed creating Engine")with open(trt_file, "wb") as f:f.write(engine)return engine

2.导入官方模型及CIFAR100数据集


if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('-gpu', action='store_true',default=True, help='use gpu or not')parser.add_argument('-b', type=int, default=32,help='batch size for dataloader')args = parser.parse_args()print(args)cifar100_test_loader = get_test_dataloader(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD,num_workers=1,batch_size=args.b)device = "cuda" if args.gpu else "cpu"net = models.resnet101(pretrained=True)net = net.to(device)# # print(net)net.eval()

3.不采用tensort的推断时间

#%%t1 = time.time()for n_iter, (image, label) in enumerate(cifar100_test_loader):pred = net(image.to(device))# print(pred.shape)t2 = time.time()print(t2-t1)

耗时约为8~9s。

4.采用tensort加速—使用tensorrt 库

4.1 导出onnx模型

#%% save onnx input = torch.rand([1, 3, 32, 32]).to(device)onnx_file = "resnet101.onnx"if  os.path.exists(onnx_file):os.remove(onnx_file)torch.onnx.export(net, input, onnx_file,input_names=['input'],  # the model's input namesoutput_names=['output'],dynamic_axes={'input': {0: 'batch_size'},'output': {0: 'batch_size'}},# opset_version=12,)print("onnx file generated!")

4.2 生成tensorrt engine 文件

# %%generate tensorrt engine filetrt_file = "resnet101.trt"ONNX_build_engine(onnx_file, trt_file)print("trt file generated!")

4.3 deserialize

    trt_file = "resnet101.trt"runtime = trt.Runtime(TRT_LOGGER)with open(trt_file, 'rb') as f:engine = runtime.deserialize_cuda_engine(f.read())print("Completed creating Engine")context = engine.create_execution_context()context.set_binding_shape(0, (16, 3, 32, 32))inputs, outputs, bindings, stream = common.allocate_buffers(engine, 32)

4.4 推断

    t1 = time.time()label_ls = []pred_ls = []for n_iter, (image, label) in enumerate(cifar100_test_loader):# print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader)))# print(image)inputs[0].host = image.numpy()trt_outputs = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=32)label_ls.extend(label.numpy())pred_ls.extend(np.array(trt_outputs[0]).reshape([-1, 100]).argmax(1).tolist())# print((np.array(pred_ls)[:10000]==np.array(label_ls)[:10000]).sum())t2 = time.time()print(t2-t1)

耗时约为4.3s,是用我的笔记本 上的GPU RTX 3050可以实现两倍左右的加速。

5.采用tensort加速—使用torch2trt库

nvidia还有torch2trt Python包,可用于一键tensorrt加速。

其安装可参考https://github.com/NVIDIA-AI-IOT/torch2trt.

git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
python setup.py install

torch2trt的使用可参考github torch2trt:

    from torch2trt import torch2trtinputs = torch.rand([1, 3, 32, 32]).to(device)model_trt = torch2trt(net, [inputs], fp16_mode=True)t1 = time.time()label_ls = []pred_ls = []for n_iter, (image, label) in enumerate(cifar100_test_loader):output_trt = model_trt(image.to(device))t2 = time.time()print(t2-t1)

使用起来不要太easy!

完整代码可参考https://github.com/L0-zhang/tentorrt_demo/tree/main

参考文献

[1] csdn pytorch TensorRT 官方例子
[2] https://github.com/NVIDIA-AI-IOT/torch2trt
[3] https://github.com/L0-zhang/tentorrt_demo/tree/main

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/193389.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql(八)docker版Mysql8.x设置大小写忽略

Mysql 5.7设置大小写忽略可以登录到Docker内部&#xff0c;修改/etc/my.cnf添加lower_case_table_names1&#xff0c;并重启docker使之忽略大小写。但MySQL8.0后不允许这样&#xff0c;官方文档记录&#xff1a; lower_case_table_names can only be configured when initializ…

SpringBoot集成knife4j

1&#xff09;添加knife4j的依赖 <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-spring-boot-starter</artifactId><version>3.0.2</version> </dependency>SpringBoot的版本&#xff1a; <pa…

机器人与3D视觉 Robotics Toolbox Python 一 安装 Robotics Toolbox Python

一 安装python 库 前置条件需要 Python > 3.6&#xff0c;使用pip 安装 pip install roboticstoolbox-python测试安装是否成功 import roboticstoolbox as rtb print(rtb.__version__)输出结果 二 Robotics Toolbox Python样例程序 加载机器人模型 加载由URDF文件定义…

【python】python列表的用法记录

文章目录 序言1. 列表的创建2. 列表的访问3. 列表的更新4. 列表的删除5. 列表的元素查找6. 列表的脚本操作符7. 列表的函数/方法8. 列表的一些其他操作 序言 总结字典的常见用法&#xff0c;以备查阅 1. 列表的创建 列表是python中最常用的数据类型&#xff0c;其数据项不需要…

【算法每日一练]-图论(保姆级教程篇12 tarjan篇)#POJ3352道路建设 #POJ2553图的底部 #POJ1236校园网络 #缩点

目录 POJ3352&#xff1a;道路建设 思路&#xff1a; POJ2553&#xff1a;图的底部 思路&#xff1a; POJ1236校园网络 思路&#xff1a; 缩点&#xff1a; 思路&#xff1a; POJ3352&#xff1a;道路建设 由于道路要维修&#xff0c;维修时候来回都不能走&#xff0c;现要…

MDK提示:在多字节的目标代码中,没有此Unicode 字符可以映射到的字符

MDK警告提示在多字节的目标代码中&#xff0c;没有此Unicode 字符可以映射到的字符 警告提示&#xff1a; 在写MDK的工程代码时&#xff0c;发现代码中引入的头文件前方出现一些红色的叉叉&#xff0c;但是编译工程并不报错&#xff0c;功能也能正常执行的&#xff0c;只是提…

JS利用时间戳倒计时案例

我们在逛某宝&#xff0c;或者逛某东时&#xff0c;我们时常看到一个倒计时&#xff0c;时间一到就开抢&#xff0c;这个倒计时是如何做的呢&#xff1f;让我为大家介绍一下。 理性分析一下&#xff1a; 1.用将来时间减去现在时间就是剩余的时间 2.核心&#xff1a;使用将来的时…

工业机器视觉megauging(向光有光)使用说明书(十五,轻量级的visionpro)

程序&#xff08;软件&#xff09;的一些不足和建议&#xff1a;&#xff08;后续会跟进&#xff09; 不足&#xff1a;&#xff08;如果你发现了&#xff0c;谢谢及时提出来&#xff09; 1&#xff0c;找线工具有噪点抑制功能&#xff1b;blob跟随工具&#xff0c;匹配跟随工…

C指针介绍(1)

文章目录 每日一言指针的简单介绍内存和地址指针在内存中的存储指针的定义和声明泛型指针 指针的关系运算算数运算关系运算 结语 每日一言 ⭐「 一声梧叶一声秋&#xff0c;一点芭蕉一点愁&#xff0c;三更归梦三更后。 」–水仙子夜雨-徐再思 指针的简单介绍 C语言指针是C语…

人工智能轨道交通行业周刊-第67期(2023.11.27-12.3)

本期关键词&#xff1a;列车巡检机器人、城轨智慧管控、制动梁、断路器、AICC大会、Qwen-72B 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro…

python调用打印机并打印

【01获取打印机列表】 要获取Python中的打印机列表&#xff0c;可以使用win32print模块&#xff08;适用于Windows系统&#xff09;或cups模块&#xff08;适用于Linux和macOS系统&#xff09;。 以下是使用这两个模块分别获取打印机列表的示例代码&#xff1a; **在Windows…

算法工程师面试八股(搜广推方向)

文章目录 机器学习线性和逻辑回归模型逻辑回归二分类和多分类的损失函数二分类为什么用交叉熵损失而不用MSE损失&#xff1f;偏差与方差Layer Normalization 和 Batch NormalizationSVM数据不均衡特征选择排序模型树模型进行特征工程的原因GBDTLR和GBDTRF和GBDTXGBoost二阶泰勒…

React使报错不再白屏

如果代码中出现问题导致报错&#xff0c;通常会使页面报错&#xff0c;导致白屏 function Head() {// 此时模拟报错导致的白屏return <div>Head --- {content}</div> } export default () > {return (<><div>下面是标题</div><Head />…

若依框架分页

文章目录 一、分页功能解析1.前端代码分析2.后端代码分析3. LIMIT含义 二、自定义MyPage,多态获取total1.定义MyPage类和对应的调用方法 一、分页功能解析 1.前端代码分析 页面代码 封装的api请求 接口请求 2.后端代码分析 controller代码 - startPage() getDataTable(…

yolo.txt格式与voc格式互转,超详细易上手

众所周知,yolo训练所需的标签文件类型是.txt的,但我们平时使用标注软件(labelimage等)标注得到的标签文件是.xml类型的,故此xml2txt之间的转换就至关重要了,这点大家不可能想不到,但是网上的文章提供的代码大多数都是冗余,或者难看,难以上手,故此作者打算提供一个相对…

String StringBuilder

String 特点&#xff1a; 1.双引号的字符串是String类的对象&#xff0c;可以点方法 2.字符串一旦创建&#xff0c;其内容不可改变 3.字符串常量池可以共享 方法 equals()&#xff1a;比较字符串的内容toCharArray()&#xff1a;转为字符数组字符串.charAt(int index)&#x…

Sharding-Jdbc(3):Sharding-Jdbc分表

1 分表分库 LogicTable 数据分片的逻辑表&#xff0c;对于水平拆分的数据库(表)&#xff0c;同一类表的总称。 订单信息表拆分为2张表,分别是t_order_0、t_order_1&#xff0c;他们的逻辑表名为t_order。 ActualTable 在分片的数据库中真实存在的物理表。即上个示例中的t_…

怎样使用rtsp,rtmp摄像头低延时参于Web视频会议互动直播

业务系统中有大量的rtsp&#xff0c;rtmp等监控直播设备&#xff0c;原大部分都是单一业务监控直播之类&#xff0c;目前很多业务需要会议互动&#xff0c;需要监控参会&#xff0c;提出需摄像头拉流参会的需求&#xff0c;由于rtmp&#xff0c;rtsp原生不支持web播放&#xff…

UVa1583生成元(Digit Generator)

题目 如果x加上x的各个数字之和得到y&#xff0c;也就是说x是y的生成元。给出n(1<n<100000)&#xff0c;求最小生成元。无解则输出0。 输入输出样例 输入 3 216 121 2005输出 198 0 1979思路 要想解决这个题目&#xff0c;只需要对每一个输入的值从1开始遍历找到小于…

vue3-在自定义hooks使用useRouter 报错问题

文章目录 前言一、报错分析报错的Vue warn截图&#xff1a;查看文档 二、那么在hook要怎么引入路由呢&#xff1f; 前言 记录在vue3项目中&#xff0c;hook使用useRouter 报错问题 一、报错分析 报错的Vue warn截图&#xff1a; 警告 inject() can only be used inside setup…