pytorch 转 onnx

ONNX 是目前模型部署中最重要的中间表示之一,在把 PyTorch 模型转换成 ONNX 模型时,使用的 torch 接口是 torch.onnx.export
这里记录了 pytorch 模型转 onnx 时的原理和注意事项,还包括部分 PyTorch 与 ONNX 的算子对应关系。

1 torch.onnx.export原理

1.1 导出计算图

TorchScript 是一种序列化和优化 PyTorch 模型的格式,在优化过程中,一个torch.nn.Module模型会被转换成 TorchScript 的 torch.jit.ScriptModule 模型。通常 TorchScript 也被当成一种中间表示来使用。

torch.onnx.export中需要的模型实际上是一个torch.jit.ScriptModule。而要把普通 PyTorch 模型转一个这样的 TorchScript 模型,有跟踪(trace)和记录(script)两种导出计算图的方法。

如果给torch.onnx.export传入了一个普通 PyTorch 模型(torch.nn.Module),那么这个模型会默认使用 trace 的方法导出:
t o r c h . n n . M o d u l e → t o r c h . o n n x . e x p o r t (默认使用 t o r c h . j i t . t r a c e ) o n n x 模型 \boxed{torch.nn.Module} \xrightarrow{torch.onnx.export(默认使用 torch.jit.trace)} \boxed{onnx模型} torch.nn.Moduletorch.onnx.export(默认使用torch.jit.trace onnx模型
t o r c h . n n . M o d u l e → t o r c h . j i t . s c r i p t s t o r c h . j i t . t r a c e t o r c h . j i t . S c r i p t M o d u l e → t o r c h . o n n x . e x p o r t o n n x 模型 \boxed{torch.nn.Module} \xrightarrow[torch.jit.scripts]{torch.jit.trace} \boxed{torch.jit.ScriptModule} \xrightarrow{torch.onnx.export} \boxed{onnx模型} torch.nn.Moduletorch.jit.trace torch.jit.scriptstorch.jit.ScriptModuletorch.onnx.export onnx模型

trace 方法只能通过实际运行一遍模型的方法导出模型的静态图,即无法识别出模型中的控制流(如循环);script 方法则能通过解析模型来正确记录所有的控制流

下面的代码段额可以用来对比 trace 和 script 两种方法获取 graph 的区别

    import torch class Model(torch.nn.Module): def __init__(self, n): super().__init__() self.n = n self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): for i in range(self.n): x = self.conv(x) return x models = [Model(2), Model(3)] model_names = ['model_2', 'model_3'] for model, model_name in zip(models, model_names): dummy_input = torch.rand(1, 3, 10, 10) dummy_output = model(dummy_input) model_trace = torch.jit.trace(model, dummy_input) model_script = torch.jit.script(model) # 跟踪法与直接 torch.onnx.export(model, ...)等价 torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output) # 记录法必须先调用 torch.jit.sciprt torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output) 

在这段代码里,定义了一个带循环的模型,模型通过参数n来控制输入张量被卷积的次数。之后,各创建了一个n=2和n=3的模型。把这两个模型分别用跟踪和记录的方法进行导出。

