如何使用C++调用Pytorch模型进行推理测试:使用libtorch库

如何使用C++调用Pytorch模型进行推理测试:使用libtorch库

目录

      • 如何使用C++调用Pytorch模型进行推理测试:使用libtorch库
        • 一、环境准备
          • 1,linux:以ubuntu 22.04系统为例
            • 1. 准备CUDA和CUDNN
            • 2. 准备C++环境
            • 3, 下载libtorch文件
            • 4, 编写测试libtorch是否安装成功
          • 2, windows: 以win10系统为例
            • 1, 准备CUDA和CUDNN
            • 2,准备C++编译环境
            • 3,下载安装libtorch
            • 4. 注意事项
          • 二、C++代码封装Pytorch模型测试:以resnet-18分类为例
          • 1, 安装opencv用于读取图像
          • 2,用python导出训练好的pytorch模型
          • 3,编写C++代码测试

一、环境准备
1,linux:以ubuntu 22.04系统为例
1. 准备CUDA和CUDNN

有两种方式配置cuda和cudnn,一种是在系统环境安装,可以参考:深度学习环境配置——ubuntu安装CUDA与CUDNN

还有一种是在conda虚拟环境使用cudatoolkit-dev包,具体可以参考:Installing-and-Test-PyTorch-C-API-on-Ubuntu-with-GPU-enabled

我选择的方式是在系统环境安装cuda12.1和cudnn8.9.2。

可使用如下命令查看是否安装成功:

NVCC -V
cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2

image-20240625103837610

2. 准备C++环境

安装gcc, cmake和GLIBC,用apt install即可

可使用如下命令是否查看是否安装成功:

gcc --version
cmake --version
ldd --version

image-20240625103749911

3, 下载libtorch文件

去pytoch官网https://pytorch.org/下载即可:

image-20240625103946244

可使用如下命令下载并解压:

wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.3.1%2Bcu121.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.3.1+cu121.zip

将libtorch路径配置到path变量:

vim ~/.bashrc

最后一行加入:

export LD_LIBRARY_PATH=/path/to/libtorch/lib:$LD_LIBRARY_PATH

注意将/path/to/libtorch替换为实际的path,我这里是/mnt/data1/zq/libtorch

查看是否成功:

source ~/.bashrc
echo $LD_LIBRARY_PATH

image-20240625110447696

4, 编写测试libtorch是否安装成功

创建main.cpp文件,内容如下:

#include <torch/torch.h>
#include <iostream>int main() {if (torch::cuda::is_available()) {std::cout << "CUDA is available! Running on GPU." << std::endl;// 创建一个随机张量并将其移到GPU上torch::Tensor tensor_gpu = torch::rand({2, 3}).cuda();std::cout << "Tensor on GPU:\n" << tensor_gpu << std::endl;} else {std::cout << "CUDA not available! Running on CPU." << std::endl;// 创建一个随机张量并保持在CPU上torch::Tensor tensor_cpu = torch::rand({2, 3});std::cout << "Tensor on CPU:\n" << tensor_cpu << std::endl;}return 0;
}

编译和运行

创建CMakeLists.txt文件,内容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(test_project)# Setting the C++ standard to C++17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)# If additional compiler flags are needed
add_compile_options(-Wall -Wextra -pedantic)# Setting the location of LibTorch
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)# Specify the name of the executable and the corresponding source file
add_executable(test_project main.cpp)# Linking LibTorch libraries
target_link_libraries(test_project "${TORCH_LIBRARIES}")# Set the output directory for the executable
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin)

/path/to/libtorch替换为实际的path

编译并测试:

mkdir build
cd build
cmake ..
make 

编译完成之后,应该会出现一个bin目录,其中有一个test_project文件,直接运行即可看到输出。

image-20240625111448917

出现CUDAFloatType说明,libtorch的GPU版本安装成功。

2, windows: 以win10系统为例
1, 准备CUDA和CUDNN

可参考:Windows10下CUDA与cuDNN的安装

2,准备C++编译环境

