PyTorch 模型转换为 ONNX 格式

PyTorch 模型转换为 ONNX 格式

在深度学习领域,模型的可移植性和可解释性是非常重要的。本文将介绍如何使用 PyTorch 训练一个简单的卷积神经网络(CNN)来分类 MNIST 数据集,并将训练好的模型转换为 ONNX 格式。我们还将讨论 PTH 和 ONNX 格式的区别,并介绍如何使用 Netron 可视化 ONNX 模型。

1. PTH 和 ONNX 的区别

PTH 格式

  • 定义:PTH 是 PyTorch 框架的专有格式,通常用于保存模型的状态字典(state_dict),包括模型的结构和训练好的参数。

  • 兼容性

    • PTH 文件只能在 PyTorch 中使用,无法直接在 C++ 环境中加载。虽然 PyTorch 提供了 C++ API(LibTorch),但 PTH 文件的加载和使用主要依赖于 Python 环境。
    • 在 C++ 中使用 PTH 文件需要将模型转换为 PyTorch 的 C++ 格式,这可能会增加复杂性和开发时间。
  • 用途

    • PTH 格式适合在 Python 环境中进行模型训练和调试,但在 C++ 中进行模型部署时,通常需要将模型转换为其他格式(如 ONNX)以便于跨平台使用。
    • 在 C++ 中,使用 PTH 文件的灵活性较低,尤其是在需要与其他框架或系统集成时。

ONNX 格式

  • 定义:ONNX(Open Neural Network Exchange)是一个开放的深度学习模型交换格式,旨在促进不同深度学习框架之间的互操作性。

  • 兼容性

    • ONNX 文件可以在多个深度学习框架中使用,包括 PyTorch、TensorFlow、Caffe2 等,这使得它在 C++ 环境中的兼容性更强。
    • ONNX 模型可以通过 ONNX Runtime、TensorRT、OpenVINO 等推理引擎在 C++ 中高效运行,支持多种硬件加速。
  • 用途

    • ONNX 格式非常适合模型的部署和推理,特别是在需要跨平台或跨框架使用时。它允许开发者在 C++ 中轻松加载和运行模型,而无需依赖于 Python 环境。
    • 在 C++ 中,使用 ONNX 模型可以简化工程化流程,便于与其他系统集成,提升模型的可移植性和可扩展性。

总结

在 C++ 进行深度学习模型的工程化时,选择 ONNX 格式通常更为合适,因为它提供了更好的跨平台兼容性和灵活性。PTH 格式虽然在 PyTorch 环境中非常方便,但在 C++ 中的使用受到限制,通常需要额外的转换步骤。ONNX 的开放性和广泛支持使其成为在多种环境中部署深度学习模型的首选格式。

2. 训练 MNIST 数据集的 CNN 模型

以下是使用 PyTorch 训练 MNIST 数据集的完整代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 1. 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # MNIST 数据集的均值和标准差
])# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 2. 定义 CNN 模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 输入通道为1,输出通道为32self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 输入通道为32,输出通道为64self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 全连接层self.fc2 = nn.Linear(128, 10)  # 输出层def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))  # 第一层卷积 + 激活 + 池化x = self.pool(torch.relu(self.conv2(x)))  # 第二层卷积 + 激活 + 池化x = x.view(x.size(0), -1)  # 展平输入x = torch.relu(self.fc1(x))  # 第一个全连接层x = self.fc2(x)  # 输出层return x# 3. 训练模型
model = SimpleCNN().to(device)  # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练过程
num_epochs = 5
for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备outputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')# 5. 转换为 ONNX 格式
onnx_file_path = 'mnist_cnn_model.onnx'
dummy_input = torch.randn(1, 1, 28, 28).to(device)  # 示例输入,形状为 [batch_size, channels, height, width]
torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True,opset_version=11, do_constant_folding=True,input_names=['input'], output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f'Model has been converted to ONNX format and saved as {onnx_file_path}.')

3. 使用 Netron 可视化 ONNX 模型

一旦您将模型转换为 ONNX 格式,您可以使用 Netron 来可视化模型结构。Netron 是一个开源的模型可视化工具,支持多种深度学习框架的模型文件格式,包括 ONNX。

