从 PyTorch 到 ONNX:深度学习模型导出全解析

在模型训练完毕后,我们通常希望将其部署到推理平台中,比如 TensorRT、ONNX Runtime 或移动端框架。而 ONNX(Open Neural Network Exchange)正是 PyTorch 与这些平台之间的桥梁。

本文将以一个图像去噪模型 SimpleDenoiser 为例,手把手带你完成 PyTorch 模型导出为 ONNX 格式的全过程,并解析每一行代码背后的逻辑。

准备工作

我们假设你已经训练好一个图像去噪模型并保存为 .pth 文件,模型结构自编码器实现如下(略):

class SimpleDenoiser(nn.Module):def __init__(self):super(SimpleDenoiser, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1), nn.ReLU())self.decoder = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),nn.Conv2d(64, 3, 3, padding=1))def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x

导出代码分解

我们现在来看导出脚本的核心逻辑,并分块解释它的每一部分。

1. 导入模块 & 设置路径

//torch:核心框架//train.SimpleDenoiser:从训练脚本复用模型结构//os:用于创建输出目录import torch
from train import SimpleDenoiser  # 模型结构
import os

2. 导出函数定义

//这个函数接收三个参数://pth_path: 训练得到的模型参数文件路径//onnx_path: 导出的 ONNX 文件保存路径//input_size: 模拟推理输入的尺寸(默认 1×3×256×256)
def export_model_to_onnx(pth_path, onnx_path, input_size=(1, 3, 256, 256)):

3. 加载模型和权重

//自动检测 CUDA 可用性,加载模型到对应设备;//使用 load_state_dict() 加载训练好的参数;//model.eval() 让模型切换到推理模式(关闭 Dropout/BatchNorm 更新);
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = SimpleDenoiser().to(device)
model.load_state_dict(torch.load(pth_path, map_location=device))
model.eval()

4. 构造假输入(Dummy Input)

//ONNX 导出需要一个具体的输入样本,我们这里用 torch.randn 生成一个形状为 (1, 3, 256, 256) 的随机图//像;//输入必须放在同一个设备上(GPU 或 CPU);
dummy_input = torch.randn(*input_size).to(device)

5. 导出为 ONNX

torch.onnx.export(model,  //要导出的模型dummy_input,  //示例输入张量onnx_path, //	导出路径export_params=True,  //是否导出权重opset_version=11,  //ONNX 的算子集版本,通常推荐 11 或 13do_constant_folding=True,  //优化常量表达式,减小模型体积input_names=['input'],  //自定义输入输出张量的名称output_names=['output'],  //声明哪些维度可以变动,比如 batch size、图像大小等(部署时更灵活)dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},'output': {0: 'batch_size', 2: 'height', 3: 'width'}}
)

6. 创建目录并调用函数

//确保输出文件夹存在,并调用导出函数生成最终模型。
if __name__ == "__main__":os.makedirs("onnx", exist_ok=True)export_model_to_onnx("weights/denoiser.pth", "onnx/denoiser.onnx")

导出后如何验证?

pip install onnxruntime
import onnxruntime
import numpy as npsess = onnxruntime.InferenceSession("onnx/denoiser.onnx")
input = np.random.randn(1, 3, 256, 256).astype(np.float32)
output = sess.run(None, {"input": input})
print("输出 shape:", output[0].shape)

 模型预览:

总结

导出 ONNX 模型的流程主要包括:

  1. 加载模型结构 + 权重

  2. 准备 dummy 输入张量

  3. 调用 torch.onnx.export() 进行导出

  4. 设置 dynamic_axes 可变尺寸以增强部署适配性

这套流程适用于大部分视觉模型(分类、去噪、分割等),也是后续进行 TensorRT 推理或移动端部署的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/77068.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Hadoop集群部署教程-P6

Hadoop集群部署教程-P6 Hadoop集群部署教程(续) 第二十一章:监控与告警系统集成 21.1 Prometheus监控体系搭建 Exporter部署: # 部署HDFS Exporter wget https://github.com/prometheus/hdfs_exporter/releases/download/v1.1.…

【Altium】AD-生成PDF文件图纸包含太多的空白怎么解决

