Pytorch的C++接口实践

Pytorch1.1版本已经提供了相对稳定的c++接口,网上也有了众多的资料供大家参考,进行c++的接口的初步尝试。

可以按照对应的选项下载,下面我们要说的是:

如何利用已经编译好的官方libtorch库和其他的opencv库等联合编写应用?

其实很简单,大概的步骤有三步:

第一步:在python环境下将模型导出为jit的模型

第二步:编写对应的c++ inference 程序。

第三步:直接在VS上(已经成功实验VS2015,高版本的应该也可以)配置相应的libtorch环境,主要是:

dll路径: 

PATH=H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib%3bD:\opencv\build\x64\vc14\bin%3b$(PATH)  相应地去修改即可,不需要在PC的path环境下加入libtorch的路径,而是在这里加更加简单。

include路径:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include\torch\csrc\api\include;H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include;D:\opencv\build\include\opencv2;D:\opencv\build\include\opencv;D:\opencv\build\include;%(AdditionalIncludeDirectories)

主要是加粗线那两个。

注意一定要去掉SDL的检查项,否则会出现错误警告。

lib路径:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib;D:\opencv\build\x64\vc14\lib;%(AdditionalLibraryDirectories)

详细的工程见:https://download.csdn.net/download/xiamentingtao/11486608

这里我们主要改编自:《Win10+VS2017+PyTorch(libtorch) C++ 基本应用》

主要代码参考: https://github.com/zhpmatrix/load-pytorch-model-with-c-

一些 常见的问题:

1. opencv的mat读入libtorch

根据我的实践,这里的最佳写法是:

src = imread(s, cv::IMREAD_COLOR);  //读图// 图像预处理 注意需要和python训练时的预处理一致
int org_w = src.cols;
int org_h = src.rows;torch::Tensor img_tensor = torch::from_blob(src.data, { org_h, org_w,3 }, torch::kByte); //将cv::Mat转成tensor,大小为448,448,3
img_tensor = img_tensor.permute({ 2, 0, 1 });  //调换顺序变为torch输入的格式 3,448,448
img_tensor = img_tensor.toType(torch::kFloat32).div_(255);

注意要先将uint8的图像先读入,再转换成float型。

2. Tensor 转换成cv::Mat

cv::Mat input(img_tensor.size(1), img_tensor.size(2), CV_32FC1, img_tensor.data<float>());

注意这里一定是CV_32FC1而不是CV_32FC3

另外的方式见:https://discuss.pytorch.org/t/convert-torch-tensor-to-cv-mat/42751/2

torch::Tensor out_tensor = module->forward(inputs).toTensor();
assert(out_tensor.device().type() == torch::kCUDA);
out_tensor=out_tensor.squeeze().detach().permute({1,2,0});
out_tensor=out_tensor.mul(255).clamp(0,255).to(torch::kU8);
out_tensor=out_tensor.to(torch::kCPU);
cv::Mat resultImg(512, 512,CV_8UC3);
std::memcpy((void*)resultImg.data,out_tensor.data_ptr(),sizeof(torch::kU8)*out_tensor.numel());

3. model的输出处理

如果只有一个返回值,可以直接转tensor:auto outputs = module->forward(inputs).toTensor();如果有多个返回值,需要先转tuple:auto outputs = module->forward(inputs).toTuple();
torch::Tensor out1 = outputs->elements()[0].toTensor();
torch::Tensor out2 = outputs->elements()[1].toTensor();

4.Tracing fails because of “parameter sharing”?

看这个案例:https://discuss.pytorch.org/t/help-tracing-fails-because-of-parameter-sharing/40324

其中的部分代码如上,问题就出现在这些画框的地方,主要是这里初始化重复使用了相同的模块进行赋值,例如self.encoder与self.conv1。

解决的办法就是在构造slef.conv1时,对self.encoder[0]加入deepcopy修饰。

即:

from copy import deepcopy
self.conv1 = nn.Sequential(deepcopy(self.encoder[0]),deepcopy(self.relu),deepcopy(self.encoder[2]),deepcopy(self.relu))

参考:https://github.com/pytorch/pytorch/issues/8392#issuecomment-431863763

5. 关于python导出模型的问题

如果训练的pytorch模型保存在cpu上,想在测试时使用gpu模式,则我们需要设置python端保存模型在gpu上,然后才能c++上使用gpu测试。

主要的方法就是:

    checkpoint = torch.load(model_path, map_location="cuda:0")  #very important# create modelmodel = TheModelClass(*args, **kwargs)model.load_state_dict(checkpoint)model.to(device)model.eval()x = torch.rand(1, 3, 448, 448)x = x.to(device)  # very importanttraced_script_module = torch.jit.trace(model.model, x)traced_script_module.save("**.pt")

