win11下 “pytorch导出模型“ 以及 “C++使用onnxruntime部署”

部分一:PyTorch导出模型

在Win11下,PyTorch是一个强大的深度学习框架,它提供了丰富的工具来训练和导出模型。在这一部分,我们将使用鸢尾花数据集,演示如何在PyTorch中训练一个简单的模型,并将其导出为ONNX格式。

1、引言

深度学习模型的导出对于模型在不同平台上的部署至关重要。PyTorch的灵活性使得导出过程变得相对简单,同时保持了模型的准确性。

2、数据准备和模型训练

在这一步,我们首先加载鸢尾花数据集,对数据进行预处理,然后训练一个简单的神经网络模型。以下是代码示例:

# 导入所需的库
import torch
import torch.onnx
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 加载鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target# 数据预处理
X_scaled = X / 10# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 将数据转换为PyTorch的Tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)# 定义简单的神经网络模型
class IrisModel(nn.Module):def __init__(self):super(IrisModel, self).__init__()self.fc1 = nn.Linear(4, 8)self.fc2 = nn.Linear(8, 3)  # 输入特征为4,输出类别为3def forward(self, x):x = F.relu(self.fc1(x))x = F.log_softmax(self.fc2(x), dim=1)return x# 初始化模型、损失函数和优化器
model = IrisModel()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型
for epoch in range(100):optimizer.zero_grad()output = model(X_train_tensor)loss = criterion(output, y_train_tensor)loss.backward()optimizer.step()

3、模型导出

在训练完成后,我们使用torch.onnx.export方法将模型导出为ONNX格式。导出的ONNX文件将在接下来的部署中使用,以下是代码示例:

# 将模型转换为ONNX格式
dummy_input = torch.randn(1, 4)  # 创建一个虚拟输入
onnx_path = 'iris_model.onnx'
torch.onnx.export(model, dummy_input, onnx_path,input_names=['input'],output_names=['output'],dynamic_axes = {'input':{0: 'batch_size'},'output':{0: 'batch_size'}})

4、总结

在这一部分,我们演示了如何使用PyTorch训练一个简单的神经网络模型,并将其导出为ONNX格式,为模型在不同平台上的部署做好了准备。

部分二:C++使用ONNX Runtime部署

ONNX Runtime是一个用于高性能推理的开源引擎,它支持在不同平台上运行ONNX格式的模型。在这一部分,我们将学习如何使用C++和ONNX Runtime加载并运行先前导出的鸢尾花分类模型。

1、引言

ONNX Runtime的强大之处在于其跨平台性能,使得模型能够在各种设备上进行高效推理。

2、环境配置和项目设置

在使用C++部署模型之前,我们需要确保系统中已经正确安装了ONNX Runtime,并且我们的C++项目设置正确。
本文采用v1.16.3版本,下载地址:https://github.com/microsoft/onnxruntime/releases

3. C++代码实现

以下是一个简单的C++代码示例,演示如何加载ONNX模型并进行推理:

#include <array>
#include <algorithm>
#include <iostream>
#include <onnxruntime_cxx_api.h>int main() {// ONNX模型文件路径const wchar_t* model_path = L"D:\\vs_project\\demo\\iris_model.onnx";// 创建ONNX运行环境和内存信息Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);// 配置会话选项Ort::SessionOptions session_option;session_option.SetIntraOpNumThreads(5); // 设置并行线程数session_option.SetGraphOptimizationLevel(ORT_ENABLE_ALL);// 定义模型输入和输出的名称const char* input_names[] = { "input" };const char* output_names[] = { "output" };// 定义样本数量和输入输出矩阵的大小const int num_samples = 2;std::array<float, num_samples * 4> input_matrix;std::array<float, num_samples * 3> output_matrix;// 定义输入输出矩阵的形状std::array<int64_t, 2> input_shape{ num_samples, 4 };std::array<int64_t, 2> output_shape{ num_samples, 3 };// 定义样本输入数据std::vector<std::vector<float>> sample_x = { {2.1, 3.5, 1.4, 0.2}, {5.1, 1.5, 2.4, 0.2} };int sample_y = 3;// 将样本数据复制到输入矩阵中for (int i = 0; i < num_samples; i++) {for (int j = 0; j < 4; j++) {input_matrix[i * 4 + j] = sample_x[i][j];}}// 创建输入和输出TensorOrt::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());try {// 创建ONNX会话并运行模型Ort::Session session(env, model_path, session_option);session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1);}catch (const Ort::Exception& e) {// 处理ONNX Runtime异常std::cerr << "ONNX Runtime 异常: " << e.what() << std::endl;}catch (const std::exception& e) {// 处理标准异常std::cerr << "标准异常: " << e.what() << std::endl;}catch (...) {// 处理未知异常std::cerr << "未知异常." << std::endl;}// 输出预测结果std::cout << "--- 预测结果 ---" << std::endl;for (int i = 0; i < num_samples; i++) {std::cout << "输出矩阵: ";for (int j = 0; j < sample_y; j++) {std::cout << output_matrix[i * sample_y + j] << " ";}std::cout << std::endl;// 找到输出矩阵中的argmax值int argmax_value = std::distance(output_matrix.begin() + i * sample_y, std::max_element(output_matrix.begin() + i * sample_y, output_matrix.begin() + (i + 1) * sample_y));std::cout << "样本 " << i << " 的输出 argmax 值: " << argmax_value << std::endl;}// 等待用户按键结束程序getchar();return 0;
}

