Pytorch的C++接口实践

Pytorch1.1版本已经提供了相对稳定的c++接口,网上也有了众多的资料供大家参考,进行c++的接口的初步尝试。

可以按照对应的选项下载,下面我们要说的是:

如何利用已经编译好的官方libtorch库和其他的opencv库等联合编写应用?

其实很简单,大概的步骤有三步:

第一步:在python环境下将模型导出为jit的模型

第二步:编写对应的c++ inference 程序。

第三步:直接在VS上(已经成功实验VS2015,高版本的应该也可以)配置相应的libtorch环境,主要是:

dll路径: 

PATH=H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib%3bD:\opencv\build\x64\vc14\bin%3b$(PATH)  相应地去修改即可,不需要在PC的path环境下加入libtorch的路径,而是在这里加更加简单。

include路径:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include\torch\csrc\api\include;H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include;D:\opencv\build\include\opencv2;D:\opencv\build\include\opencv;D:\opencv\build\include;%(AdditionalIncludeDirectories)

主要是加粗线那两个。

注意一定要去掉SDL的检查项,否则会出现错误警告。

lib路径:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib;D:\opencv\build\x64\vc14\lib;%(AdditionalLibraryDirectories)

详细的工程见:https://download.csdn.net/download/xiamentingtao/11486608

这里我们主要改编自:《Win10+VS2017+PyTorch(libtorch) C++ 基本应用》

主要代码参考: https://github.com/zhpmatrix/load-pytorch-model-with-c-

一些 常见的问题:

1. opencv的mat读入libtorch

根据我的实践,这里的最佳写法是:

src = imread(s, cv::IMREAD_COLOR);  //读图// 图像预处理 注意需要和python训练时的预处理一致
int org_w = src.cols;
int org_h = src.rows;torch::Tensor img_tensor = torch::from_blob(src.data, { org_h, org_w,3 }, torch::kByte); //将cv::Mat转成tensor,大小为448,448,3
img_tensor = img_tensor.permute({ 2, 0, 1 });  //调换顺序变为torch输入的格式 3,448,448
img_tensor = img_tensor.toType(torch::kFloat32).div_(255);

注意要先将uint8的图像先读入,再转换成float型。

2. Tensor 转换成cv::Mat

cv::Mat input(img_tensor.size(1), img_tensor.size(2), CV_32FC1, img_tensor.data<float>());

注意这里一定是CV_32FC1而不是CV_32FC3

另外的方式见:https://discuss.pytorch.org/t/convert-torch-tensor-to-cv-mat/42751/2

torch::Tensor out_tensor = module->forward(inputs).toTensor();
assert(out_tensor.device().type() == torch::kCUDA);
out_tensor=out_tensor.squeeze().detach().permute({1,2,0});
out_tensor=out_tensor.mul(255).clamp(0,255).to(torch::kU8);
out_tensor=out_tensor.to(torch::kCPU);
cv::Mat resultImg(512, 512,CV_8UC3);
std::memcpy((void*)resultImg.data,out_tensor.data_ptr(),sizeof(torch::kU8)*out_tensor.numel());

3. model的输出处理

如果只有一个返回值,可以直接转tensor:auto outputs = module->forward(inputs).toTensor();如果有多个返回值,需要先转tuple:auto outputs = module->forward(inputs).toTuple();
torch::Tensor out1 = outputs->elements()[0].toTensor();
torch::Tensor out2 = outputs->elements()[1].toTensor();

4.Tracing fails because of “parameter sharing”?

看这个案例:https://discuss.pytorch.org/t/help-tracing-fails-because-of-parameter-sharing/40324

其中的部分代码如上,问题就出现在这些画框的地方,主要是这里初始化重复使用了相同的模块进行赋值,例如self.encoder与self.conv1。

解决的办法就是在构造slef.conv1时,对self.encoder[0]加入deepcopy修饰。

即:

from copy import deepcopy
self.conv1 = nn.Sequential(deepcopy(self.encoder[0]),deepcopy(self.relu),deepcopy(self.encoder[2]),deepcopy(self.relu))

参考:https://github.com/pytorch/pytorch/issues/8392#issuecomment-431863763

5. 关于python导出模型的问题

如果训练的pytorch模型保存在cpu上,想在测试时使用gpu模式,则我们需要设置python端保存模型在gpu上,然后才能c++上使用gpu测试。

主要的方法就是:

    checkpoint = torch.load(model_path, map_location="cuda:0")  #very important# create modelmodel = TheModelClass(*args, **kwargs)model.load_state_dict(checkpoint)model.to(device)model.eval()x = torch.rand(1, 3, 448, 448)x = x.to(device)  # very importanttraced_script_module = torch.jit.trace(model.model, x)traced_script_module.save("**.pt")

