模型部署 - onnx 的导出和分析 -(1) - PyTorch 导出 ONNX - 学习记录

onnx 的导出和分析

  • 一、PyTorch 导出 ONNX 的方法
    • 1.1、一个简单的例子 -- 将线性模型转成 onnx
    • 1.2、导出多个输出头的模型
    • 1.3、导出含有动态维度的模型
  • 二、pytorch 导出 onnx 不成功的时候如何解决
    • 2.1、修改 opset 的版本
    • 2.2、替换 pytorch 中的算子组合
    • 2.3、在 pytorch 登记( 注册 ) onnx 中某些算子
      • 2.3.1、注册方法一
      • 2.3.2、注册方法二
    • 2.4、直接修改 onnx,创建 plugin

一、PyTorch 导出 ONNX 的方法

1.1、一个简单的例子 – 将线性模型转成 onnx

首先我们用 pytorch 定义一个线性模型,nn.Linear : 线性层执行的操作是 y = x * W^T + b,其中 x 是输入,W 是权重,b 是偏置。(实际上就是一个矩阵乘法)

class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return x

然后我们再定义一个函数,用于导出 onnx

def export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished onnx export")

可以看到,这里面的关键在函数 torch.onnx.export(),这是 pytorch 导出 onnx 的基本方式,这个函数的参数有很多,但只要一些基本的参数即可导出模型,下面是一些基本参数的定义:

  • model (torch.nn.Module): 需要导出的PyTorch模型
  • args (tuple or Tensor): 一个元组,其中包含传递给模型的输入张量
  • f (str): 要保存导出模型的文件路径。
  • input_names (list of str): 输入节点的名字的列表
  • output_names (list of str): 输出节点的名字的列表
  • opset_version (int): 用于导出模型的 ONNX 操作集版本

最后我们完整的运行一下代码:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

导出模型后,我们用 netron 查看模型,在终端输入

netron model.onnx

在这里插入图片描述

1.2、导出多个输出头的模型

第一步:定义一个多输出的模型:

class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2

第二步:编写导出 onnx 的函数

def export_onnx():input    = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model   = Model(4, 3, weights1, weights2)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0", "output1"],opset_version = 12)print("Finished onnx export")

可以看到,和例 1.1 不一样的地方是 torch.onnx.export 的 output_names
例1.1:output_names = [“output0”]
例1.2:output_names = [“output0”, “output1”]

运行一下完整代码:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2def export_onnx():input    = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model   = Model(4, 3, weights1, weights2)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0", "output1"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

用 netron 查看模型,结果如下,模型多出了一个输出结果
在这里插入图片描述

1.3、导出含有动态维度的模型

完整运行代码如下:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],dynamic_axes  = {'input0':  {0: 'batch'},'output0': {0: 'batch'}},opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

可以看到,比例 1.1 多了一行 torch.onnx.export 的 dynamic_axes 。我们可以用 dynamic_axes 来指定动态维度,其中 'input0': {0: 'batch'} 中的 0 表示在第 0 维度上的元素是动态的,这里取名为 ‘batch’

用 netron 查看模型:
在这里插入图片描述
可以看到相对于例1.1,他的维度 0 变成了动态的,并且名为 ‘batch’

二、pytorch 导出 onnx 不成功的时候如何解决

上面是 onnx 可以直接被导出的情况,是因为对应的 pytorch 和 onnx 版本都有相应支持的算子在里面。但是有些时候,我们不能顺利的导出 onnx,下面记录一下常见的解决思路 。

2.1、修改 opset 的版本

这是首先应该考虑的思路,因为有可能只是版本过低然后有些算子还不支持,所以考虑提高 opset 的版本

比如下面的这个报错,提示当前 onnx 的 opset 版本不支持这个算子,那我们可以去官方手册搜索一下是否在高的版本支持了这个算子
在这里插入图片描述

官方手册地址:https://github.com/onnx/onnx/blob/main/docs/Operators.md

在这里插入图片描述
又比如说 Acosh 这个算子,在 since version 9 才开始支持,那我们用 7 的时候就是不合适的,升级 opset 版本即可

2.2、替换 pytorch 中的算子组合

有些时候 pytorch 中的一些算子操作在 onnx 中并没有,那我们可以把这些算子替换成 onnx 支持的算子

2.3、在 pytorch 登记( 注册 ) onnx 中某些算子

有些算子在 onnx 中是有的,但是在 pytorch 中没被登记,则需要注册一下
比如下面这个案例,我们想要导出 asinh 这个算子的模型

import torch
import torch.onnxclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 9)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()

但是报错,提示 opset_version = 9 不支持这个算子
在这里插入图片描述

但是我们打开官方手册去搜索发现 asinh 在 version 9 又是支持的
在这里插入图片描述
这里的问题是 PyTorch 与 onnx 之间没有建立 asinh 的映射 (没有搭建桥梁),所以我们编写一个注册代码,来手动注册一下这个算子

2.3.1、注册方法一