然后才能在c++上使用gpu模式,方法为:

    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);module->to(at::kCUDA);assert(module != nullptr);std::cout << "ok\n";// 建立一个输入,维度为(1,3,224,224),并移动至cudastd::vector<torch::jit::IValue> inputs;inputs.push_back(torch::ones({1, 3, 224, 224}).to(at::kCUDA));// Execute the model and turn its output into a tensor.at::Tensor output = module->forward(inputs).toTensor();

参考:

 

pytorch跨设备保存和加载模型(变量类型(cpu/gpu)不匹配原因之一)

https://pytorch.org/tutorials/beginner/saving_loading_models.html

https://blog.csdn.net/IAMoldpan/article/details/85057238

参考文献:

1.利用Pytorch的C++前端(libtorch)读取预训练权重并进行预测

2.Pytorch的C++端(libtorch)在Windows中的使用

3. https://pytorch.org/tutorials/advanced/cpp_frontend.html

4. https://zhpmatrix.github.io/2019/03/01/c++-with-pytorch/

5. Windows使用C++调用Pytorch1.0模型

6. 用cmake构建基于qt5,opencv,libtorch项目

7. c++调用pytorch模型并使用GPU进行预测 (较好的例子)

8. Ptorch 与libTorch 使用过程中问题记录

9. c++ load pytorch 的数据转换

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/258502.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一次惨痛的装机经历

最近不小心把我的联想一体机电脑系统搞坏了&#xff0c;就不得不重装系统&#xff0c;之前的系统是win7&#xff0c;于是开始的时候想着直接装win10&#xff0c;升级一下系统。但是装的过程中总是卡在了win10的正在准备系统中&#xff0c;进度环不转了。后来转了多次都不行&…

unity让对象作为参数_unity-container – 一个unity容器可以将自身的引用作为构造函数参数传递吗?...

简短的答案是肯定的。当您使用Resolve方法时&#xff0c;这应该自动传递。例如&#xff1a;IUnityContainer container new UnityContainer();var something container.Resolve();另外&#xff0c;如果您想查看&#xff0c;这与Prism(CodePlex)使用的技术相同。更新增加测试&…

KnockoutJS + My97DatePicker

如何将Knockoutjs和其他脚本库结合使用&#xff1f;这里给出一个Knockoutjs与my97datepicker配合使用的例子&#xff0c;例子中使用了ko的自定义绑定功能&#xff1a; ko.bindingHandlers.my97DatePicker {init: function (element, valueAccessor) {$(element).on(click, fun…

HttpClient v4.5 简单抓取主页数据

由于工作原因&#xff0c;需要每隔半小时刷新一些网页&#xff0c;并查看上面的数据是否有更新。这件事能否自动化进行呢&#xff1f;查找了下Java相关的资料&#xff0c;蹦出一个关键词&#xff1a;HttpClient。 HttpClient是常用Http客户端库&#xff0c;相关的资料也不少&am…

matlab局部放大的图中图画法

【亲测有效】 在作图过程中&#xff0c;如果想将局部信息展示出来并且画在同一张图中&#xff0c;一般的MATLAB作图法就比较拙计了&#xff0c;好在MATLAB还是很强大的&#xff0c;当然&#xff0c;除了不能当女朋友之外 .... ╮(╯▽╰)╭ function showdetail()% 在当前的ax…

进入Python世界——Python基础知识

本文通过实例练习Python基础语法, python版本2.7 # -*- coding: utf-8 -*- import randomimport re import requests from bs4 import BeautifulSoup# 爬取糗事百科的首页内容 def qiushibaike():content requests.get(http://www.qiushibaike.com/).contentsoup BeautifulS…

db2 版本发布历史_数据库各厂商的发展历史(2. DB2 of IBM)

如若转载&#xff0c;请务必注明出处&#xff0c;iihero 2008.9.26于CSDN1973年&#xff0c;IBM研究中心启动System R项目&#xff0c;为DB2的诞生打下良好基础。System R 是 IBM 研究部门开发的一种产品&#xff0c;这种原型语言促进了技术的发展并最终在1983年将 DB2 带到了商…

android---简单的通讯录

遗留问题:获取头像及其他信息 利用adapter和Cursor来获取联系人的姓名和手机号,重在复习之前学过的内容加深自己的理解. 其中需要注意的部分: 1.adapter中的getview的优化问题,用到tag这一属性 2.onBackPressed()返回方法的重写,使得程序更加人性化 下面是主要代码 1.adapte…

win phone 获取并且处理回车键事件