然后才能在c++上使用gpu模式,方法为:

    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);module->to(at::kCUDA);assert(module != nullptr);std::cout << "ok\n";// 建立一个输入,维度为(1,3,224,224),并移动至cudastd::vector<torch::jit::IValue> inputs;inputs.push_back(torch::ones({1, 3, 224, 224}).to(at::kCUDA));// Execute the model and turn its output into a tensor.at::Tensor output = module->forward(inputs).toTensor();

参考:

 

pytorch跨设备保存和加载模型(变量类型(cpu/gpu)不匹配原因之一)

https://pytorch.org/tutorials/beginner/saving_loading_models.html

https://blog.csdn.net/IAMoldpan/article/details/85057238

参考文献:

1.利用Pytorch的C++前端(libtorch)读取预训练权重并进行预测

2.Pytorch的C++端(libtorch)在Windows中的使用

3. https://pytorch.org/tutorials/advanced/cpp_frontend.html

4. https://zhpmatrix.github.io/2019/03/01/c++-with-pytorch/

5. Windows使用C++调用Pytorch1.0模型

6. 用cmake构建基于qt5,opencv,libtorch项目

7. c++调用pytorch模型并使用GPU进行预测 (较好的例子)

8. Ptorch 与libTorch 使用过程中问题记录

9. c++ load pytorch 的数据转换

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/258502.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HttpClient v4.5 简单抓取主页数据

由于工作原因&#xff0c;需要每隔半小时刷新一些网页&#xff0c;并查看上面的数据是否有更新。这件事能否自动化进行呢&#xff1f;查找了下Java相关的资料&#xff0c;蹦出一个关键词&#xff1a;HttpClient。 HttpClient是常用Http客户端库&#xff0c;相关的资料也不少&am…

matlab局部放大的图中图画法

【亲测有效】 在作图过程中&#xff0c;如果想将局部信息展示出来并且画在同一张图中&#xff0c;一般的MATLAB作图法就比较拙计了&#xff0c;好在MATLAB还是很强大的&#xff0c;当然&#xff0c;除了不能当女朋友之外 .... ╮(╯▽╰)╭ function showdetail()% 在当前的ax…

【2020年】最新中国科学院大学学位论文写作规范

最近在完成国科大博士论文写作的时候&#xff0c;有一些心得体会&#xff0c;特此总结下来&#xff0c;以飨读者&#xff0c;尤其是可爱的学弟学妹们。需要注意的是&#xff0c; 以下仅仅是我自己的心得而已&#xff0c;仅供参考。 1. 首先推荐大家使用国科大的Latex模板&…

用fft对信号进行频谱分析实验报告_示波器上的频域分析利器,Spectrum View测试分析...

简介&#xff1a;【Spectrum View技术文章系列】从基础篇开始&#xff0c;讲述利用示波器上的Spectrum View功能观测多通道信号频谱分析正文&#xff1a;示波器和频谱仪都是电子测试测量中必不可少的测试设备&#xff0c;分别用于观察信号的时域波形和频谱。时域波形是信号最原…

复盘caffe安装

最近因之前的服务器上的caffe奔溃了&#xff0c;不得已重新安装这一古老的深度学习框架&#xff0c;之前也尝试了好几次&#xff0c;每次都失败&#xff0c;这次总算是成功了&#xff0c;因此及时地总结一下。 以下安装的caffe主要是针对之前虹膜分割和巩膜分割所需的caffe版本…

HP P2000 RAID-5两块盘离线的数据恢复报告

1. 故障描述本案例是HP P2000的存储vmware exsi虚拟化平台&#xff0c;由RAID-5由10块lT硬盘组成&#xff0c;其中6号盘是热备盘&#xff0c;由于故障导致RAID-5磁盘阵列的两块盘掉线&#xff0c;表现为两块硬盘亮黄灯。 经用户方维护人员检测&#xff0c;故障硬盘应为物理故障…

为什么torch.nn.Linear的表达形式为y=xA^T+b而不是常见的y=Ax+b?

今天看代码&#xff0c;对比了常见的公式表达与代码的表达&#xff0c;发觉torch.nn.Linear的数学表达与我想象的有点不同&#xff0c;于是思索了一番。 众多周知&#xff0c;torch.nn.Linear作为全连接层&#xff0c;将下一层的每个结点与上一层的每一节点相连&#xff0c;用…

Leetcode47: Palindrome Linked List

Given a singly linked list, determine if it is a palindrome. 推断一个链表是不是回文的&#xff0c;一个比較简单的办法是把链表每一个结点的值存在vector里。然后首尾比較。时间复杂度O(n)。空间复杂度O(n)。 /*** Definition for singly-linked list.* struct ListNode {…

