一、定义
- 如何保证pytorch 模型顺利转为onnx. 前言
- pytorch 算子是如何与onnx 算子对齐的?
- Asinh 算子出现于第 9 个 ONNX 算子集。PyTorch 在 9 号版本的符号表文件中是怎样支持这个算子的?
- BitShift 算子出现于第11个 ONNX 算子集。PyTorch 在 11 号版本的符号表文件中是怎样支持这个算子的?
- 算子在pytorch 中已经实现,onnx 算子也实现,缺少映射方法,自己注册,实现转换。
- 自定义onnx 算子。
- 构造onnx 模型,并测试。
- onnx提取子模型
二、实现
-
如何保证pytorch 模型顺利转为onnx. 前言, 参考:https://zhuanlan.zhihu.com/p/513387413
要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:
算子在 PyTorch 中有实现
有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法
ONNX 有相应的算子
可在实际部署中,这三部分的内容都可能有所缺失。其中最坏的情况是:我们定义了一个全新的算子,它不仅缺少 PyTorch 实现,还缺少 PyTorch 到 ONNX 的映射关系。但所谓车到山前必有路,对于这三个环节,我们也分别都有以下的添加支持的方法:
PyTorch 算子
组合现有算子
添加 TorchScript 算子
添加普通 C++ 拓展算子
映射方法
为 ATen 算子添加符号函数
为 TorchScript 算子添加符号函数
封装成 torch.autograd.Function 并添加符号函数
ONNX 算子
使用现有 ONNX 算子
定义新 ONNX 算子 -
pytorch 算子是如何与onnx 算子对齐的?
onnx 算子文档:https://github.com/onnx/onnx/blob/main/docs/Operators.md
torch 对onnx算子映射:https://github.com/pytorch/pytorch/tree/main/torch/onnx
表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export中提到的opset_version表示的算子集版本号。
symbolic_opset{n}.py(符号表文件)即表示 PyTorch 在支持第 n 版 ONNX 算子集时新加入的内容。判定是否存在映射方法。 -
Asinh 算子出现于第 9 个 ONNX 算子集。PyTorch 在 9 号版本的符号表文件中是怎样支持这个算子的?
Asinh 在第9版本onnx 中实现,检查symbolic_opset9.py 发现,但pytorch 中已经实现torch.asinh(), 即缺少映射方法。 -
BitShift 算子出现于第11个 ONNX 算子集。PyTorch 在 11 号版本的符号表文件中是怎样支持这个算子的?
通过在 torch.onnx.symbolic_opset11.py 搜索 BitShift,我们可以发现 PyTorch 在 _lshift 和 _rshift 里用到了ONNX的 BitShift 算子。当输入类型为 Byte 时,PyTorch会把算子直接翻译翻译
BitShift,以代替乘除 2 的次幂的操作。 -
算子在pytorch 中已经实现,onnx 算子也实现,缺少映射方法,自己注册,实现转换。
1. 获取 ATen 中算子接口定义
2. 添加符号函数 -
整合模型,导出onnx文件
-
测试算子
================================================================
import torchclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):return torch.asinh(x)from torch.onnx.symbolic_registry import register_opdef asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)register_op('asinh', asinh_symbolic, '', 9)model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'asinh.onnx')
测试
import onnxruntime
import torch
import numpy as np class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.asinh(x) model = Model()
input = torch.rand(1, 3, 10, 10)
torch_output = model(input).detach().numpy() sess = onnxruntime.InferenceSession('asinh.onnx')
ort_output = sess.run(None, {'0': input.numpy()})[0] assert np.allclose(torch_output, ort_output)
- 自定义onnx 算子。
https://zhuanlan.zhihu.com/p/513387413 - 构造onnx 模型,并测试。
import onnx
from onnx import helper
from onnx import TensorProto# input and output
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])# Mul
mul = helper.make_node('Mul', ['a', 'x'], ['c'])# Add
add = helper.make_node('Add', ['c', 'b'], ['output'])# graph and model
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
model = helper.make_model(graph)# save model
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
import onnxruntime
import numpy as np sess = onnxruntime.InferenceSession('linear_func.onnx')
a = np.random.rand(10, 10).astype(np.float32)
b = np.random.rand(10, 10).astype(np.float32)
x = np.random.rand(10, 10).astype(np.float32) output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0] assert np.allclose(output, a * x + b)
- onnx提取子模型
https://zhuanlan.zhihu.com/p/516920606
https://zhuanlan.zhihu.com/p/543973749
https://zhuanlan.zhihu.com/p/516920606