onnx 注册自定义算子
- 第一步:手写一个算子,然后注册一下
 - 第二步:将算子放进模型定义
 - 第三步:利用 torch.onnx.export() 编写onnx 导出函数
 
一般我们自定义算子的时候,有以下流程
- 编写算子并注册
 - 将算子放进模型定义
 - 利用 torch.onnx.export() 编写 onnx 导出函数
 
第一步:手写一个算子,然后注册一下
(注册就是在正常的 forward 之前加一个 symbolic 函数)
 (如何注册理解 symbolic 参考上个博客)
class CustomOp(torch.autograd.Function):@staticmethoddef symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:return g.op("custom_domain::customOp2", x)@staticmethoddef forward(ctx, x: torch.Tensor) -> torch.Tensor:ctx.save_for_backward(x)x = x.clamp(min=0)return x / (1 + torch.exp(-x))
 
第二步:将算子放进模型定义
customOp = CustomOp.apply
class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = customOp(x)return x
 
第三步:利用 torch.onnx.export() 编写onnx 导出函数
def export_norm_onnx():input   = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)model   = Model()model.eval()file    = "customOp.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")
 
完整代码:
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicOperatorExportTypes = torch._C._onnx.OperatorExportTypesclass CustomOp(torch.autograd.Function):@staticmethoddef symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:return g.op("custom_domain::customOp2", x)@staticmethoddef forward(ctx, x: torch.Tensor) -> torch.Tensor:ctx.save_for_backward(x)x = x.clamp(min=0)return x / (1 + torch.exp(-x))customOp = CustomOp.apply
class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = customOp(x)return xdef export_norm_onnx():input   = torch.rand(1, 50).uniform_(-1, 1).reshape(1, 2, 5, 5)model   = Model()model.eval()file    = "customOp.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()