这一步需要配置cmake, mingw。可参考:Windows 配置 C/C++ 开发环境

建议直接安装Visual Studio这个IDE,可参考:Windows libtorch C++部署GPU版

3,下载安装libtorch

参考这个视频:

win10系统上LibTorch的安装和使用(cuda10.1版本)

一个很水的LibTorch教程(1)

4. 注意事项

windows环境我没有做测试,不保证一定可以成功。linux环境是亲自测试的,保证可以复现

二、C++代码封装Pytorch模型测试:以resnet-18分类为例
1, 安装opencv用于读取图像

需要使用opencv来读取图像数据,可通过如下命令安装:

sudo apt install libopencv-dev
dpkg -l | grep libopencv # 查看是否安装成功
2,用python导出训练好的pytorch模型

在将PyTorch模型应用于C++环境之前,需要将其转换为TorchScript。这可以通过两种方式实现:tracingscripting。可以通过如下代码导出训练好的ResNet-18模型:

import torch
import torchvision# 加载预训练的模型
model = torchvision.models.resnet18(pretrained=True)# 将模型设置为评估模式
model.eval()# 创建一个示例输入
example_input = torch.rand(1, 3, 224, 224)  # 模型输入的大小# 使用tracing导出模型
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("resnet18.pt")
3,编写C++代码测试

