PyTorch导出JIT模型并用C++ API libtorch调用

PyTorch导出JIT模型并用C++ API libtorch调用

本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。

Step1:导出模型

首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。

# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

导出 JIT 模型的方式有两种:trace 和 script。

我们采用 torch.jit.trace 的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。

在我们的工程目录 demo 下运行上面的 export_jit_model.py ,会得到一个 JIT 模型件:resnet50_jit.pth

Step 2:安装libtorch

接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:

wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

也解压在我们的工程目录 demo 下即可。

Step 3:安装OpenCV

用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录 demo 下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在 CMakeLists.txt 文件中正确地指定本机的 OpenCV 地址即可。

git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6

Step 4:准备测试图像并用Python测试

我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。

kitten.jpg

在这里插入图片描述

写一个脚本用 PyTorch 运行一下模型:

# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())

输出结果是:282。通过查看ImageNet 1K 类别名与索引的对应关系,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。

Step 5:准备cpp源文件

我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。

这里有几个点要说一下,不注意可能会犯错:

  1. cv::imread() 默认读取为三通道BGR,需要进行B/R通道交换,这里采用 cv::cvtColor()实现。

  2. 图像尺寸需要调整到 224×224224\times 224224×224,通过 cv::resize() 实现。

  3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是np.transpose()操作,这里使用tensor.permut()实现,效果是一样的。

  4. 数据归一化,采用 tensor.div(255) 实现。

// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加载JIT模型auto module = torch::jit::load(argv[1]);// 加载图像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 图像转换为Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 运行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 结果处理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}

Step 6:构建运行验证

我们先来写一下 CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50  test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50  PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

现在我们的工程目录 demo 下有以下文件:

CMakeLists.txt  export_jit_model.py  kitten.jpg  libtorch  pytorch_test.py  resnet50_jit.pth  test_model.cpp

然后开始用 CMake 构建工程:

mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make

整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件 resnet50 在工程目录 demo 下。

接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:

./build/resnet50 resnet50_jit.pth kitten.jpg

输出:

The classifiction index is: 282

运行成功并且结果正确。

Ref:

https://www.jianshu.com/p/7cddc09ca7a4

https://blog.csdn.net/cxx654/article/details/115916275

https://zhuanlan.zhihu.com/p/370455320

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532594.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sqli-lab--writeup(7~10)文件输出,时间布尔盲注

前置知识点&#xff1a; 1、outfile是将检索到的数据&#xff0c;保存到服务器的文件内&#xff1a; 格式&#xff1a;select * into outfile “文件地址” 示例&#xff1a; mysql> select * into outfile ‘f:/mysql/test/one’ from teacher_class; 2、文件是自动创建…

sqli-lab———writeup(11~17)

less11 用户名提交单引号显示sql语法错误&#xff0c;故存在sql注入 根据单引号报错&#xff0c;在用户名和密码任意行输入 万能密码&#xff1a;‘ or 11# 输入后username语句为&#xff1a;SELECT username, password FROM users WHERE username or 11; 双引号 password语…

深入理解Python中的全局解释锁GIL

深入理解Python中的全局解释锁GIL 转自&#xff1a;https://zhuanlan.zhihu.com/p/75780308 注&#xff1a;本文为蜗牛学院资深讲师卿淳俊老师原创&#xff0c;首发自公众号https://mp.weixin.qq.com/s/TBiqbSCsjIbNIk8ATky-tg&#xff0c;如需转载请私聊我处获得授权并注明出处…

sqli-lab————Writeup(18~20)各种头部注入

less18 基于错误的用户代理&#xff0c;头部POST注入 admin admin 登入成功&#xff08;进不去重置数据库&#xff09; 显示如下 有user agent参数&#xff0c;可能存在注入点 显示版本号&#xff1a; 爆库&#xff1a;User-Agent:and extractvalue(1,concat(0x7e,(select …

Python GIL

转自&#xff1a;https://blog.csdn.net/weixin_41594007/article/details/79485847 Python GIL 在进行GIL讲解之前&#xff0c;我们可以先回顾一下并行和并发的区别&#xff1a; 并行&#xff1a;多个CPU同时执行多个任务&#xff0c;就好像有两个程序&#xff0c;这两个程序…

sqli-lab——Writeup21~38(各种过滤绕过WAF和)

Less-21 Cookie Injection- Error Based- complex - string ( 基于错误的复杂的字符型Cookie注入) base64编码&#xff0c;单引号&#xff0c;报错型&#xff0c;cookie型注入。 本关和less-20相似&#xff0c;只是cookie的uname值经过base64编码了。 登录后页面&#xff1a;…

sqli-lab——Writeup(38~over)堆叠等......

知识点&#xff1a; 1.堆叠注入原理&#xff08;stacked injection&#xff09; 在SQL中&#xff0c;分号&#xff08;;&#xff09;是用来表示一条sql语句的结束。试想一下我们在 ; 结束一个sql语句后继续构造下一条语句&#xff0c;会不会一起执行&#xff1f;因此这个想法…

mysql常规使用(建立,增删改查,视图索引)