使用步骤:
  1. 下载 Netron

    • 您可以访问 Netron 的官方网站 在线使用,或者下载桌面版本。
  2. 打开 ONNX 模型

    • 如果使用在线版本,直接将 mnist_cnn_model.onnx 文件拖放到浏览器窗口中。
    • 如果使用桌面版本,打开 Netron 应用,选择“File” > “Open Model”,然后选择您的 ONNX 文件。
  3. 查看模型结构

    • 在 Netron 中,您可以查看模型的层次结构、输入输出形状、参数数量等信息。通过可视化,您可以更好地理解模型的设计和工作原理。
      在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/62307.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue+Elementui el-tree树只能选择子节点并且支持检索

效果&#xff1a; 只能选择子节点 添加配置添加检索代码 源码&#xff1a; <template><div><el-button size"small" type"primary" clearable :disabled"disabled" click"showSign">危险点评估</el-button>…

分析JHTDB数据库的Channel5200数据集的数据(SciServer服务器)

代码来自https://github.com/idies/pyJHTDB/blob/master/examples/channel.ipynb %matplotlib inline import numpy as np import math import random import pyJHTDB import matplotlib.pyplot as plt import time as ttN 3 T pyJHTDB.dbinfo.channel5200[time][-1] time …

《Vue零基础入门教程》第十二课:双向绑定指令

往期内容 《Vue零基础入门教程》第六课&#xff1a;基本选项 《Vue零基础入门教程》第八课&#xff1a;模板语法 《Vue零基础入门教程》第九课&#xff1a;插值语法细节 《Vue零基础入门教程》第十课&#xff1a;属性绑定指令 《Vue零基础入门教程》第十一课&#xff1a;事…

windows 应用 UI 自动化实战

UI 自动化技术架构选型 UI 自动化是软件测试过程中的重要一环&#xff0c;网络上也有很多 UI 自动化相关的知识或资料&#xff0c;具体到 windows 端的 UI 自动化&#xff0c;我们需要从以下几个方面考虑&#xff1a; 开发语言 毋庸置疑&#xff0c;在 UI 自动化测试领域&am…

linux部署Whisper 视频音频转文字

github链接&#xff1a;链接 我这里使用anaconda来部署&#xff0c;debian12系统&#xff0c;其他linux也同样 可以使用gpu或者cpu版本&#xff0c;建议使用n卡&#xff0c;rtx3060以上 一、前期准备 1.linux系统 链接&#xff1a;debian安装 链接&#xff1a;ubuntu安装 …

MySQL聚合查询分组查询联合查询

#对应代码练习 -- 创建考试成绩表 DROP TABLE IF EXISTS exam; CREATE TABLE exam ( id bigint, name VARCHAR(20), chinese DECIMAL(3,1), math DECIMAL(3,1), english DECIMAL(3,1) ); -- 插入测试数据 INSERT INTO exam (id,name, chinese, math, engli…

mini-spring源码分析

IOC模块 关键解释 beanFactory&#xff1a;beanFactory是一个hashMap, key为beanName, Value为 beanDefination beanDefination: BeanDefinitionRegistry&#xff0c;BeanDefinition注册表接口&#xff0c;定义注册BeanDefinition的方法 beanReference&#xff1a;增加Bean…

redis学习面试

1、数据类型 string 增删改查 set key valueget keydel kstrlen k 加减 incr articleincrby article 3decr articledecyby article 取v中特定位置数据 getrange name 0 -1getrange name 0 1setrange name 0 x 设置过期时间 setex pro 10 华为 等价于 set pro 华为expire pro…

详解MVC架构与三层架构以及DO、VO、DTO、BO、PO | SpringBoot基础概念

&#x1f64b;大家好&#xff01;我是毛毛张! &#x1f308;个人首页&#xff1a; 神马都会亿点点的毛毛张 今天毛毛张分享的是SpeingBoot框架学习中的一些基础概念性的东西&#xff1a;MVC结构、三层架构、POJO、Entity、PO、VO、DO、BO、DTO、DAO 文章目录 1.架构1.1 基本…

KST-3D01型胎儿超声仿真体模、吸声材料以及超声骨密度仪用定量试件介绍

一、KST-3D01型胎儿超声仿真体模 KST—3D01型胎儿超声体模&#xff0c;采用仿羊水环境中内置胎龄为7个月大仿胎儿设计。用于超声影像系统3D扫描演示装置表面轮廓呈现和3D重建。仿羊水超声影像呈暗回声&#xff08;无回波&#xff09;特性&#xff0c;仿胎儿超声影像呈对比明显…