完整代码如下:

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicdef asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x     = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess  = onnxruntime.InferenceSession('asinh.onnx')x     = sess.run(None, {'input0': input.numpy()})print("result from onnx is:    ", x)def export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定义完onnx以后必须要进行一下验证validate_onnx()

这段代码的关键在于 算子的注册:

1、定义 asinh_symbolic 函数

def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
  1. 函数必须是 asinh_symbolic 这个名字
  2. g: 就是 graph,计算图 (在计算图中添加onnx算子)
  3. input :symblic的参数需要与Pytorch的asinh接口函数的参数对齐
    (def asinh( input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: … )
  4. 符号函数内部调用 g.op, 为 onnx 计算图添加 Asinh 算子
  5. g.op中的第一个参数是onnx中的算子名字: Asinh

2、使用 register_custom_op_symbolic 函数

register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)
  1. aten 是"a Tensor Library"的缩写,是一个实现张量运算的C++库
  2. asinh 是在名为 aten 的一个c++命名空间下进行实现的
  3. 将 asinh_symbolic 这个符号函数,与PyTorch的 asinh 算子绑定
  4. register_op 中的第一个参数是PyTorch中的算子名字: aten::asinh
  5. 最后一个参数表示从第几个 opset 开始支持(可自己设置)

3、自定义完 onnx 以后必须要进行一下验证,可使用 onnxruntime

2.3.2、注册方法二

import torch
import torch.onnx
import onnxruntime
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x     = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess  = onnxruntime.InferenceSession('asinh2.onnx')x     = sess.run(None, {'input0': input.numpy()})print("result from onnx is:    ", x)def export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh2.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定义完onnx以后必须要进行一下验证validate_onnx()

与上面例子不同的是,这个注册方式跟底层文件的写法是一样的(文件在虚拟环境中的 torch/onnx/symbolic_opset*.py )

通过torch._internal 中的 registration 来注册这个算子,让这个算子可以与底层C++实现的 aten::asinh 绑定

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)

2.4、直接修改 onnx,创建 plugin

直接手动创建一个 onnx (这是一个思路,会在后续博客进行总结记录)

参考链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/715799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vscode+remote突然无法连接服务器以及ssh连接出问题时的排错方法

文章目录 设备描述状况描述解决方法当ssh连接出问题时的排错方法 设备描述 主机:win11,使用vscode的remote-ssh插件 服务器:阿里云的2C2GUbuntu 22.04 UFIE 状况描述 之前一直使用的是vscode的remote服务,都是能够正常连接服务…

vue2结合electron开发桌面端应用

一、Electron是什么? Electron是一个使用 JavaScript、HTML 和 CSS 构建桌面应用程序的框架。 嵌入 Chromium 和 Node.js 到 二进制的 Electron 。允许您保持一个 JavaScript 代码代码库并创建可在Windows、macOS和Linux上运行的跨平台应用 。 Electron 经常与 Ch…

scrapy 中间件

就是发送请求的时候,会经过,中间件。中间件会处理,你的请求 下面是代码: # Define here the models for your spider middleware # # See documentation in: # https://docs.scrapy.org/en/latest/topics/spider-middleware.html…

【快速上手ProtoBuf】基本使用

文章目录 1 :peach:初识 ProtoBuf:peach:1.1 :apple:序列化概念:apple:1.2 :apple:ProtoBuf 是什么:apple:1.3 :apple:ProtoBuf 的使用特点:apple: 2 :peach:创建 .proto ⽂件:peach:3 :peach:编译 .proto 文件:peach:3 :peach:序列化与反序列化的使用:peach: 1 🍑初…

LeetCode 刷题 [C++] 第45题.跳跃游戏 II

题目描述 给定一个长度为 n 的 0 索引整数数组 nums。初始位置为 nums[0]。 每个元素 nums[i] 表示从索引 i 向前跳转的最大长度。换句话说&#xff0c;如果你在 nums[i] 处&#xff0c;你可以跳转到任意 nums[i j] 处: 0 < j < nums[i]i j < n 返回到达 nums[n …

【C语言】熟悉文件基础知识

欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 文件 为了数据持久化保存&#xff0c;使用文件&#xff0c;否则数据存储在内存中&#xff0c;程序退出&#xff0c;内存回收&#xff0c;数据就会丢失。 程序设计中&…

微信小程序,h5端自适应登陆方式

微信小程序端只显示登陆(获取opid),h5端显示通过账户密码登陆 例如: 通过下面的变量控制: const isWeixin ref(false); // #ifdef MP-WEIXIN isWeixin.value true; // #endif

LIN基础:从LIN Frame开始

目录&#xff1a; 1、LIN的网络拓扑 2、LIN Frame 1&#xff09;Header 2&#xff09;Response 3、LIN的通信规则 1&#xff09;LIN的发送行为示例 2&#xff09;LIN的接收行为示例 虽然LIN总线的通信速率不高&#xff0c;工程中&#xff0c;最高的速率也就19200bps。…

StarRocks——Stream Load 事务接口实现原理