值得一提的是,由于这里的两个模型(model_trace, model_script)是 TorchScript 模型,export函数已经不需要再运行一遍模型了。(如果模型是用跟踪法得到的,那么在执行torch.jit.trace的时候就运行过一遍了;而用记录法导出时,模型不需要实际运行)参数中的dummy_input和dummy_output`仅仅是为了获取输入和输出张量的类型和形状。

trace 方法得到的 ONNX 模型结构,会把 for 循环展开,这样不同的 n,得到的 ONNX 模型 graph 是不一样的;而 scripts 方法得到的 ONNX 模型,用 Loop 节点来表示循环,这样对于不同的 n,得到的 ONNX 模型结构是一样的

实际上,推理引擎对静态图的支持更好,通常在模型部署时不需要显式地把 PyTorch 模型转成 TorchScript 模型,直接把 PyTorch 模型用 torch.onnx.export 借助 trace 方法导出即可。

1.2 torch.onnx.export 参数注解

这里主要记录对于模型部署比较重要的几个参数在模型部署中还如何设置,该函数的 API 文档:https://pytorch.org/docs/stable/onnx.html#functions

torch.onnx.exporttorch.onnx.__init__.py 文件中的定义如下:

    def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=True, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True, use_external_data_format=False): 

前三个必选参数分别为 torch 模型、模型输入(转 ONNX 的时候的 dummy input)、ONNX 模型的保存路径。

  • export_params
    模型中是否存储模型权重。IR 中一般包含两类信息,模型结构和模型权重,这两类信息可以在同一个文件里存储,也可以分文件存储。
    一般来说,如果转 onnx 是用来部署的,那么选择设置为 true,存放在同一个文件中;如果是用来在在不同框架间传递模型,则设为 false,分开存放

  • input_names, output_names
    设置输入输出张量的名称。如果不设置,会默认使用 tensor ID(数字) 作为 张量名称。ONNX 的张量名称一般都需要设置,因为大部分推理引擎在设置模型输入和获取输出数据的时候,都是以字典的形式进行访问处理,其中张量名称作为 key,数据作为 value

  • opset_version
    转换时参考哪个 ONNX 算子集版本,默认为 9

  • dynamic_axes
    指定 onnx 动态 shape 的动态维度。
    为了追求效率,ONNX 默认所有参与运算的张量都是静态的(张量的形状不发生改变)。但在实际应用中,我们又希望模型的输入张量是动态的,尤其是本来就没有形状限制的全卷积模型。因此,我们需要显式地指明输入输出张量的哪几个维度的大小是可变的。

      import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv(x) return x model = Model() dummy_input = torch.rand(1, 3, 10, 10) model_names = ['model_static.onnx',  'model_dynamic_0.onnx',  'model_dynamic_23.onnx'] dynamic_axes_0 = { 'in' : {0'batch'}'out' : {0, 'batch'} } torch.onnx.export(model, dummy_input, model_names[0], input_names=['in'], output_names=['out']) torch.onnx.export(model, dummy_input, model_names[1], input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0) 
    

    导出 2 个 ONNX 模型,分别为没有动态维度、第 0 维动态的模型。
    这里使用字典的方式来表示动态维度,因为 ONNX 要求每个动态维度都有一个名字,否则会有一堆 warning

2 torch 转 onnx 时候的额外操作

2.1 添加额外处理逻辑到 onnx 中

可以把一些后处理的逻辑放在模型里,来简化除运行模型之外的其他代码。torch.onnx.is_in_onnx_export() 可以达到这样的效果,这个函数只会在执行torch.onnx.export()的时候返回 true

    import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv(x) if torch.onnx.is_in_onnx_export(): x = torch.clip(x, 0, 1) return x 

这里,仅在模型导出 onnx 时把输出张量的数值限制在[0, 1]之间。使用 is_in_onnx_export确让我们方便地在代码中添加和模型部署相关的逻辑。但是,这些突兀的部署逻辑会降低代码整体的可读性。另外,is_in_onnx_export只能在每个需要添加部署逻辑的地方都“打补丁”,不方便进行统一的管理。

2.2 中断张量 trace

如果在 pytorch 的模型脚本中有一些比较离谱的操作,会把某些取决于输入的中间结果变成常量,从而使导出的 ONNX 模型和原来的模型不等价。
如下是一个 trace 中断的例子:

    class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): x = x * x[0].item() return x, torch.Tensor([i for i in x]) model = Model()       dummy_input = torch.rand(10) torch.onnx.export(model, dummy_input, 'a.onnx') 

在导出 ONNX 的时候,会有很多的 warning,并且提示转出来的 onnx 很可能不正确。
在这个模型里使用了.item()把 torch 中的张量转换成了普通的 Python 变量,还尝试遍历 torch 张量,并用一个列表新建一个 torch 张量。这些涉及张量与普通变量转换的逻辑都会导致最终的 ONNX 模型不太正确。
另一方面,也可以利用这个性质,在保证正确性的前提下令模型的中间结果变成常量。这个技巧常常用于模型的静态化上。

3 pytorch 对 onnx 算子的支持了解

在做 pytorch model 转换成 onnx model 的时候,PyTorch 一方面会用跟踪法执行前向推理,把遇到的算子整合成计算图;另一方面,PyTorch 还会把遇到的每个算子翻译成 ONNX 中定义的算子。 PyTorch 算子是向 ONNX 对齐的,这个过程中,可能会有这样的情况:

  • 该算子可以一对一地翻译成一个 ONNX 算子。
  • 该算子在 ONNX 中没有直接对应的算子,会翻译成一至多个 ONNX 算子。
  • 该算子没有定义翻译成 ONNX 的规则,报错。
3.1 onnx 算子文档

在onnx 官方算子文档中可以查看 onnx 算子的定义情况。
在算子文档中,第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是前面在torch.onnx.export中提到的opset_version表示的算子集版本号。通过查看算子第一次发生变动的版本号,可以知道某个算子是从哪个版本开始支持的;通过查看某算子小于等于opset_version的第一个改动记录,可以知道当前算子集版本中该算子的定义规则。

3.2 pytorch 对 onnx 算子的映射

在 PyTorch 中,和 ONNX 有关的定义全部放在torch.onnx目录中。symbolic_opset{n}.py(符号表文件)即表示 PyTorch 在支持第 n 版 ONNX 算子集时新加入的内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/592081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React使用动态标签名称

最近在一项目里(React antd)遇到一个需求,某项基础信息里有个图标配置(图标用的是antd的Icon组件),该项基础信息的图标信息修改后,存于后台数据库,后台数据库里存的是antd Icon组件…

用Redis实现实现全局唯一ID

全局唯一ID 如果使用数据库自增ID就存在一些问题: id的规律性太明显受表数据量的限制 全局ID生成器,是一种在分布式系统下用来生成全局唯一ID的工具,一般要满足下列特性: 唯一性高可用递增性安全性高性能 为了增加ID的安全性…

Django 后台与便签

1. 什么是后台管理 后台管理是网页管理员利用网页的后台程序管理和更新网站上网页的内容。各网站里网页内容更新就是通过网站管理员通过后台管理更新的。 2. 创建超级用户 1. python .\manage.py createsuperuser 2. 输入账号密码等信息 Username (leave blank to use syl…

在Android设备上设置和使用隧道代理HTTP

随着互联网的深入发展,网络信息的传递已经成为人们日常生活中不可或缺的一部分。对于我们中国人来说,由于某些特殊的原因,访问国外网站时常常会遇到限制。为了解决这个问题,使用代理服务器成为了许多人的选择。而在Android设备上设…

微服务智慧工地信息化解决方案(IOT云平台源码)

智慧工地是指应用智能技术和互联网手段对施工现场进行管理和监控的一种工地管理模式。它利用传感器、监控摄像头、人工智能、大数据等技术,实现对施工现场的实时监测、数据分析和智能决策,以提高工地的安全性、效率和质量。 智慧工地平台是一种智慧型、系…

opencv期末练习题(7)附带解析

打印图像各个点的颜色 import cv2 import numpy as np""" 分别获得左上角、右上角、左下角、右下角各自的颜色,并打印相关颜色的值 """ img cv2.imread(test.png)(x, y, z) img.shape print("当前图像的尺寸:", x, y, z…

Redis双写一致性

文章目录 Redis双写一致性1. 延迟双删(有脏数据风险)2. 异步通知(保证数据最终一致性)3. 分布式锁(数据的强一致,性能低) Redis双写一致性 当修改了数据库的数据也要同时更新缓存的数据&#xf…

Linux 系统拉取 Github项目

一、安装Git 在Linux上拉取GitHub项目可以使用Git命令。首先确保已经安装了Git。如果没有安装,可以通过包管理器(比如apt、yum)来进行安装。 sudo yum install git #查看安装版本 git -version二、关联GitHub 配置本地账户和邮箱 >>…

ThreeJS创建关键帧动画

之前有说过两种创建动画的形式,一个是很粗的方式,直接在requestAnimationFrame中修改模型的属性,因为threejs本身就会不断刷新画面,利用不断刷新的时候修改模型属性就实现了每次刷新后修改模型的一些属性,另一种方式是…

iOS实时查看App运行日志

目录 一、设备连接 二、使用克魔助手查看日志 三、过滤我们自己App的日志 📝 摘要: 本文介绍了如何在iOS iPhone设备上实时查看输出在console控制台的日志。通过克魔助手工具,我们可以连接手机并方便地筛选我们自己App的日志。 &#x1f4…

鸟瞰uml(下)

36.组件是系统中遵从一组接口且提供实现的一个物理部件,通常指开发和运行时类的物理实现 37.部件图用于对系统的静态实现视图建模,这种视图主要支持系统部件的配置管理,通常可以分为以下4种方式来完成: 对源代码进行建模&#x…

Rust 圣经 阅读 引用与借用

Rust 通过 借用(Borrowing) 在使用某个变量的指针或引用。 获取变量的引用,称之为 借用(borrowing) 。 引用与解引用 引用是为了解决在使用函数时,频繁地传递所有权。 引用只是获取了引用权,而…

魔改Stable Diffusion,开源创新“单目深度估计”模型

单目深度估计一直是计算机视觉领域的难点。仅凭一张 RGB 图像,想要还原出场景的三维结构,在几何结构上非常不确定,必须依赖复杂的场景理解能力。 即便使用更强大的深度学习模型来实现,也面临算力需求高、图像数据注释量大、泛化能力弱等缺点。 为了解决这些难题&a…

线性代数第一课+第二课总结

第一课 第一课是简单的行列式计算,主要就是要把左下角的数字全部转换为0,通过减去其他行的式子即可实现,最后把对角线的所有数字相乘,得到的结果是最后行列式的答案 第二课 例题1 硬算理论上其实也是可行的,但是使…

R语言——reshape2包、tidyr包、dplyr包(五)

目录 一、数据转换之reshape2包:melt与dcast函数 二、数据转换之tidyr包:gather与spread函数,separate与unite函数 三、据转换之dplyr包 四、参考 一、数据转换之reshape2包:melt与dcast函数 merge 函数 使用merge函数 x &l…

听GPT 讲Rust源代码--library/proc_macro

File: rust/library/proc_macro/src/bridge/rpc.rs 在Rust源代码中,rust/library/proc_macro/src/bridge/rpc.rs文件的作用是实现了Rust编程语言的编译过程中的远程过程调用(RPC)机制。 这个文件定义了与编译器的交互过程中使用的各种数据结构…

阿里云2核2G3M服务器能放几个网站?有限制吗?

阿里云2核2g3m服务器可以放几个网站?12个网站,阿里云服务器网的2核2G服务器上安装了12个网站,甚至还可以更多,具体放几个网站取决于网站的访客数量,像阿里云服务器网aliyunfuwuqi.com小编的网站日访问量都很少&#xf…

LeetCode 1758. 生成交替二进制字符串的最少操作数【字符串,模拟】1353

本文属于「征服LeetCode」系列文章之一,这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁,本系列将至少持续到刷完所有无锁题之日为止;由于LeetCode还在不断地创建新题,本系列的终止日期可能是永远。在这一系列刷题文章…

【数值分析】三次样条插值

三次样条插值 2023年11月5日 #analysis 文章目录 三次样条插值1. 样条函数1.1 截断多项式 2. 三次样条插值2.1 B样条为基底的三次样条插值函数2.1.1 第一种边界条件2.1.2 第二种边界条件2.1.3 第三种边界条件 2.2 三弯矩法求三次样条插值函数2.2.1 第一种边界条件2.2.2 第二种…

万界星空科技低代码平台基本模块与优势

低代码平台(Low-Code Development Platform,LCDP)就是使用低代码的方式进行开发,能快速设置和部署的平台。低代码平台旨在简化应用开发过程,降低开发难度,缩短开发周期,并使非专业程序员&#x…