1、 文档目标 AD设计文件导出PDF时,图纸模板方向设置问题 2、 问题场景 AD使用Smart PDF导出PDF时,不管你怎么设置页面尺寸,只要从横向转为纵向输出,输出的始终是横向纸张(中间保留纵向图纸,两边大量留白…

大厂面试:六大排序

前言 本篇博客集中了冒泡,选择,二分插入,快排,归并,堆排,六大排序算法 如果觉得对你有帮助,可以点点关注,点点赞,谢谢你! 1.冒泡排序 //冒泡排序&#xff…

大模型开发:源码分析 Qwen 2.5-VL 视频抽帧模块(附加FFmpeg 性能对比测试)

目录 qwen 视频理解能力 messages 构建 demo qwen 抽帧代码分析 验证两个实际 case 官网介绍图 性能对比:ffmpeg 抽帧、decord 库抽帧 介绍 联系 对比 测试结果 测试明细 ffmpeg 100 qps 测试(CPU) decord 100 qps 测试&#x…

git的上传流程

好久没使用git 命令上传远程仓库了。。。。。温习了一遍; 几个注意点--单个文件大小不能超过100M~~~ 一步步运行下面的命令: 进入要上传的文件夹内,点击git bash 最终 hbu的小伙伴~有需要nndl实验的可以自形下载哦

驱动学习专栏--字符设备驱动篇--2_字符设备注册与注销

对于字符设备驱动而言,当驱动模块加载成功以后需要注册字符设备,同样,卸载驱动模 块的时候也需要注销掉字符设备。字符设备的注册和注销函数原型如下所示 : static inline int register_chrdev(unsigned int major, const char *name, const…

redis 放置序列化的对象,如果修改对象,需要修改版本号吗?

在 Redis 中存储序列化对象时,如果修改了对象的类结构(例如增删字段、修改字段类型或顺序),是否需要修改版本号取决于序列化协议的兼容性策略和业务场景的容错需求。以下是详细分析: 1. 为什么需要考虑版本号? 序列化兼容性问题: 当对象的类结构发生变化时,旧版本的序列…

WPF ObjectDataProvider

在 WPF(Windows Presentation Foundation)中,ObjectDataProvider 是一个非常有用的类,用于将非 UI 数据对象(如业务逻辑类或服务类)与 XAML 绑定集成。它允许在 XAML 中直接调用方法、访问属性或实例化对象,而无需编写额外的代码。以下是关于 ObjectDataProvider 的详细…

深度学习-损失函数 python opencv源码(史上最全)

目录 定义 种类 如何选择损失函数? 平方(均方)损失函数(Mean Squared Error, MSE) 均方根误差 交叉熵 对数损失 笔记回馈 逻辑回归中一些注意事项: 定义 损失函数又叫误差函数、成本函数、代价函数…

poll为什么使用poll_list链表结构而不是数组 - 深入内核源码分析

一:引言 在Linux内核中,poll机制是一个非常重要的I/O多路复用机制。它允许进程监视多个文件描述符,等待其中任何一个进入就绪状态。poll的内部实现使用了poll_list链表结构而不是数组,这个设计选择背后有其深层的技术考量。本文将从内核源码层面深入分析这个设计决…

使用 Azure AKS 保护 Kubernetes 部署的综合指南

企业不断寻求增强其软件开发和部署流程的方法。DevOps 一直是这一转型的基石,弥合了开发与运营之间的差距。然而,随着安全威胁日益复杂,将安全性集成到 DevOps 流水线(通常称为 DevSecOps)已变得势在必行。本指南深入探讨了如何使用 Azure Kubernetes 服务 (AKS) 来利用 D…

2025年常见渗透测试面试题-webshell免杀思路(题目+回答)

网络安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 webshell免杀思路 PHP免杀原理 webshell免杀测试: webshell免杀绕过方法: 编…

访问不到服务器上启动的llamafactory-cli webui

采用SSH端口转发有效,在Windows上面进行访问 在服务器上启动 llamafactory-cli webui 后,访问方式需根据服务器类型和网络环境选择以下方案: 一、本地服务器(物理机/虚拟机) 1. 直接访问 若服务器与操作设备处于同一…

基于 LSTM 的多特征序列预测-SHAP可视化!

往期精彩内容: 单步预测-风速预测模型代码全家桶-CSDN博客 半天入门!锂电池剩余寿命预测(Python)-CSDN博客 超强预测模型:二次分解-组合预测-CSDN博客 VMD CEEMDAN 二次分解,BiLSTM-Attention预测模型…

C++ 编程指南35 - 为保持ABI稳定,应避免模板接口

一:概述 模板在 C 中是编译期展开的,不同模板参数会生成不同的代码,这使得模板类/函数天然不具备 ABI 稳定性。为了保持ABI稳定,接口不要直接用模板,先用普通类打个底,模板只是“外壳”,这样 AB…

【iOS】OC高级编程 iOS多线程与内存管理阅读笔记——自动引用计数(二)

自动引用计数 前言ARC规则所有权修饰符**__strong修饰符**__weak修饰符__unsafe_unretained修饰符__autoreleasing修饰符 规则属性数组 前言 上一篇我们主要学习了一些引用计数方法的内部实现,现在我们学习ARC规则。 ARC规则 所有权修饰符 OC中,为了处…

可信空间数据要素解决方案

可信空间数据要素解决方案 一、引言 随着数字经济的蓬勃发展,数据已成为重要的生产要素。可信空间数据要素解决方案旨在构建一个安全、可靠、高效的数据流通与应用环境,促进数据要素的合理配置和价值释放,推动各行业的数字化转型和创新发展…

mysql删除表后重建表报错Tablespace exists

版本 mysql:8.0.23 复现步骤 1、删除表 DROP TABLE IF EXISTS xxx_demo; 2、新建表 CREATE TABLE xxx_demo (id bigint NOT NULL AUTO_INCREMENT COMMENT 主键id,creator varchar(64) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NULL DEFAULT COMMENT 创建者,c…

【Leetcode-Hot100】缺失的第一个正数

题目 解答 有一处需要注意,我使用注释部分进行交换值,报错:超出时间限制。有人知道是为什么吗?难道是先给nums[i]赋值后,从而改变了后一项的索引? class Solution(object):def firstMissingPositive(sel…

从单模态到多模态:五大模型架构演进与技术介绍

前言 1. ResNet — 残差神经网络背景核心问题与解决方案原理模型架构ResNet 系列变体技术创新与影响 2. ViT — Vision Transformer背景核心思想发展历程Transformer的起源:ViT的出现:ViT的进一步发展: 模型架构技术创新与影响 3. Swin Trans…