【TVM 教程】编译 PyTorch 目标检测模型

本文介绍如何用 Relay VM 部署 PyTorch 目标检测模型。

首先应安装 PyTorch。此外,还应安装 TorchVision,并将其作为模型合集(model zoo)。

可通过 pip 快速安装:

pip install torch
pip install torchvision

或参考官网:https://pytorch.org/get-started/locally/

PyTorch 版本应该和 TorchVision 版本兼容。

目前 TVM 支持 PyTorch 1.7 和 1.4,其他版本可能不稳定。

import tvm
from tvm import relay
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download_testdataimport numpy as np
import cv2# PyTorch 导入
import torch
import torchvision

从 TorchVision 加载预训练的 MaskRCNN 并进行跟踪

in_size = 300
input_shape = (1, 3, in_size, in_size)def do_trace(model, inp):model_trace = torch.jit.trace(model, inp)model_trace.eval()return model_tracedef dict_to_tuple(out_dict):if "masks" in out_dict.keys():return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]return out_dict["boxes"], out_dict["scores"], out_dict["labels"]class TraceWrapper(torch.nn.Module):def __init__(self, model):super().__init__()self.model = modeldef forward(self, inp):out = self.model(inp)return dict_to_tuple(out[0])model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))with torch.no_grad():out = model(inp)script_module = do_trace(model, inp)

输出结果:

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /workspace/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth0%|          | 0.00/170M [00:00<?, ?B/s]9%|9         | 15.3M/170M [00:00<00:01, 160MB/s]19%|#8        | 32.1M/170M [00:00<00:00, 170MB/s]29%|##9       | 49.7M/170M [00:00<00:00, 176MB/s]40%|####      | 68.8M/170M [00:00<00:00, 185MB/s]51%|#####     | 86.4M/170M [00:00<00:00, 175MB/s]61%|######1   | 104M/170M [00:00<00:00, 178MB/s]71%|#######1  | 121M/170M [00:00<00:00, 169MB/s]86%|########6 | 147M/170M [00:00<00:00, 199MB/s]
100%|##########| 170M/170M [00:00<00:00, 193MB/s]
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3878: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).for i in range(dim)
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').for g in grid_sizes
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:127: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).for g in grid_sizes
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:73: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').A = Ax4 // 4
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py:74: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').C = AxC // A
/usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:156: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py:158: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/transform.py:293: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).for s, s_orig in zip(new_size, original_size)
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/roi_heads.py:387: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)

下载测试图像并进行预处理

img_url = ("/img/docs/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
)
img_path = download_testdata(img_url, "test_street_small.jpg", module="data")img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)

将计算图导入 Relay

input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(script_module, shape_list)

输出结果:

/workspace/python/tvm/relay/build_module.py:411: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)DeprecationWarning,

使用 Relay VM 编译

注意:目前仅支持 CPU target。对于 x86 target,因为 TorchVision RCNN 模型中存在大型密集算子,为取得最佳性能,强烈推荐使用 Intel MKL 和 Intel OpenMP 来构建 TVM。

# 在 x86 target上添加“-libs=mkl”以获得最佳性能。
# 对于支持 AVX512 的 x86 机器,完整 target 是
# "llvm -mcpu=skylake-avx512 -libs=mkl"
target = "llvm"with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):vm_exec = relay.vm.compile(mod, target=target, params=params)

输出结果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead."target_host parameter is going to be deprecated. "

使用 Relay VM 进行推理

dev = tvm.cpu()
vm = VirtualMachine(vm_exec, dev)
vm.set_input("main", **{input_name: img})
tvm_res = vm.run()

获取 score 大于 0.9 的 box

score_threshold = 0.9
boxes = tvm_res[0].numpy().tolist()
valid_boxes = []
for i, score in enumerate(tvm_res[1].numpy().tolist()):if score > score_threshold:valid_boxes.append(boxes[i])else:breakprint("Get {} valid boxes".format(len(valid_boxes)))

输出结果:

Get 9 valid boxes

脚本总运行时长: (2 分 57.278 秒)

下载 Python 源代码:deploy_object_detection_pytorch.py

下载 Jupyter Notebook:deploy_object_detection_pytorch.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/24181.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot+Vue在线文档管理系统(前后端分离)

技术栈 JavaSpringBootMavenMySQLMyBatisVueShiroElement-UI 系统角色对应功能 员工管理员 系统功能截图

《精通ChatGPT:从入门到大师的Prompt指南》第1章:认识ChatGPT

第1章&#xff1a;认识ChatGPT 1.1 ChatGPT是什么 ChatGPT&#xff0c;全称为Chat Generative Pre-trained Transformer&#xff0c;是由OpenAI开发的一种先进的自然语言处理模型。它利用了深度学习中的一种技术——Transformer架构&#xff0c;来生成类人文本。ChatGPT通过对…

贪心算法-数组跳跃游戏(mid)

目录 一、问题描述 二、解题思路 1.回溯法 2.贪心算法 三、代码实现 1.回溯法实现 2.贪心算法实现 四、刷题链接 一、问题描述 二、解题思路 1.回溯法 使用递归的方式&#xff0c;找到所有可能的走步方式&#xff0c;并记录递归深度&#xff08;也就是走步次数&#x…

玩转ChatGPT:最全学术论文提示词分享【上】

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 在当今数字时代&#xff0c;人工智能&#xff08;AI&#xff09;技术正迅速改变各行各业的运作方式。特别是&#xff0c;OpenAI的ChatGPT等语言模型以其强大的文本生成能力&#xff0c;…

使用Ant-Design-Vue实现动态表头与数据填充的实战指南

好的&#xff0c;我将为你写一篇关于如何使用Ant-Design-Vue动态生成表头并填充数据的文章。这篇文章将包括一个基本的介绍&#xff0c;详细的步骤和示例代码&#xff0c;以帮助你实现这一功能。 --- # 使用Ant-Design-Vue动态生成表头并填充数据 在现代前端开发中&#xff…

Android ViewPager和ViewPager2的区别

一、实现方式 ViewPager内部是通过继承ViewGroup来实现的&#xff0c;ViewPager2内部是通过RecyclerView来实现的&#xff08;效率更高&#xff09; 二、支持方向 ViewPager只能横向滑动&#xff0c;ViewPager2可以横向以及竖向滑动 三、采用的适配器 ViewPager有两个适配…

算法:70. 爬楼梯

70. 爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;2 解释&#xff1a;有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2 阶 示…

【CTF MISC】XCTF GFSJ0170 János-the-Ripper Writeup(文件提取+ZIP压缩包+暴力破解)

Jnos-the-Ripper 暂无 解法 用 winhex 打开&#xff0c;提到了 flag.txt。 用 binwalk 扫描&#xff0c;找到一些 zip 压缩包。 binwalk misc100用 foremost 提取文件。 foremost misc100 -o 100flag.txt 在压缩包里。 但是压缩包需要解压密码。 用 Ziperello 暴力破解。 不…

mac安装nigix且配置 vue/springboot项目(本地/服务器)

一、mac安装Nigix 1. 查看是否存在 nginx 执行brew search nginx 命令查询要安装的软件是否存在 brew search nginx 2. 安装nginx brew install nginx 3. 查看版本 nginx -v 4. 查看信息 查看ngxin下载的位置以及nginx配置文件存放路径等信息 brew info nginx 下载的存…

Angular 由一个bug说起之六:字体预加载

浏览器在加载一个页面时&#xff0c;会解析网页中的html和css&#xff0c;并开始加载字体文件。字体文件可以通过css中的font-face规则指定&#xff0c;并使用url()函数指定字体文件的路径。 比如下面这样: css font-face {font-family: MyFont;src: url(path/to/font.woff2…

【路径规划】三维深度矩阵寻路算法

在三维空间中寻路相较于二维空间更为复杂&#xff0c;因为需要处理额外的维度。下面是一个示例&#xff0c;展示了如何使用深度优先搜索&#xff08;DFS&#xff09;算法在三维矩阵中寻找路径。 首先&#xff0c;我们需要定义三维矩阵&#xff0c;并编写一个递归的DFS函数来寻…

IDEA 中设置 jdk 的版本

本文介绍一下 IDEA 中设置 jdk 版本的步骤。 一共有三处需要配置。 第一处 File --> Project Structure Project 和 Modules 下都需要指定一下。 第二处 File --> Settings 第三处 运行时的配置

Linux基础2-基本指令4(cp,mv,cat,tac)

上篇文章我们说到了rmdir,rm,man,echo.重定向等知识。 Linux基础1-基本指令3-CSDN博客 本文继续梳理其他基础指令 1.本章重点 1.使用cp命令拷贝文件 2.使用mv命令移动文件 3.使用cat&#xff0c;tac查看小文本文件 2.cp命令 在linux中使用cp命令来拷贝粘贴文件 cp src(原文…

解决Nginx出现An error occurred问题

每个人遇到Nginx的An error occurred情况可能都不一样&#xff08;见图1&#xff09;&#xff0c;Nginx造成该错误的原因&#xff1a; 1. 我在配置域名解析成IP时&#xff0c;没有把所有解析配置都修改&#xff0c;见图2&#xff1a;解析 *.hanxiaozhang.xyz 配置的是新IP地…

【名词解释】Unity的Inputfield组件及其使用示例

Unity的InputField组件是一个UI元素&#xff0c;它允许用户在游戏或应用程序中输入文本。InputField通常用于创建表单、登录界面或任何需要用户输入文本的场景。它提供了多种功能&#xff0c;比如文本验证、占位符显示、输入限制等。 功能特点&#xff1a; 文本输入&#xff…

如何发布自己的npm插件包

随着JavaScript在前端和后端的广泛应用,npm(Node Package Manager)已成为JavaScript开发者不可或缺的工具之一。通过npm,开发者可以轻松共享和使用各种功能模块,极大地提高了开发效率。那么,如何将自己开发的功能模块发布为npm插件包,与全球的开发者共享呢?本文将进行全…

Nvidia Jetson/Orin/算能 +FPGA+AI大算力边缘计算盒子:潍柴雷沃智慧农业无人驾驶

潍柴雷沃智慧农业科技股份有限公司&#xff0c;是潍柴集团重要的战略业务单元&#xff0c;旗下收获机械、拖拉机等业务连续多年保持行业领先&#xff0c;是国内少数可以为现代农业提供全程机械化整体解决方案的品牌之一。潍柴集团完成对潍柴雷沃智慧农业战略重组后&#xff0c;…

JSON5:JSON的现代化扩展

JavaScript Object Notation (JSON)&#xff0c;被广泛应用于网络传输、配置文件等许多地方&#xff0c;因其简洁、易读的特性而备受欢迎。然而&#xff0c;JSON不是无可指摘的&#xff0c;限制太过严格&#xff0c;使得一些场景变得棘手。这时候&#xff0c;JSON的超集——JSO…

轻松解决问题!教你文件怎么解除只读模式!

在日常使用电脑时&#xff0c;我们有时会遇到文件或文件夹被设定为只读模式的情况&#xff0c;这可能会限制我们对文件的修改和编辑。然而&#xff0c;解除只读模式并获得文件的完全控制是一个相对简单的过程&#xff0c;只需要掌握一些基本的技巧和方法。在本文中&#xff0c;…

数据分析面试常问问题(二)(SQL、统计学、业务方面等)

一、数据分析之业务指标高频面试题 1.关于视频app&#xff08;比如爱奇艺&#xff09;首页推荐的推荐顺序&#xff0c;你会考虑哪些指标&#xff1f; &#xff08;1&#xff09;用户行为数据&#xff1a;浏览、点击、播放、搜索、收藏、点赞、转发、滑动、在某个位置的停留时…