参考自&#xff1a;http://www.cnblogs.com/mohe/archive/2013/03/18/2966540.html 实用场景,比如输入帐号和密码啦,输入搜索关键字啦.protected override void OnKeyDown(KeyEventArgs e) {if (e.Key Key.Enter){MessageBox.Show("我是windows phone 回车键"); …

【2020年】最新中国科学院大学学位论文写作规范

最近在完成国科大博士论文写作的时候&#xff0c;有一些心得体会&#xff0c;特此总结下来&#xff0c;以飨读者&#xff0c;尤其是可爱的学弟学妹们。需要注意的是&#xff0c; 以下仅仅是我自己的心得而已&#xff0c;仅供参考。 1. 首先推荐大家使用国科大的Latex模板&…

谈谈Java基础数据类型

Java的基本数据类型 类型意义取值boolean布尔值true或falsebyte8位有符号整型-128~127short16位有符号整型-pow(2,15)~pow(2,15)-1int32位有符号整型-pow(2,31)~pow(2,31)-1long64位有符号整型-pow(2,63)~pow(2,63)-1float32位浮点数IEEE754标准单精度浮点数double64位浮点数IE…

用fft对信号进行频谱分析实验报告_示波器上的频域分析利器,Spectrum View测试分析...

简介&#xff1a;【Spectrum View技术文章系列】从基础篇开始&#xff0c;讲述利用示波器上的Spectrum View功能观测多通道信号频谱分析正文&#xff1a;示波器和频谱仪都是电子测试测量中必不可少的测试设备&#xff0c;分别用于观察信号的时域波形和频谱。时域波形是信号最原…

DataTable RowFilter 过滤数据

用Rowfilter加入过滤条件 eg&#xff1a; string sql "select Name,Age,Sex from UserInfo"; DataTable dt DataAccess.GetDataTable(sql);//外部方法&#xff08;通过一条查询语句返回一个DataTable&#xff09; dt.DefaultView.RowFilter "Sex女"; dt…

platform_device与platform_driver

做Linux方面也有三个多月了&#xff0c;对代码中的有些结构一直不是非常明确&#xff0c;比方platform_device与platform_driver一直分不清关系。在网上搜了下&#xff0c;做个总结。两者的工作顺序是先定义platform_device -> 注冊 platform_device->&#xff0c;再定义…

复盘caffe安装

最近因之前的服务器上的caffe奔溃了&#xff0c;不得已重新安装这一古老的深度学习框架&#xff0c;之前也尝试了好几次&#xff0c;每次都失败&#xff0c;这次总算是成功了&#xff0c;因此及时地总结一下。 以下安装的caffe主要是针对之前虹膜分割和巩膜分割所需的caffe版本…

HP P2000 RAID-5两块盘离线的数据恢复报告

1. 故障描述本案例是HP P2000的存储vmware exsi虚拟化平台&#xff0c;由RAID-5由10块lT硬盘组成&#xff0c;其中6号盘是热备盘&#xff0c;由于故障导致RAID-5磁盘阵列的两块盘掉线&#xff0c;表现为两块硬盘亮黄灯。 经用户方维护人员检测&#xff0c;故障硬盘应为物理故障…

微智魔盒骗局_微智魔盒官宣

原标题&#xff1a;微智魔盒官宣微智魔盒官方宣传视频微达国际集团创建于2011年&#xff0c;是一家坚持创新的集科研、产销、服务为一体的智能化产业平台&#xff0c;致力于国际领先的专注人工智能领域的产业投资、项目孵化、教育培训&#xff0c;并提供终极解决方案。集团创新…

瑞柏匡丞_移动互联的发展现状与未来

互联网作为人类文明史上最伟大、最重要的科技发明之一&#xff0c;发展到今天&#xff0c;用翻天覆地来形容并不过分。而作为传统互联网的延伸和演进方向&#xff0c;移动互联网更是在近两年得到了迅猛的发展。如今&#xff0c;越来越多的用户得以通过高速的移动网络和强大的智…

android 进程间通信数据(一)------parcel的起源

关于parcel&#xff0c;我们先来讲讲它的“父辈” Serialize。 Serialize 是java提供的一套序列化机制。但是为什么要序列化&#xff0c;怎么序列化&#xff0c;序列化是怎么做到的&#xff0c;我们将在本文探讨下。 一&#xff1a;java 中的serialize 关于Serialize这个东东&a…

为什么torch.nn.Linear的表达形式为y=xA^T+b而不是常见的y=Ax+b?

今天看代码&#xff0c;对比了常见的公式表达与代码的表达&#xff0c;发觉torch.nn.Linear的数学表达与我想象的有点不同&#xff0c;于是思索了一番。 众多周知&#xff0c;torch.nn.Linear作为全连接层&#xff0c;将下一层的每个结点与上一层的每一节点相连&#xff0c;用…