推理结果:

--- 预测结果 ---
输出矩阵: -0.0161782 -4.42303 -5.50902
样本 0 的输出 argmax 值: 0
输出矩阵: -2.30582 -0.799013 -0.797285
样本 1 的输出 argmax 值: 2

4. 总结

在这一部分,我们通过使用C++和ONNX Runtime,成功加载并运行了在PyTorch中训练并导出的鸢尾花分类模型。这为在不同C++支持的平台上进行模型推理提供了一个简单而强大的解决方案。

通过这两个部分,我们实现了从PyTorch训练模型到在C++环境中进行推理的全过程。这个流程可以在Win11下轻松实现,为模型的实际应用提供了一个完整的参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/656403.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

改变this指针的三个方法?

要改变 this 的指向&#xff0c;JavaScript 提供了一系列的方法&#xff1a; call()&#xff1a; 使用 call() 可以直接改变 this 的指向。它接受两个参数&#xff1a;第一个是要调用的目标函数&#xff0c;第二个是将作为 this 的值的对象或对象引用的数组。例如&#xff0c…

protobuf-go pragma.go 文件介绍

pragma.go 文件 文件位于&#xff1a; https://github.com/protocolbuffers/protobuf-go/blob/master/internal/pragma/pragma.go 该文件核心思想&#xff1a; 利用 Golang 语法机制&#xff0c;扩展 Golang 语言特性 目前&#xff0c;该文件提供以下 4 个功能&#xff1a; …

C++STL模板库

类&#xff1a; pair: 头文件&#xff1a;<utility> 定义&#xff1a; 是一个标准库类型。可以看作是有两个成员变量first和second的结构体&#xff0c;并且重载了<运算符(先比较first大小&#xff0c;再比较second大小)当我们创建一个pair时&#xff0c;必须提供两…

SQLite 简介

什么是SQLite&#xff1f; SQLite是一个轻量级的嵌入式关系型数据库&#xff0c;它以一个小型的C语言库的形式存在。它的设计目标是嵌入式的&#xff0c;而且已经在很多嵌入式产品中使用了它&#xff0c;它占用资源非常的低&#xff0c;在嵌入式设备中&#xff0c;可能只需要几…

机器学习面试题总结60-99

目录 60、Python到底是什么样的语言? 61.Python是如何进行内存管理的? 引用计数和垃圾回收。

leetcode-存在重复元素

217. 存在重复元素 把列表转成集合&#xff0c;我们知道集合中是没有重复元素的&#xff0c;然后和原列表的长度做对比&#xff0c;不相等说明是有重复元素的 class Solution:def containsDuplicate(self, nums: List[int]) -> bool:if len(set(nums)) len(nums):return …

状态码400以及状态码415

首先检查前端传递的参数是放在header里边还是放在body里边。 此图前端传参post请求&#xff0c;定义为’Content-Type’&#xff1a;‘application/x-www-form-urlencoded’ 此刻他的参数在FormData中。看下图 后端接参数应为&#xff08;此刻参数前边什么都不加默认为requestP…

Qt QScrollArea 不显示滚动条 不滚动

使用QScrollArea时&#xff0c;发现添加的控件超出QScrollArea 并没有显示&#xff0c;且没有滚动条效果 原因是 scrollArea指的是scrollArea控件本身的大小&#xff0c;肉眼能看到的外形尺寸。 scrollAreaWidgetContents指的是scrollArea控件内部的显示区域&#xff0c;里面可…

