使用Pytorch导出自定义ONNX算子

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomAffineGrid(Function):@staticmethoddef forward(ctx, theta: torch.Tensor, size: torch.Tensor):grid = F.affine_grid(theta=theta, size=size.cpu().tolist())return grid@staticmethoddef symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):return g.op("AffineGrid", theta, size)class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):grid = CustomAffineGrid.apply(theta, size)x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")return xdef main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)theta = torch.randn(1, 2, 3)size = torch.as_tensor([1, 3, 512, 512])torch.onnx.export(model=custum_model,args=(x, theta, size),f="custom.onnx",input_names=["input0_x", "input1_theta", "input2_size"],output_names=["output"],dynamic_axes={"input0_x": {2: "h0", 3: "w0"},"output": {2: "h1", 3: "w1"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。
在这里插入图片描述

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomRot90AndScale(Function):@staticmethoddef forward(ctx, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90x *= 1.2return x@staticmethoddef symbolic(g: torch.Graph, x: torch.Tensor):return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor):return CustomRot90AndScale.apply(x)def main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)torch.onnx.export(model=custum_model,args=(x,),f="custom_rot90.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {2: "h0", 3: "w0"},"output": {2: "w0", 3: "h0"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/729570.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习中的线性代数

基础知识的的复习: 线性代数——深度学习花书第二章 - 知乎 矩阵分解 特征值分解。 PCA(Principal Component Analysis)分解,作用:降维、压缩。 SVD(Singular Value Decomposition)分解,也叫奇异值分解。 矩阵分解的主要应用是:降维、聚类分析、数据预处理、低维度特征学…

Keepalive 解决nginx 的高可用问题

一 说明 keepalived利用 VRRP Script 技术,可以调用外部的辅助脚本进行资源监控,并根据监控的结果实现优先动态调整,从而实现其它应用的高可用性功能 参考配置文件: /usr/share/doc/keepalived/keepalived.conf.vrrp.localche…

三八妇女节智慧花店/自动售花机远程视频智能监控解决方案

一、项目背景 国家统计局发布的2023年中国经济年报显示,全年社会消费品零售总额471495亿元,比上年增长7.2%。我国无人零售整体发展迅速,2014年市场规模约为17亿元。无人零售自助终端设备市场规模超过500亿元,年均复合增长率超50%。…

正则表达式在QT开发中的应用

一.正则表达式在QT开发中的使用: 1.模式匹配与验证:正则表达式最基本的作用就是进行模式匹配,它可以用来查找、识别或验证一个字符串是否符合某个特定的模式。例如,在表单验证中,可以使用正则表达式来检查用户输入的邮…

Agent——记忆模块

在基于大模型的 Agent架构设计方面,论文[1]提出了一个统一的框架,包括Profile模块、Memory模块、Planning模块和Action模块。其中长期记忆的状态维护至关重要,在 OpenAI AI 应用研究主管 Lilian Weng 的博客《基于大模型的 Agent 构成》[2]中,将记忆视为关键的组件之一,下…

Prompt Engineering、Finetune、RAG:OpenAI LLM 应用最佳实践

一、背景 本文介绍了 2023 年 11 月 OpenAI DevDay 中的一个演讲,演讲者为 John Allard 和 Colin Jarvis。演讲中,作者对 LLM 应用落地过程中遇到的问题和相关改进方案进行了总结。虽然其中用到的都是已知的技术,但是进行了很好的总结和串联…

羊大师分析羊奶滋养,女性魅力绽放

羊大师分析羊奶滋养,女性魅力绽放 羊奶,自古以来便是滋养身心的天然佳品。它富含多种营养成分,如蛋白质、脂肪、矿物质和维生素等,能够为女性提供全面而均衡的营养支持,帮助她们保持健康与活力。 女性是社会的半边天&…

单片机入门:LED数码管

LED数码管 LED数码管:由多个发光二极管封装在一起组成的“8”字型的器件。如下图所示: 数码管引脚定义 一位数码管 内部由八个LED组成。器件有十个引脚。 对于数码管内的8个LED有共阴和共阳两种连接方法。 共阴:将8个LED的阴极都连接到一…

Java项目:41 springboot大学生入学审核系统的设计与实现010

作者主页:源码空间codegym 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 本大学生入学审核系统管理员和学生。 管理员功能有个人中心,学生管理,学籍信息管理,入学办理管理等。 学…

wpf prism左侧抽屉式菜单

1.首先引入包MaterialDesignColors和MaterialDesignThemes 2.主页面布局 左侧菜单显示在窗体外&#xff0c;点击左上角菜单图标通过简单的动画呈现出来 3.左侧窗体外菜单 <Grid x:Name"GridMenu" Width"150" HorizontalAlignment"Left" Ma…

鸿蒙原生应用元服务开发-WebGL网页图形库开发概述

WebGL的全称为Web Graphic Library(网页图形库)&#xff0c;主要用于交互式渲染2D图形和3D图形。目前HarmonyOS中使用的WebGL是基于OpenGL裁剪的OpenGL ES&#xff0c;可以在HTML5的canvas元素对象中使用&#xff0c;无需使用插件&#xff0c;支持跨平台。WebGL程序是由JavaScr…

解读:DUSt3R: Geometric 3D Vision Made Easy

概述&#xff1a;给定一个无约束图像集&#xff0c;即一组具有未知相机姿态和内在特征的照片&#xff0c;我们提出的 DUSt3R 方法会输出一组相应的点阵图&#xff0c;从中我们可以直接恢复通常难以一次性估算的各种几何量&#xff0c;如相机参数、像素对应关系、深度图和完全一…

【PCIe】 PCIe 拓扑结构与分层结构

&#x1f525;博客主页&#xff1a;PannLZ 文章目录 PCIe拓扑结构PCIe分层结构 PCIe拓扑结构 计算机网络中的拓扑结构源于拓扑学(研究与大小、形状无关的点、线关系的方法)。 把网络中的计算机和通信设备抽象为一个点&#xff0c;把传输介质抽象为一条线&#xff0c;由点和线组…

【物联网】stm32芯片结构组成,固件库、启动过程、时钟系统、GPIO、NVIC、DMA、UART以及看门狗电路的全面详解

一、stm32的介绍 1、概述 stm32: ST&#xff1a;指意法半导体 M&#xff1a;指定微处理器 32&#xff1a;表示计算机处理器位数 与ARM关系:采用ARM推出cortex-A&#xff0c;R,M三系中的M系列&#xff0c;其架构主要基于ARMv7-M实现 ARM分成三个系列&#xff1a; Cortex-A&…

【排序算法】推排序算法解析:从原理到实现

目录 1. 引言 2. 推排序算法原理 3. 推排序的时间复杂度分析 4. 推排序的应用场景 5. 推排序的优缺点分析 5.1 优点&#xff1a; 5.2 缺点&#xff1a; 6. Java、JavaScript 和 Python 实现推排序算法 6.1 Java 实现&#xff1a; 6.2 JavaScript 实现&#xff1a; 6.…

K8S之实现业务的金丝雀发布

如何实现金丝雀发布 金丝雀发布简介优缺点在k8s中实现金丝雀发布 金丝雀发布简介 金丝雀发布的由来&#xff1a;17 世纪&#xff0c;英国矿井工人发现&#xff0c;金丝雀对瓦斯这种气体十分敏感。空气中哪怕有极其微量的瓦斯&#xff0c;金丝雀也会停止歌唱&#xff1b;当瓦斯…

【JS逆向学习】猿人学 第五题 js混淆 乱码

逆向目标 网址&#xff1a;https://match.yuanrenxue.cn/match/5接口&#xff1a;https://match.yuanrenxue.cn/api/match/5?page2&m1709806560791&f1709806560000参数&#xff1a; Cookie(m、RM4hZBv0dDon443M)payload(m、f) 逆向过程 老规矩&#xff0c;上来先分…

第24集《灵峰宗论导读》

请大家打开讲义第79面。 在第一段呢蕅益大师先讲到这个诸法实相的妙理&#xff0c;说从我们现前一念心性来观察诸法实相有两个角度&#xff1a; 第一个角度呢就是当我们摄用归体的时候&#xff0c;所谓万法唯识一心的时候&#xff0c;这个时候我们会发觉三世诸佛&#xff0c;…

vue实现文字手工动态打出效果

vue实现文字手工动态打出效果 问题背景 本文实现vue中&#xff0c;动态生成文字手动打出效果。 问题分析 话不多说&#xff0c;直接上代码&#xff1a; <template><main><button click"makeText"><p class"text">点击生成内容…

启动vue项目执行npm run serve报错 : error in ./src/element-variables.scss

error in ./src/element-variables.scss 问题原因 node-sass的版本问题 解决方式 我直接更新了一下node-sass&#xff0c;就好了 npm install node-sass 再次执行就可以执行成功了