onnx 注册自定义算子
- 第一步:手写一个算子,然后注册一下
- 第二步:将算子放进模型定义
- 第三步:利用 torch.onnx.export() 编写onnx 导出函数
一般我们自定义算子的时候,有以下流程
- 编写算子并注册
- 将算子放进模型定义
- 利用 torch.onnx.export() 编写 onnx 导出函数
第一步:手写一个算子,然后注册一下
(注册就是在正常的 forward 之前加一个 symbolic
函数)
(如何注册理解 symbolic
参考上个博客)
class CustomOp(torch.autograd.Function):@staticmethoddef symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:return g.op("custom_domain::customOp2", x)@staticmethoddef forward(ctx, x: torch.Tensor) -> torch.Tensor:ctx.save_for_backward(x)x = x.clamp(min=0)return x / (1 + torch.exp(-x))
第二步:将算子放进模型定义
customOp = CustomOp.apply
class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = customOp(x)return x
第三步:利用 torch.onnx.export() 编写onnx 导出函数
def export_norm_onnx():input = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)model = Model()model.eval()file = "customOp.onnx"torch.onnx.export(model = model, args = (input,),f = file,input_names = ["input0"],output_names = ["output0"],opset_version = 12)print("Finished normal onnx export")
完整代码:
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicOperatorExportTypes = torch._C._onnx.OperatorExportTypesclass CustomOp(torch.autograd.Function):@staticmethoddef symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:return g.op("custom_domain::customOp2", x)@staticmethoddef forward(ctx, x: torch.Tensor) -> torch.Tensor:ctx.save_for_backward(x)x = x.clamp(min=0)return x / (1 + torch.exp(-x))customOp = CustomOp.apply
class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = customOp(x)return xdef export_norm_onnx():input = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)model = Model()model.eval()file = "customOp.onnx"torch.onnx.export(model = model, args = (input,),f = file,input_names = ["input0"],output_names = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()