目录 前言 一、StarRocks 数据导入 二、StarRocks 事务写入原理 三、InLong 实时写入StarRocks原理 3.1 InLong概述 3.2 基本原理 3.3 详细流程 3.3.1 任务写入数据 3.3.2 任务保存检查点 3.3.3 任务如何确认保存点成功 3.3.4 任务如何初始化 3.4 Exactly Once 保证…

Leetcode - 周赛386

目录 一&#xff0c;3046. 分割数组 二&#xff0c;3047. 求交集区域内的最大正方形面积 三&#xff0c;3048. 标记所有下标的最早秒数 I 四&#xff0c;3049. 标记所有下标的最早秒数 II 一&#xff0c;3046. 分割数组 将题目给的数组nums分成两个数组&#xff0c;且这两个…

盲人出行:科技创造美好的未来

在繁忙的都市中&#xff0c;我每天都要面对许多挑战&#xff0c;盲人出行安全保障一直难以得到落实。我看不见这个世界&#xff0c;只能依靠触觉和听觉来感知周围的一切。然而&#xff0c;我从未放弃过对生活的热爱和对未来的憧憬。在一次机缘巧合下&#xff0c;我认识了一款名…

C3_W2_Collaborative_RecSys_Assignment_吴恩达_中英_Pytorch

Practice lab: Collaborative Filtering Recommender Systems(实践实验室:协同过滤推荐系统) In this exercise, you will implement collaborative filtering to build a recommender system for movies. 在本次实验中&#xff0c;你将实现协同过滤来构建一个电影推荐系统。 …

VLAN实验报告

实验要求&#xff1a; 实验参考图&#xff1a; 实验过程&#xff1a; r1: [r1]int g 0/0/0.1 [r1-GigabitEthernet0/0/0.1]ip address 192.168.1.1 24 [r1-GigabitEthernet0/0/0.1]dot1q termination vid 2 [r1-GigabitEthernet0/0/0.1]arp broadcast enable [r1]int g 0/0/…

Mysql学习之MVCC解决读写问题

多版本并发控制 什么是MVCC MVCC &#xff08;Multiversion Concurrency Control&#xff09;多版本并发控制。顾名思义&#xff0c;MVCC是通过数据行的多个版本管理来实现数据库的并发控制。这项技术使得在InnoDB的事务隔离级别下执行一致性读操作有了保证。换言之&#xff0…

django的模板渲染中的【高级定制】:按数据下标id来提取数据

需求&#xff1a; 1&#xff1a;在一个页面中显示一张数据表的数据 2&#xff1a;不能使用遍历的方式 3&#xff1a;页面中的数据允许通过admin后台来进行修改 4&#xff1a;把一张数据表的某些内容渲染到[xxx.html]页面 5&#xff1a;如公司的新商品页面&#xff0c;已有固定的…

《梦幻西游》本人收集的34个单机版游戏,有详细的视频架设教程,值得收藏

梦幻西游这款游戏&#xff0c;很多人玩&#xff0c;喜欢研究的赶快下载吧。精心收集的34个版本。不容易啊。里面有详细的视频架设教程&#xff0c;可以外网呢。 《梦幻西游》本人收集的34个单机版游戏&#xff0c;有详细的视频架设教程&#xff0c;值得收藏 下载地址&#xff1…

阶跃信号与冲击信号

奇异信号&#xff1a;信号与系统分析中&#xff0c;经常遇到函数本身有不连续点&#xff08;跳变电&#xff09;或其导函数与积分有不连续点的情况&#xff0c;这类函数称为奇异函数或奇异信号&#xff0c;也称之为突变信号。以下为一些常见奇异函数。 奇异信号 单位斜变信号 …

Ubuntu18.04安装RTX2060显卡驱动+CUDA+cuDNN

Ubuntu18.04安装RTX2060显卡驱动CUDAcuDNN 1 安装RTX2060显卡驱动1.1 查看当前显卡是否被识别1.2 安装驱动依赖1.3 安装桌面显示管理器1.4 下载显卡驱动1.5 禁用nouveau1.6 安装驱动1.7 查看驱动安装情况 2 安装CUDA2.1 查看当前显卡支持的CUDA版本2.2 下载CUDA Toolkit2.3 安装…

车灯修复UV胶的优缺点有哪些?

车灯修复UV胶的优点如下&#xff1a; 优点&#xff1a; 快速固化&#xff1a;通过紫外光照射&#xff0c;UV胶可以在5-15秒内迅速固化&#xff0c;提高了修复效率。高度透明&#xff1a;固化后透光率高&#xff0c;几乎与原始车灯材料无法区分&#xff0c;修复后车灯外观更加…

对缓冲区的初步认识——制作进度条小程序

对缓冲区的初步认识--进度条小程序 前言预备知识回车和换行的区别输出缓冲区/n 有清空输出缓冲区的作用stdout是什么&#xff1f;验证一切皆文件为什么是\n行刷新&#xff1f; 倒计时程序原理 代码实现为什么这里要强制刷新&#xff1f;没有会怎样&#xff1f;为什么是输出的是…