【Pytorch】(十五)模型部署:ONNX和ONNX Runtime

文章目录

  • (十五)模型部署:ONNX和ONNX Runtime
    • ONNX 和 ONNX Runtime的关系
    • 将PyTorch模型导出为ONNX格式
    • 使用Netron可视化ONNX模型图
    • 检查ONNX模型
    • 验证ONNX Runtime推理结果
    • 使用ONNX Runtime运行超分模型

(十五)模型部署:ONNX和ONNX Runtime

ONNX 和 ONNX Runtime的关系

ONNX(模型表示格式):Open Neural Network Exchange(ONNX)一种用于表示深度学习模型的标准格式。这个格式允许将模型从一个深度学习框架转移到另一个框架,以及在不同平台上进行推理。

ONNX Runtime(推理引擎):ONNX Runtime(ORT) 是一个用于运行和执行 ONNX 模型的推理引擎。ONNX Runtime 提供了高性能、低延迟的深度网络模型推理,并且是跨平台的,支持各种操作系统和设备。ONNX Runtime已被证明可以显著提高多个模型的推理性能。

想用ONNX和ONNX Runtime进行Pytorch模型部署,首先需要安装以下Python包:

pip install --upgrade onnx onnxscript onnxruntime

将PyTorch模型导出为ONNX格式

Pytorch中torch.onnx模块提供API来从PyTorch的 torch.nn.Module模块捕获计算图,并将其转换为ONNX格式。从PyTorch 2.1开始,ONNX Exporter有两个版本。

torch.onnx.dynamo_export是基于PyTorch 2.0发布的TorchDynamo技术的最新Exporter(仍处于测试版)

torch.onnx.export则基于TorchScript,自PyTorch 1.2.0以来一直可用

本文只介绍torch.onnx.export,关于torch.onnx.dynamo_export,可阅读:
https://pytorch.org/tutorials/beginner/onnx/intro_onnx.html
https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html

下面将以一个图像超分模型为例,介绍如何使用基于TorchScript 的torch.onnx.export将PyTorch中定义的模型转换为ONNX格式。

import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn as nn
import torch.nn.init as init# 1 搭建一个超分模型。
class SuperResolutionNet(nn.Module):def __init__(self, upscale_factor, inplace=False):super(SuperResolutionNet, self).__init__()self.relu = nn.ReLU(inplace=inplace)self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))self.pixel_shuffle = nn.PixelShuffle(upscale_factor)self._initialize_weights()def forward(self, x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.relu(self.conv3(x))x = self.pixel_shuffle(self.conv4(x))return xdef _initialize_weights(self):init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))init.orthogonal_(self.conv4.weight)# 创建一个模型实例
torch_model = SuperResolutionNet(upscale_factor=3)# 2 训练模型或者直接导入预训练的模型参数,这里采用后者:
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'map_location = lambda storage, loc: storage
if torch.cuda.is_available():map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))# 3 将模型转换为推理模式。
# 这是必需的,因为像dropout或batchnorm这样的运算符在推理和训练模式中表现不同。
# set the model to inference mode
torch_model.eval()# 4 导出ONNX模型batch_size = 1    # just a random number
## 首先需要提供一个输入张量x。只要它是正确的类型和大小,其中的值就可以是随机的。
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)## 导出模型
torch.onnx.export(torch_model,               # 模型x,                         # 模型输入"super_resolution.onnx",   # onnx文件保存路径export_params=True,        # 将经过训练的参数权重存储在模型文件中opset_version=10,          # ONNX的版本do_constant_folding=True,  # 执行常量折叠(constant folding)进行优化input_names = ['input'],   # 模型输入的名字output_names = ['output'], #  模型输出的名字dynamic_axes={'input' : {0 : 'batch_size'}, # 将第一个维度指定为dynamic'output' : {0 : 'batch_size'}})
# 计算原始Pytorch模型的输出,用于验证导出的ONNX 模型是否能计算出相同的值。                     
torch_out = torch_model(x)  # 计算原始Pytorch模型的输出

请注意,除非在dynamic_axes指定,否则ONNX模型中输入和输出的尺寸大小都是固定的。在本例中,在torch.onnx.export()中的dynamic_axies参数中将第一个维度指定为dynamic。这使得导出的模型接受大小为 [batch_size, 1, 224, 224]的输入,其中batch_size是可变的。

使用Netron可视化ONNX模型图

Netron可以对ONNX模型图进行可视化。Netron除了可以安装在macos、Linux或Windows系统的计算机上,还可以在浏览器上运行:https://netron.app/

打开Netron后,我们可以将.onnx文件拖放到浏览器中,也可以在单击“打开模型”从文件目录选择它,进行可视化:

检查ONNX模型

在使用ONNX Runtime进行推理之前,我们先使用ONNX API检查ONNX模型。

import onnx
# 加载onnx模型
onnx_model = onnx.load("super_resolution.onnx")
# 验证ONNX模型的有效性,包括通过检查模型的版本、图的结构和节点及其输入和输出
onnx.checker.check_model(onnx_model)

验证ONNX Runtime推理结果

现在,让我们使用ONNX Runtime的Python API来进行推理。

这一部分通常是在另一个进程中或在另一台机器上完成。为了验证ONNX Runtime和PyTorch原始网络模型计算的值是否近似,我们在一个进程进行。

import onnxruntime
# 创建一个推理会话
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 使用ONNX Runtime进行推理
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)# ONNX Runtime和PyTorch原始网络模型输出的近似程度没有达到指定精度(rtol=1e-03和atol=1e-05),将抛出异常。
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")