【转】七个例子帮你更好地理解 CPU 缓存

我的大多数读者都知道缓存是一种快速、小型、存储最近已访问的内存的地方。这个描述相当准确&#xff0c;但是深入处理器缓存如何工作的“枯燥”细节&#xff0c;会对尝试理解程序性能有很大帮助。在这篇博文中&#xff0c;我将通过示例代码来说明缓存是如何工作的&#xff0c;…

win10 平台VS2019最简安装实现C++/C开发

这两天一直在安装vs2015,总是卡在visual studio 2015 出现安装包丢失或损坏的现象&#xff0c;尽管按照网上很多方法尝试解决&#xff0c;但是一直不行。算了。还是使用最新版的VS 2019安装&#xff0c;没想到很顺利。 下面总结一下在win10平台上最简安装VS2019&#xff0c;实…

Hook的两个小插曲

看完了前面三篇文章后&#xff0c;这里我们来一个小插曲~~~~ 第一个小插曲。是前面文章一个CM精灵的分析。我们这里使用hook代码来搞定。 第二个小插曲&#xff0c;是如今一些游戏&#xff0c;都有了支付上限&#xff0c;比如每天仅仅能花20块钱来购买。好了。以下我们分开叙述…

微信小程序和vue双向绑定哪里不一样_个人理解Vue和React区别

本文转载自掘金&#xff0c;作者&#xff1a;binbinsilk&#xff0c;监听数据变化的实现原理不同Vue 通过 getter/setter 以及一些函数的劫持&#xff0c;能精确知道数据变化&#xff0c;不需要特别的优化就能达到很好的性能React 默认是通过比较引用的方式进行的&#xff0c;如…

JS 省,市,区

1 // 纯JS省市区三级联动2 // 2011-11-30 by http://www.cnblogs.com/zjfree3 var addressInit function (_cmbProvince, _cmbCity, _cmbArea, defaultProvince, defaultCity, defaultArea) {4 var cmbProvince document.getElementById(_cmbProvince);5 var cmbCity…

使用极链/AutoDL云服务器复盘caffe安装

继上一次倒腾caffe安装以后&#xff0c;因为博士毕业等原因&#xff0c;旧的服务器已经不能再使用&#xff0c;最近因论文等原因&#xff0c;不得不继续来安装一下我的caffe。这次运气比较好&#xff0c;经历了一晚上和一早上的痛苦之后&#xff0c;最终安装成功了&#xff0c;…

Samba服务

####################samba####################1.samba作用提供cifs协议实现共享文件2.安装yum install samba samba-common samba-client -ysystemctl start smb nmbsystemctl enable smb nmb3.添加smb用户smb用户必须是本机用户[rootlocalhost ~]# smbpasswd -a student New…

CodeForces 543D 树形DP Road Improvement

题意&#xff1a; 有一颗树&#xff0c;每条边是好边或者是坏边&#xff0c;对于一个节点为x&#xff0c;如果任意一个点到x的路径上的坏边不超过1条&#xff0c;那么这样的方案是合法的&#xff0c;求所有合法的方案数。 对于n个所有可能的x&#xff0c;输出n个答案。 分析&am…

理解Javascritp中的引用

Author: bugall Wechat: bugallF Email: 769088641qq.com Github: https://github.com/bugall一&#xff1a; 函数中的引用传递 我们看下下面的代码的正确输出是什么 function changeStuff(a, b, c) {a a * 10;b.item "changed";c {item: "changed"}; …

ONOS系统架构演进,实现高可用性解决方案

上一篇文章《ONOS高可用性和可扩展性实现初探》讲到了ONOS系统架构在高可用、可扩展方面技术概况&#xff0c;提到了系统在分布式集群中怎样保证数据的一致性。在数据终于一致性方面&#xff0c;ONOS採用了Gossip协议。这一部分的变化不大&#xff0c;而在强一致性方案的选择方…

Struts2_day01

Java Web开发常用框架 SSH(Struts2 Spring Hibernate)SSM(Struts2 Spring MyBatis)SSI(Struts2 Spring iBatis) 多种框架协同工作 Web层 -- Service层 -- Dao层 Struts2框架: Struts2是一个基于MVC设计模式的Web应用框架&#xff0c;它本质上相当于一个servlet&#xff0c;在MV…

使用 python 开发 Web Service

使用 python 开发 Web Service Python 是一种强大的面向对象脚本语言&#xff0c;用 python 开发应用程序往往十分快捷&#xff0c;非常适用于开发时间要求苛刻的原型产品。使用 python 开发 web service 同样有语言本身的简捷高速的特点&#xff0c;能使您快速地提供新的网络服…