创建main.cpp文件,内容如下:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <filesystem>// Function to transform image to tensor
torch::Tensor transform_image(const cv::Mat& image) {cv::Mat img_transformed;cv::cvtColor(image, img_transformed, cv::COLOR_BGR2RGB);cv::resize(img_transformed, img_transformed, cv::Size(224, 224));img_transformed.convertTo(img_transformed, CV_32FC3, 1.0/255);auto img_tensor = torch::from_blob(img_transformed.data, {img_transformed.rows, img_transformed.cols, 3}, torch::kFloat);img_tensor = img_tensor.permute({2, 0, 1});img_tensor = torch::data::transforms::Normalize<torch::Tensor>({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(img_tensor);img_tensor = img_tensor.unsqueeze(0);return img_tensor;
}// Load the model and classify an image
void classify_image(const std::string& model_path, const std::string& image_path) {// Load the modeltorch::jit::script::Module model = torch::jit::load(model_path);model.eval(); // Switch to evaluation mode// Load and transform the imagecv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);if (image.empty()) {std::cerr << "Could not read the image: " << image_path << std::endl;return;}torch::Tensor tensor_image = transform_image(image);// Perform inferencetorch::Tensor output = model.forward({tensor_image}).toTensor();int64_t pred = output.argmax(1).item<int64_t>();std::cout << "The image is classified as class index: " << pred << std::endl;
}int main(int argc, char* argv[]) {std::string model_path = "resnet18.pt"; // Default model pathstd::string image_path = "default_image.jpg"; // Default image path// 从命令行接受两个参数, 分别作为model_path和image_pathif (argc >= 3) {model_path = argv[1];image_path = argv[2];} else {std::cout << "Using default model and image paths." << std::endl;}classify_image(model_path, image_path);return 0;
}

创建CMakeLists.txt,内容如下:

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(ImageClassification)set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)# 设置LibTorch的位置, /path/to/libtorch替换为实际路径
set(Torch_DIR "/path/to/libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)find_package(OpenCV REQUIRED)add_executable(ImageClassification main.cpp)
target_link_libraries(ImageClassification "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")

编译并运行:

mkdir build && cd build
cmake ..
make

在build目录下会出现ImageClassification这个可执行文件,直接运行传入model_path和image_path即可。

image-20240625114911739

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/40137.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

期权学习必看圣书:《3小时快学期权》要在哪里看?

今天带你了解期权学习必看圣书&#xff1a;《3小时快学期权》要在哪里看&#xff1f;《3小时快学期权》是一本关于股票期权基础知识的书籍。 它旨在通过简明、易懂的语言和实用的案例&#xff0c;让读者在短时间内掌握股票期权的基本概念、操作方法和投资策略。通过这本书&…

Linux系统(CentOS)安装Mysql5.7.x

安装准备&#xff1a; Linux系统(CentOS)添加防火墙、iptables的安装和配置 请访问地址&#xff1a;https://blog.csdn.net/esqabc/article/details/140209894 1&#xff0c;下载mysql安装文件&#xff08;mysql-5.7.44为例&#xff09; 选择Linux通用版本64位&#xff08;L…

算力互联网网络架构;SRV6;智享WAN

目录 算力互联网网络架构 SRV6 主要特点 应用场景 结论 G-SRV6 多层次网络切片 智享WAN 一、定义与背景 二、关键技术 三、应用场景与优势 四、发展现状与未来展望 智能算力网络成为智能经济时代代表性数字基础设施 算力互联网网络架构 为构建算力互联网这个前瞻性…

后端之路——阿里云OSS云存储

一、何为阿里云OSS 全名叫“阿里云对象存储OSS”&#xff0c;就是云存储&#xff0c;前端发文件到服务器&#xff0c;服务器不用再存到本地磁盘&#xff0c;可以直接传给“阿里云OSS”&#xff0c;存在网上。 二、怎么用 大体逻辑&#xff1a; 细分的话就是&#xff1a; 1、准…

JavaSE (Java基础):面向对象(下)

8.7 多态 什么是多态&#xff1f; 即同一方法可以根据发送对象的不同而采用多种不同的方式。 一个对象的实际类型是确定的&#xff0c;但可以指向对象的引用的类型有很多。在句话我是这样理解的&#xff1a; 在实例中使用方法都是根据他最开始将类实例化最左边的类型来定的&…

消息中间件ApacheKafka在windows简单安装

一.背景 之前公司需要API网关管理软件ApacheShenYu&#xff0c;我相信把调用的记录都存到一个数据库。他支持日志推送到kafka&#xff0c;所以&#xff0c;我准备尝试一下通过kafka接收调用的日志信息。第一步&#xff0c;当然是安装kafka了。 二.ApacheKafka的下载 打开下载…

【C++】 解决 C++ 语言报错:Memory Leak

文章目录 引言 内存泄漏&#xff08;Memory Leak&#xff09;是 C 编程中常见且严重的内存管理问题之一。当程序分配了内存而没有正确释放&#xff0c;导致内存无法被重新利用时&#xff0c;就会发生内存泄漏。这种错误会导致程序占用越来越多的内存&#xff0c;最终可能导致系…

论文学习——动态多目标优化的一种新的分位数引导的对偶预测策略

论文题目&#xff1a;A novel quantile-guided dual prediction strategies for dynamic multi-objective optimization 动态多目标优化的一种新的分位数引导的对偶预测策略&#xff08;Hao Sun a,b, Anran Cao a,b, Ziyu Hu a,b, Xiaxia Li a,b, Zhiwei Zhao c&#xff09;In…

“免费”的可视化大屏案例分享-智慧园区综合管理平台

一.智慧园区是什么&#xff1f; 智慧园区是一种融合了新一代信息与通信技术的先进园区发展理念。它通过迅捷信息采集、高速信息传输、高度集中计算、智能事务处理和无所不在的服务提供能力&#xff0c;实现了园区内及时、互动、整合的信息感知、传递和处理。这样的园区旨在提高…

正确使用Pytorch Geometric打开Cora(Planetoid)数据集

文章目录 关于报错&#xff08;"Cannot connect to host"&#xff09;解决方法 关于报错&#xff08;“Cannot connect to host”&#xff09; 我们在使用PyG调用Planetoid数据集的时候&#xff0c;常会碰到如下报错&#xff1a; 解决方法就是手动下载这个数据集。…

前端播放RTSP视频流,使用FLV请求RTSP视频流播放(Vue项目,在Vue中使用插件flv.js请求RTSP视频流播放)

简述&#xff1a;在浏览器中请求 RTSP 视频流并进行播放时&#xff0c;直接使用原生的浏览器 API 是行不通的&#xff0c;因为它们不支持 RTSP 协议。为了解决这个问题&#xff0c;开发者通常会选择使用像 flv.js 这样的库&#xff0c;它专为在浏览器中播放 FLV 和其他流媒体格…

MySQL 代理层:ProxySQL

文章目录 说明安装部署1.1 yum 安装1.2 启停管理1.3 查询版本1.4 Admin 管理接口 入门体验功能介绍3.1 多层次配置系统 读写分离将实例接入到代理服务定义主机组之间的复制关系配置路由规则事务读的配置延迟阈值和请求转发 ProxySQL 核心表mysql_usersmysql_serversmysql_repli…

Java实现日志全链路追踪.精确到一次请求的全部流程

广大程序员在排除线上问题时,会经常遇见各种BUG.处理这些BUG的时候日志就格外的重要.只有完善的日志才能快速有效的定位问题.为了提高BUG处理效率.我决定在日志上面优化.实现每次请求有统一的id.通过id能获取当前接口的全链路流程走向. 实现效果如下: 一次查询即可找到所有关…

自定义一个背景图片的高度,随着容器高度的变化而变化,小于图片的高度时裁剪,大于时拉伸100%展示

1、通过js创建<image?>标签来获取背景图片的宽高比&#xff1b; 2、当元素的高度大于原有比例计算出来的高度时&#xff0c;背景图片的高度拉伸自适应100%&#xff0c;否则高度为auto&#xff0c;会自动被裁减 3、背景图片容器高度变化时&#xff0c;自动计算背景图片的…

Android network - NUD检测机制(Android 14)

Android network - NUD检测机制 1. 前言2. 源码分析2.1 ClientModeImpl2.2 IpClient2.3 IpReachabilityMonitor 1. 前言 在Android系统中&#xff0c;NUD&#xff08;Neighbor Unreachable Detection&#xff09;指的是网络中的邻居不可达检测机制&#xff0c;它用于检测设备是…

一文了解常见DNS问题

当企业的DNS出现故障时&#xff0c;为不影响企业的正常运行&#xff0c;团队需要能够快速确定问题的性质和范围。那么有哪些常见的DNS问题呢&#xff1f; 域名解析失败&#xff1a; 当您输入一个域名&#xff0c;但无法获取到与之对应的IP地址&#xff0c;导致无法访问相应的网…

获取VC账号,是成为亚马逊供应商的全面准备与必要条件

成为亚马逊的供应商&#xff0c;拥有VC&#xff08;Vendor Central&#xff09;账号&#xff0c;是众多制造商和品牌所有者的共同目标。这不仅代表了亚马逊对供应商的高度认可&#xff0c;也意味着获得了更多的销售机会和更广阔的市场前景。 全面准备与必要条件是获取VC账号的关…

代码转换成AST语法树移除无用代码console.log、import

公司中代码存在大量,因此产生 可以使用 @babel/parser 解析代码生成 AST (抽象语法树),然后使用 @babel/traverse 进行遍历并删除所有的 console.log 语句,最后使用 @babel/generator 生成修改后的代码。 这里有一个网址,可以线上解析代码转换成AST语法树: https://astex…

Python爬虫康复训练——笔趣阁《神魂至尊》

还是话不多说&#xff0c;很久没写爬虫了&#xff0c;来个bs4康复训练爬虫&#xff0c;正好我最近在看《神魂至尊》&#xff0c;爬个txt文件下来看看 直接上代码 """ 神魂至尊网址-https://www.bqgui.cc/book/1519/ """ import requests from b…

【C++】 解决 C++ 语言报错:未定义行为(Undefined Behavior)

文章目录 引言 未定义行为&#xff08;Undefined Behavior, UB&#xff09;是 C 编程中非常危险且难以调试的错误之一。未定义行为发生时&#xff0c;程序可能表现出不可预测的行为&#xff0c;导致程序崩溃、安全漏洞甚至硬件损坏。本文将深入探讨未定义行为的成因、检测方法…