使用ONNX Runtime运行超分模型


import numpy as np
import onnxruntime
from PIL import Image
import torchvision.transforms as transformsdef to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 创建一个推理会话
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])# 加载图像与预处理
img = Image.open("cat.jpg")resize = transforms.Resize([224, 224])
img = resize(img)img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)# 在ONNX Runtime中运行超分辨率模型
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]# 从输出张量构造最终输出图像,并保存
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')final_img = Image.merge("YCbCr", [img_out_y,img_cb.resize(img_out_y.size, Image.BICUBIC),img_cr.resize(img_out_y.size, Image.BICUBIC),]).convert("RGB")  # Cr, Cb通道通过插值发大final_img.save("cat_superres_with_ort.jpg")

参考:
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/3543.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第十五届蓝桥杯省赛第二场C/C++B组D题【前缀总分】题解(AC)

暴力解法 O ( 26 n 5 ) O(26n^5) O(26n5) 枚举将第 i i i 个字符串的第 j j j 个字符改为 c c c 的所有方案,时间复杂度 O ( 26 n 2 ) O(26n^2) O(26n2),修改并计算总分, O ( n 3 ) O(n^3) O(n3)。 暴力优化 O ( 26 n 3 log ⁡ n ) O…

基于Python实现心脏病数据可视化DEA+预测【500010103.1】

一、数据说明 该心脏病数据集是通过组合 5 个已经独立可用但以前未合并的流行心脏病数据集来策划的。在这个数据集中,5 个心脏数据集结合了 11 个共同特征,使其成为迄今为止可用于研究目的的最大心脏病数据集。 该数据集由 1190 个实例和 11 个特征组成…

《第二行代码》第二版学习笔记(6)——内容提供器

文章目录 一 运行时权限2.权限分类3 运行时申请权限 二、内容提供器1、 ContentResolver的基本用法2、现有的内容提供器3、创建自己的内容提供器2.1 创建内容提供器的步骤2.2 跨程序数据共享 内容提供器(Content Provider)主要用于在不同的应用程序之间实…

Python | Leetcode Python题解之第52题N皇后II

题目: 题解: class Solution:def totalNQueens(self, n: int) -> int:def backtrack(row: int) -> int:if row n:return 1else:count 0for i in range(n):if i in columns or row - i in diagonal1 or row i in diagonal2:continuecolumns.add…

【VsCode】使用VsCode学习VUE+TS必备插件

目录标题 《Auto Close Tag》💕《Auto Rename Tag》💕《Path Intellisense》《Open in Browser》💕《IntelliCode》《Vue-Official》💕《Prettier - Code formatter》💕《ESLint》💕 《Auto Close Tag》&am…

Docker镜像的创建 和 Dockerfile

一. Docker 镜像的创建 创建镜像有三种方法,分别为基于已有镜像创建、基于本地模板创建以及基于 Dockerfile 创建。 1 基于现有镜像创建 (1)首先启动一个镜像,在容器里做修改docker run -it --name web3 centos:7 /bin/bash …

maven修改默认编码格式为UTF-8

执行mvn -version查看maven版本信息发现,maven使用的编码格式为GBK。 为什么想到要修改编码格式呢?因为idea中我将文件格式统一设置为UTF-8(如果不知道如何修改文件编码,可以参考文末),然后使用maven打包时…

docker 和 docker-compose的区别