【逐行注释】自适应Q和R的AUKF(自适应无迹卡尔曼滤波),附下载链接

文章目录 自适应Q的KF逐行注释的说明运行结果部分代码各模块解释 自适应Q的KF 自适应无迹卡尔曼滤波&#xff08;Adaptive Unscented Kalman Filter&#xff0c;AUKF&#xff09;是一种用于状态估计的滤波算法。它是基于无迹卡尔曼滤波&#xff08;Unscented Kalman Filter&am…

易速鲜花聊天客服机器人的开发(上)

“聊天机器人”项目说明 聊天机器人&#xff08;Chatbot&#xff09;是LLM和LangChain的核心用例之一&#xff0c;很多人学习大语言模型&#xff0c;学习LangChain&#xff0c;就是为了开发出更好的、更能理解用户意图的聊天机器人。聊天机器人的核心特征是&#xff0c;它们可…

ChatGPT/AI辅助网络安全运营之-数据解压缩

在网络安全的世界中&#xff0c;经常会遇到各种压缩的数据&#xff0c;比如zip压缩&#xff0c;比如bzip2压缩&#xff0c;gzip压缩&#xff0c;xz压缩&#xff0c;7z压缩等。网络安全运营中需要对这些不同的压缩数据进行解压缩&#xff0c;解读其本意&#xff0c;本文将探索一…

05_JavaScript注释与常见输出方式

JavaScript注释与常见输出方式 JavaScript注释 源码中注释是不被引擎所解释的&#xff0c;它的作用是对代码进行解释。lavascript 提供两种注释的写法:一种是单行注释&#xff0c;用//起头:另一种是多行注释&#xff0c;放在/*和*/之间。 //这是单行注释/* 这是 多行 注释 *…

网络原理(一):应用层自定义协议的信息组织格式 HTTP 前置知识

目录 1. 应用层 2. 自定义协议 2.1 根据需求 > 明确传输信息 2.2 约定好信息组织的格式 2.2.1 行文本 2.2.2 xml 2.2.3 json 2.2.4 protobuf 3. HTTP 协议 3.1 特点 4. 抓包工具 1. 应用层 在前面的博客中, 我们了解了 TCP/IP 五层协议模型: 应用层传输层网络层…

题目 3209: 蓝桥杯2024年第十五届省赛真题-好数

一个整数如果按从低位到高位的顺序&#xff0c;奇数位&#xff08;个位、百位、万位 &#xff09;上的数字是奇数&#xff0c;偶数位&#xff08;十位、千位、十万位 &#xff09;上的数字是偶数&#xff0c;我们就称之为“好数”。给定一个正整数 N&#xff0c;请计算从…

数据结构与算法学习笔记----KMP

数据结构与算法学习笔记----KMP author: 明月清了个风 last edited: 2024.11.24 Acwing 831. KMP字符串 给定一个字符串 S S S&#xff0c;以及一个模式串 P P P&#xff0c;所有字符串中只包含大小写英文字母以及阿拉伯数字。 模式串 P P P在字符串 S S S中多次作为子串出…

算法基础 - 二分迭代法求解非线性方程

文章目录 1. 基本思想2. 编程实现2.1. 非递归2.2. 递归方案 3. 总结 二分迭代法使用了二分算法思想求解非线性方程式。 下面要求使用二分迭代法求解&#xff1a; 2x3-5x-10 方程式&#xff0c;且要求误差不能大于10e-5。 二分迭代法也只是近似求解算法。 所谓求解&#xff…

家校通小程序实战教程03学生管理

目录 1 创建数据源2 搭建后台功能3 设置主列字段4 批量导入数据5 设置查询条件6 实现查询和重置总结 我们现在已经搭建了班级管理&#xff0c;并且录入了班级口令。之后就是加入班级的功能了。这里分为老师加入班级和学生家长加入班级。 如果是学生家长的话&#xff0c;在加入之…

springboot336社区物资交易互助平台pf(论文+源码)_kaic

毕 业 设 计&#xff08;论 文&#xff09; 社区物资交易互助平台设计与实现 摘 要 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff…