使用PyTorch导出JIT模型:C++ API与libtorch实战

PyTorch导出JIT模型并用C++ API libtorch调用

本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。

Step1:导出模型

首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。

# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

导出 JIT 模型的方式有两种:trace 和 script。

我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。

在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth

Step 2:安装libtorch

接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:

wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

也解压在我们的工程目录
demo
下即可。

Step 3:安装OpenCV

用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo
下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt
文件中正确地指定本机的 OpenCV 地址即可。

git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6

Step 4:准备测试图像并用Python测试

我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。

kitten.jpg

写一个脚本用 PyTorch 运行一下模型:

# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())

输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系
,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。

Step 5:准备cpp源文件

我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。

这里有几个点要说一下,不注意可能会犯错:

  1. cv::imread()
    默认读取为三通道BGR,需要进行B/R通道交换,这里采用
    cv::cvtColor()
    实现。
  2. 图像尺寸需要调整到

224

×

224

224\times 224

2

2

4

×

2

2

4

,通过
cv::resize()
实现。
3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用
tensor.permut()
实现,效果是一样的。
4. 数据归一化,采用
tensor.div(255)
实现。

// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加载JIT模型auto module = torch::jit::load(argv[1]);// 加载图像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 图像转换为Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 运行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 结果处理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}

Step 6:构建运行验证

我们先来写一下
CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50  test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50  PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

现在我们的工程目录
demo
下有以下文件:

CMakeLists.txt  export_jit_model.py  kitten.jpg  libtorch  pytorch_test.py  resnet50_jit.pth  test_model.cpp

然后开始用 CMake 构建工程:

mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make

整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50
在工程目录
demo
下。

接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:

./build/resnet50 resnet50_jit.pth kitten.jpg

输出:

The classifiction index is: 282

运行成功并且结果正确。

Ref:

https://www.jianshu.com/p/7cddc09ca7a4

https://blog.csdn.net/cxx654/article/details/115916275

https://zhuanlan.zhihu.com/p/370455320

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/50859.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【odoo17】后端py方法触发右上角提示组件

概要 在前面文章中&#xff0c;有介绍过前端触发的通知服务。 【odoo】右上角的提示&#xff08;通知服务&#xff09; 此文章则介绍后端触发方法。 内容 直接上代码&#xff1a;但是前提一定是按钮触发&#xff01;&#xff01;&#xff01;&#xff01;&#xff01; def bu…

【css】实现扫光特效

对于要重点突出的元素&#xff0c;我们经常可以看到它上面打了一个从左到右的斜向扫光&#xff0c;显得元素亮闪闪的&#xff01;类似于下图的亮光动效 关键步骤 伪元素设置position :absolute【也可以不用伪元素&#xff0c;直接创建一个absolute元素盖在上面】设置渐变line…

Mike21粒子追踪模型particle tracking如何展示粒子轨迹

前言&#xff1a; 随着模型的推广&#xff0c;模型的很多模块也问的多了起来&#xff0c;PT粒子追踪模块最近群友也在问&#xff0c;结果算了出来&#xff0c;却实现不了展示运动轨迹。今天就写段简单的PT后处理的方法吧。 注意&#xff1a;MIKE21输出模块中不但输出了关于水…

Axure怎么样?全面功能评测与用户体验分析!

软件 Axure 曾经成为产品经理必备的原型设计工具&#xff0c;被认为是专门为产品经理设计的工具。但事实上&#xff0c;软件 Axure 的使用场景并不局限于产品经理构建产品原型。UI/UX 设计师还可以使用 Axure 软件构件应用程序 APP 原型&#xff0c;网站设计师也可以使用 Axure…

如何系统的学习C++和自动驾驶算法

给大家分享一下我的学习C和自动驾驶算法视频&#xff0c;收藏订阅都很高。打开下面的链接&#xff0c;就可以看到所有的合集了&#xff0c;订阅一下&#xff0c;下次就能找到了。 【C面试100问】第七十四问&#xff1a;STL中既然有了vector为什么还需要array STL中既然有了vec…

QSqlQuery增删改查

本文记录使用QSqlQuery实现增删改查的过程。 目录 1. 构建表格数据 声明变量 表格、数据模型、选择模型三板斧设置 列表执行查询 列表的水平表头设置 2. 新增一行 构建一个空行 通过dialog返回的修改行数据&#xff0c;update更新 3. 更新一行 获取到需要更新的行 通…

Spring Bean - xml 配置文件创建对象

类型&#xff1a; 1、值类型 2、null &#xff08;标签&#xff09; 3、特殊符号 &#xff08;< -> < &#xff09; 4、CDATA <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/bea…

信息安全技术解析

在信息爆炸的今天&#xff0c;信息技术安全已成为社会发展的重要基石。随着网络攻击的日益复杂和隐蔽&#xff0c;保障数据安全、提升防御能力成为信息技术安全领域的核心任务。本文将从加密解密技术、安全行为分析技术和网络安全态势感知技术三个方面进行深入探讨&#xff0c;…

WPF启动失败报System.Windows.Automation.Peers.AutomationPeer.Initialize()错误解决

问题描述 win10系统上WPF程序启动后就崩溃&#xff0c;通过查看崩溃日志如下&#xff1a; 应用程序: xxx.exe Framework 版本: v4.0.30319 说明: 由于未经处理的异常&#xff0c;进程终止。 异常信息: System.TypeLoadException 在 System.Windows.Automation.Peers.Automatio…

leetcode-105. 从前序与中序遍历序列构造二叉树

题目描述 给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 示例 1: 输入: preorder [3,9,20,15,7], inorder [9,3,15,20,7] 输出: [3,9,20,null,nu…

重塑生态体系 深挖应用场景 萤石诠释AI时代智慧生活新图景

7月24日&#xff0c;“智动新生&#xff0c;尽在掌控”2024萤石夏季新品发布会在杭州举办。来自全国各地的萤石合作伙伴、行业从业者及相关媒体&#xff0c;共聚杭州&#xff0c;共同见证拥抱AI的萤石&#xff0c;将如何全新升级&#xff0c;AI加持下的智慧生活又有何不同。 发…

【WinDbg读取蓝屏的dmp日志】iaStorAC.sys 蓝屏解决

读取蓝屏日志&#xff1a; Window偶尔一次蓝屏不用管。 经常蓝屏重置或重装系统。 想要知道为什么蓝屏&#xff0c;通过WinDbg查看蓝屏日志。 蓝屏日志查找和配置 1&#xff0c;蓝屏那一刻拍照蓝屏的界面&#xff0c;即可知道基本的蓝屏信息。 2&#xff0c;蓝屏日志的配置…

从0开始搭建vue + flask 旅游景点数据分析系统(一):创建前端项目

根据前面的爬虫课程&#xff0c;我们重新开一个坑&#xff0c;就是基于爬取到的数据&#xff0c;搭建一个vueflask的前后端分离的数据分析系统 1 通过这个系列教程可以学习到什么&#xff1f; 从0开始搭建一个 vue flask 的数据分析系统&#xff1b;了解系统的整体架构&…

通信类IEEE会议——第四届通信技术与信息科技国际学术会议(ICCTIT 2024)

[IEEE 独立出版&#xff0c;中山大学主办&#xff0c;往届均已见刊检索] 第四届通信技术与信息科技国际学术会议&#xff08;ICCTIT 2024&#xff09; 2024 4th International Conference on Communication Technology and Information Technology 重要信息 大会官网&#xf…

Visual Studio调试Web项目

一、编译运行调试&#xff08;VS快捷键&#xff1a;CtrlF5&#xff09; 缺点&#xff1a;编译运行项目太慢&#xff0c;整体程序有些编译报错运行不了 二、附加到进程调试&#xff08;VS快捷键&#xff1a;CtrlAltP&#xff0c;选择w3wp.exe&#xff09; 无需编译&#xff0c;速…

数据结构之栈详解

1. 栈的概念以及结构 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后进先出LIFO&#xff08;Last In First Out&#xff09;的原则。 压栈…

7.23模拟赛总结 [数据结构优化dp] + [神奇建图]

目录 复盘题解T2T4 复盘 浅复盘下吧… 7:40 开题 看 T1 &#xff0c;起初以为和以前某道题有点像&#xff0c;子序列划分&#xff0c;注意到状态数很少&#xff0c;搜出来所有状态然后 dp&#xff0c;然后发现这个 T1 和那个毛关系没有 浏览了一下&#xff0c;感觉 T2 题面…

并发编程--volatile

1.什么是volatile volatile是 轻 量 级 的 synchronized&#xff0c;它在多 处 理器开 发 中保 证 了共享 变 量的 “ 可 见 性 ” 。可 见 性的意思是当一个 线 程 修改一个共享变 量 时 &#xff0c;另外一个 线 程能 读 到 这 个修改的 值 。如果 volatile 变 量修 饰 符使用…

day08:订单状态定时处理、来单提醒和客户催单

文章目录 Spring Task介绍cron表达式入门案例 订单状态定时处理需求分析代码开发扩展 WebSocket介绍入门案例特点 来单提醒需求分析和设计代码实现 客户催单需求分析和设计代码实现 Spring Task 介绍 Spring Task 是Spring框架提供的任务调度工具&#xff0c;可以按照约定的时…

谁说只有车载HMI界面?现在工业类的HMI界面UI也崛起了

谁说只有车载HMI界面&#xff1f;现在工业类的HMI界面UI也崛起了 引言 艾斯视觉作为行业ui设计和前端开发领域的从业者&#xff0c;其观点始终认为&#xff1a;工业自动化和智能化水平不断提高&#xff0c;人机界面&#xff08;Human-Machine Interface&#xff0c;简称HMI&a…