目录 1.数据库建立 2.增删改查 3.视图建立&#xff1a; 1.数据库建立 mysql> mysql> show databases; ----------------------------------- | Database | ----------------------------------- | information_schema | | ch…

Xctf练习sql注入--supersqli

三种方法 方法一 1 回显正常 1’回显不正常,报sql语法错误 1’ -- 回显正常,说明有sql注入点,应该是字符型注入(# 不能用) 1’ order by 3 -- 回显失败,说明有2个注入点 1’ union select 1,2 -- 回显显示过滤语句: 1’; show databases -- 爆数据库名 -1’; show tables …

深拷贝与浅拷贝、值语义与引用语义对象语义 ——以C++和Python为例

深拷贝与浅拷贝、值语义与引用语义/对象语义 ——以C和Python为例 值语义与引用语义&#xff08;对象语义&#xff09; 本小节参考自&#xff1a;https://www.cnblogs.com/Solstice/archive/2011/08/16/2141515.html 概念 在任何编程语言中&#xff0c;区分深浅拷贝的关键都…

一次打卡软件的实战渗透测试

直接打卡抓包, 发现有疑似企业网站,查ip直接显示以下页面 直接显示了后台安装界面…就很有意思 探针和phpinfo存在 尝试连接mysql失败 fofa扫描为阿里云服务器 找到公司官网使用nmap扫描,存在端口使用onethink 查询onethink OneThink是一个开源的内容管理框架&#xff0c;…

centos7ubuntu搭建Vulhub靶场(推荐Ubuntu)

这里写目录标题一.前言总结二.成功操作&#xff1a;三.出现报错&#xff1a;四.vulhub使用正文&#xff1a;一.前言总结二.成功操作&#xff1a;三.出现报错&#xff1a;四.vulhub使用看完点赞关注不迷路!!!! 后续继续更新优质安全内容!!!!!一.前言总结 二.成功操作&#xff1…

Yapi Mock 远程代码执行漏洞

跟风一波复现Yapi 漏洞描述&#xff1a; YApi接口管理平台远程代码执行0day漏洞&#xff0c;攻击者可通过平台注册用户添加接口&#xff0c;设置mock脚本从而执行任意代码。鉴于该漏洞目前处于0day漏洞利用状态&#xff0c;强烈建议客户尽快采取缓解措施以避免受此漏洞影响 …

CVE-2017-10271 WebLogic XMLDecoder反序列化漏洞

漏洞产生原因&#xff1a; CVE-2017-10271漏洞产生的原因大致是Weblogic的WLS Security组件对外提供webservice服务&#xff0c;其中使用了XMLDecoder来解析用户传入的XML数据&#xff0c;在解析的过程中出现反序列化漏洞&#xff0c;导致可执行任意命令。攻击者发送精心构造的…

树莓派摄像头 C++ OpenCV YoloV3 实现实时目标检测

树莓派摄像头 C OpenCV YoloV3 实现实时目标检测 本文将实现树莓派摄像头 C OpenCV YoloV3 实现实时目标检测&#xff0c;我们会先实现树莓派对视频文件的逐帧检测来验证算法流程&#xff0c;成功后&#xff0c;再接入摄像头进行实时目标检测。 先声明一下笔者的主要软硬件配…

【实战】记录一次服务器挖矿病毒处理

信息收集及kill&#xff1a; 查看监控显示长期CPU利用率超高&#xff0c;怀疑中了病毒 top 命令查看进程资源占用&#xff1a; netstat -lntupa 命令查看有无ip进行发包 netstat -antp 然而并没有找到对应的进程名 查看java进程和solr进程 ps aux &#xff1a;查看所有进程…

ag 搜索工具参数详解

ag 搜索工具参数详解 Ag 是类似ack&#xff0c; grep的工具&#xff0c;它来在文件中搜索相应关键字。 官方列出了几点选择它的理由&#xff1a; 它比ack还要快 &#xff08;和grep不在一个数量级上&#xff09;它会忽略.gitignore和.hgignore中的匹配文件如果有你想忽略的文…

CVE-2013-4547 文件名逻辑漏洞

搭建环境&#xff0c;访问 8080 端口 漏洞说明&#xff1a; Nginx&#xff1a; Nginx是一款轻量级的Web 服务器/反向代理服务器及电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器&#xff0c;在BSD-like 协议下发行。其特点是占有内存少&#xff0c;并发能力强&#xf…

CVE-2017-7529Nginx越界读取缓存漏洞POC

漏洞影响 低危&#xff0c;造成信息泄露&#xff0c;暴露真实ip等 实验内容 漏洞原理 通过查看patch确定问题是由于对http header中range域处理不当造成&#xff0c;焦点在ngx_http_range_parse 函数中的循环&#xff1a; HTTP头部range域的内容大约为Range: bytes4096-81…

Linux命令行性能监控工具大全

Linux命令行性能监控工具大全 作者&#xff1a;Arnold Lu 原文&#xff1a;https://www.cnblogs.com/arnoldlu/p/9462221.html 关键词&#xff1a;top、perf、sar、ksar、mpstat、uptime、vmstat、pidstat、time、cpustat、munin、htop、glances、atop、nmon、pcp-gui、collect…