2024 高级前端面试题之 React 「精选篇」

该内容主要整理关于 React 模块的相关面试题&#xff0c;其他内容面试题请移步至 「最新最全的前端面试题集锦」 查看。 React模块精选篇 1. 如何理解React State不可变性的原则2. JSX本质3. React合成事件机制4. setState和batchUpdate机制5. 组件渲染和更新过程6. Diff算法相…

windows server 开启远程连接RDP连接

windows server 开启远程连接&#xff0c;RDP连接windows server 打开gpedit.msc, 找到计算机配置-管理模板-windows组件-远程桌面服务-远程桌面会话主机-授权 1 使用指定的远程桌面许可证服务器 2 设置远程桌面授权模式 3 重启windows server服务器生效 4使用mstsc命令连接…

未来每家公司都需要有自己的大模型- Hugging Face创始人分享

自ChatGPT发布以来&#xff0c;有人称其是统治性一切的模型。Hugging Face创始人兼首席执行官Clem Delangue介绍&#xff0c;Hugging Face平台已经有15000家公司分享了25万个开源模型&#xff0c;当然这些公司不会为了训练模型而训练模型&#xff0c;因为训练模型需要投入大量资…

Springboot自定义线程池实现多线程任务

1. 在启动类添加EnableAsync注解 2.自定义线程池 package com.bt.springboot.config;import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.scheduling.concurrent.ThreadPoolTask…

记录 | ubuntu nm命令的基本使用

什么是nm命令 nm命令是linux下针对某些特定文件的分析工具&#xff0c;能够列出库文件&#xff08;.a、.lib&#xff09;、目标文件&#xff08;*.o&#xff09;、可执行文件的符号表。 nm命令的常用参数 -A 或 -o 或 --print-file-name&#xff1a;打印出每个符号属于的文件…

webassembly003 TTS BARK.CPP

TTS task TTS&#xff08;Text-to-Speech&#xff09;任务是一种自然语言处理&#xff08;NLP&#xff09;任务&#xff0c;其中模型的目标是将输入的文本转换为声音&#xff0c;实现自动语音合成。具体来说&#xff0c;模型需要理解输入的文本并生成对应的语音输出&#xff0…

Mysql 为表增加计算列

什么叫计算列呢&#xff1f;简单来说就是某一列的值是通过别的列计算得来的。 增加计算列的语法格式如下&#xff1a; col_name data_type [GENERATED ALWAYS] AS (expression) [VIRTUAL | STORED] [UNIQUE [KEY]] [COMMENT comment] [NOT NULL | NULL] [[PRIMARY] KEY]; 下…

c++学习记录 多态—案例2—电脑组装

#include<iostream> using namespace std;//抽象不同的零件//抽象的cpu类 class Cpu { public://抽象的计算函数virtual void calculate() 0; };//抽象的显卡类 class VideoCard { public://抽象的显示函数virtual void display() 0; };//抽象的内存条类 class Memory …

华为通用软件开发工程师24校招三轮面试详细记录

本文介绍2024届秋招中&#xff0c;华为技术有限公司的通用软件开发工程师岗位的3场面试基本情况、提问问题等。 7月投递了华为技术有限公司的通用软件开发工程师岗位&#xff0c;所在部门为海思半导体与器件业务部。目前完成了一面、二面与三面等全部流程&#xff0c;在这里记录…

K210 UART串口通信介绍与 STM32通信

目录 K210-UART串口通信相关函数&#xff1a; 使用K210串口的时候需要映射引脚&#xff1a; K210与STM32串口通信 发送单字节&#xff1a; K210端 STM32端 发送数据包 K210端 STM32端 K210的UART模块支持全双工通信&#xff0c;可以同时进行数据的发送和接收。在K21…

Nginx启用WebSocket支持

报错内容nginx.conf proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; 问题解决WebSocket跨域 add_header Access-Control-Allow-Origin *; add_header Access-Control-Allow-Credentials true;

常用芯片学习——AMS1117芯片

AMS1117 1A 低压差线性稳压器 使用说明 AMS1117 是一款低压差线性稳压电路&#xff0c;该电路输出电流能力为1A。该系列电路包含固定输出电压版本和可调输出电压版本&#xff0c;其输出电压精度为士1.5%。为了保证芯片和电源系统的稳定性&#xff0c;XBLWAMS1117 内置热保护和…