Docker 和 Docker Compose 是两个相关但具有不同功能的工具,它们在容器化应用的生命周期管理中扮演不同的角色: Docker: Docker 是一个开源的应用容器引擎,它允许开发者打包应用及其依赖包到一个可移植的容器中,这样…

ESP32与SD卡交互实现:文件读写实战与初始化详解及引脚定义

本代码实现ESP32与SD卡的交互,包括定义SPI引脚、创建自定义SPI类实例、编写WriteFile与ReadFile函数进行文件读写。setup函数初始化串口、SPI、SD卡,向“/test.txt”写入“myfirstmessage”,读取并打印其内容。loop函数留空待扩展。 1. 需要…

Lock-It for Mac(应用程序加密工具)

OSXBytes Lock-It for Mac是一款功能强大的应用程序加密工具,专为Mac用户设计。该软件具有多种功能,旨在保护用户的隐私和数据安全。 Lock-It for Mac v1.3.0激活版下载 首先,Lock-It for Mac能够完全隐藏应用程序,使其不易被他人…

PCV库之调用SIFT.py中process_image()执行错误的解决方案

背景介绍: windows10,python3.x,64位AMD R9处理器,ROG电脑,pycharm开发,已经安装好了PCV库 如何安装PCV库 请看我的博客Python之PCV库安装教程以及解说-CSDN博客文章浏览阅读111次,点赞5次,收藏3次。GitHub - Ultravioletrayss/PCVfile: 文档内含有pyt…

Nginx 配置 SSL(HTTPS)详解

Nginx作为一款高性能的HTTP和反向代理服务器,自然支持SSL/TLS加密通信。本文将详细介绍如何在Nginx中配置SSL,实现HTTPS的访问。 随着互联网安全性的日益重要,HTTPS协议逐渐成为网站加密通信的标配。Nginx作为一款高性能的HTTP和反向代理服务…

OpenCV 实现霍夫圆变换

返回:OpenCV系列文章目录(持续更新中......) 上一篇:OpenCV实现霍夫变换 下一篇:OpenCV 实现重新映射 目标 在本教程中,您将学习如何: 使用 OpenCV 函数 HoughCircles()检测图像中的圆圈。 理论 Hough 圆变换 H…

Cgicc搭建交叉编译环境(移植到arm)

Cgicc GUN Project官网连接:Cgicc- GNU Project - Free Software Foundation 1. 下载源码 Cgicc下载地址: [via http] Index of /gnu/cgicc [via FTP] ftp://ftp.gnu.org/gnu/cgicc/ 目前最新版:3.2.20 2. 源码构建原理 一般&#xff…

在VSCode中调试其他软件执行的python文件

在VSCode中调试其他软件执行的python文件 0. 实际场景 我有一段python代码想在Metashape中运行,但是又想在中间某一步停下来查看变量值。由于Metashape的python环境不容易在vscode中配置,所以直接用vscode调试单个文件的方式无法实现这个想法。还好&am…

42. UE5 RPG 实现火球术伤害

上一篇,我们解决了火球术于物体碰撞的问题,现在火球术能够正确的和攻击目标产生碰撞。接下来,我们要实现火球术的伤害功能,在火球术击中目标后,给目标造成伤害。 实现伤害功能的思路是给技能一个GameplayEffect&#x…

3DTiles生产流程与规范

一篇19年整理的比较老的笔记了。更多精彩内容尽在数字孪生平台。 瓦片切分 标准的四叉树切分对于均匀分布的地理数据切片非常有效,但是这样均等的切分不适用于随机分布、不均匀分布的地理数据,当地理数据稀疏分布的时候,均等的四叉树就不再高…

跟着Datawhale重学数据结构与算法(3)---排序算法

开源链接:【 教程地址 】【电子网站】 【写博客的目的是记录自己学习过程,方便自己复盘,专业课复习】 数组排序: #mermaid-svg-F3iLcKsVv8gcmqqC {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16p…

PHP项目搭建与启动

1、拉取项目 2、安装phpstudy 下载地址: Windows版phpstudy下载 - 小皮面板(phpstudy) (xp.cn) 软件安装: Apache2.4.39、Nginx1.15.11、MySQL8.0.12、 composer2.5.8 添加伪静态 将下面代码写入到伪静态配置文本域框内: location ~* (ru…

【Qt 学习笔记】Qt常用控件 | 输入类控件 | Text Edit的使用及说明

博客主页:Duck Bro 博客主页系列专栏:Qt 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ Qt常用控件 | 输入类控件 | Text Edit的使用及说明 文章编号&#xff…