PyTorch 添加 C++ 拓展

参考内容:pytorch添加C++拓展简单实战编写及基本功能测试

文章目录

  • 第一步:编写 C++ 模块
    • test.h
    • test.cpp
  • 第二步:编写 setup.py
  • 第三步:安装 C++ 模块
  • 第四步:验证安装
  • 第五步:C++ 模块使用
    • test_cpp1.py
    • test_cpp2.py
  • 运行结果
  • 扩展阅读

编译安装前的文件目录:

这里的 csrc 应该不是指 pytorch 项目中的 /torch/csrc

csrc
├─ cpu
│    ├─ test.cpp
│    └─ test.h
└─ setup.py

第一步:编写 C++ 模块

test.h

#include <torch/extension.h>
#include <vector>// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB);// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput);

test.cpp

#include "test.h"// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y){AT_ASSERTM(x.sizes() == y.sizes(), "x must be the same size as y");torch::Tensor z = torch::zeros(x.sizes());z = 2 * x + y;return z;
}// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput){torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes());torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes());return {gradOutputX, gradOutputY};
}// pybind11 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){m.def("forward", &Test_forward_cpu, "TEST forward");m.def("backward", &Test_backward_cpu, "TEST backward");
}

第二步:编写 setup.py

from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__))
# 源代码目录
source_cpu = glob.glob(os.path.join(include_dirs, 'cpu', '*.cpp'))setup(name='test_cpp', # 模块名称,需要在 python 中调用version="0.1",ext_modules=[CppExtension('test_cpp', sources=source_cpu, include_dirs=[include_dirs]),],cmdclass={'build_ext': BuildExtension}
)

第三步:安装 C++ 模块

在 csrc 文件夹下运行命令

python setup.py install

第一次尝试的报错信息:

/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!********************************************************************************Please avoid running ``setup.py`` directly.Instead, use pypa/build, pypa/installer or otherstandards-based tools.See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.********************************************************************************!!self.initialize_options()
/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!********************************************************************************Please avoid running ``setup.py`` and ``easy_install``.Instead, use pypa/build, pypa/installer or otherstandards-based tools.See https://github.com/pypa/setuptools/issues/917 for details.********************************************************************************!!self.initialize_options()

参考 SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip 后得知是 setuptools 版本太高,于是降低 setuptools 版本,pip install setuptools==58.2.0

第二次尝试的运行结果:

running install
running bdist_egg
running egg_info
writing test_cpp.egg-info/PKG-INFO
writing dependency_links to test_cpp.egg-info/dependency_links.txt
writing top-level names to test_cpp.egg-info/top_level.txt
reading manifest file 'test_cpp.egg-info/SOURCES.txt'
writing manifest file 'test_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'test_cpp' extension
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu
Emitting ninja build file /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/1] c++ -MMD -MF /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o.d -pthread -B /home/zjma/.conda/envs/debugtest/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/zjma/pytorch_v1.13.1/csrc -I/home/zjma/pytorch_v1.13.1/torch/include -I/home/zjma/pytorch_v1.13.1/torch/include/torch/csrc/api/include -I/home/zjma/pytorch_v1.13.1/torch/include/TH -I/home/zjma/pytorch_v1.13.1/torch/include/THC -I/home/zjma/.conda/envs/debugtest/include/python3.8 -c -c /home/zjma/pytorch_v1.13.1/csrc/cpu/test.cpp -o /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=test_cpp -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
cc1plus: warning: command-line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.8
g++ -pthread -shared -B /home/zjma/.conda/envs/debugtest/compiler_compat -L/home/zjma/.conda/envs/debugtest/lib -Wl,-rpath=/home/zjma/.conda/envs/debugtest/lib -Wl,--no-as-needed -Wl,--sysroot=/ /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -L/home/zjma/pytorch_v1.13.1/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for test_cpp.cpython-38-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/test_cpp.py to test_cpp.cpython-38.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.test_cpp.cpython-38: module references __file__
creating 'dist/test_cpp-0.1-py3.8-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing test_cpp-0.1-py3.8-linux-x86_64.egg
removing '/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg' (and everything under it)
creating /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Extracting test_cpp-0.1-py3.8-linux-x86_64.egg to /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages
test-cpp 0.1 is already the active version in easy-install.pthInstalled /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Processing dependencies for test-cpp==0.1
Finished processing dependencies for test-cpp==0.1

编译安装后的文件目录:

csrc
├─ build
│    ├─ bdist.linux-x86_64
│    ├─ lib.linux-x86_64-3.8
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ lib.linux-x86_64-cpython-38
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ temp.linux-x86_64-3.8
│    │    ├─ .ninja_deps
│    │    ├─ .ninja_log
│    │    ├─ build.ninja
│    │    └─ home
│    └─ temp.linux-x86_64-cpython-38
│           ├─ .ninja_deps
│           ├─ .ninja_log
│           ├─ build.ninja
│           └─ home
├─ cpu
│    ├─ test.cpp
│    └─ test.h
├─ dist
│    └─ test_cpp-0.1-py3.8-linux-x86_64.egg
├─ setup.py
└─ test_cpp.egg-info├─ PKG-INFO├─ SOURCES.txt├─ dependency_links.txt└─ top_level.txt

第四步:验证安装

1、在虚拟环境的路径 /lib/python3.8/site-packages 下看到 test_cpp-0.1-py3.8-linux-x86_64.egg 文件
在这里插入图片描述
2、conda list 查看当前虚拟环境下已经安装的包
在这里插入图片描述3、进入 python 的交互模式,import test_cpp 后报错:

>>> import test_cpp
Traceback (most recent call last):File "<stdin>", line 1, in <module>
ImportError: libc10.so: cannot open shared object file: No such file or directory

参考 通过Python setup.py install的第三方包,import时却无法导入是什么问题呢? - 神经的网络里挣扎的回答 - 知乎,因为编译的 test_cpp 包需要依赖 torch 包,导致无法导入。所以,在 import test_cpp 前要先 import torch

第五步:C++ 模块使用

test_cpp1.py

import torch
import test_cpp
from torch.autograd import Functionclass TestFunction(Function):@staticmethoddef forward(ctx, x, y):return test_cpp.forward(x, y)@staticmethoddef backward(ctx, gradOutput):gradX, gradY = test_cpp.backward(gradOutput)return gradX, gradYclass Test(torch.nn.Module):def __init__(self):super(Test, self).__init__()def forward(self, inputA, inputB):return TestFunction.apply(inputA, inputB)

test_cpp2.py

import torch
from torch.autograd import Variablefrom test_cpp1 import Testx = Variable(torch.Tensor([1,2,3]), requires_grad=True)
y = Variable(torch.Tensor([4,5,6]), requires_grad=True)test = Test()
z = test(x, y)
z.sum().backward()print('x: ', x)
print('y: ', y)
print('z: ', z)
print('x.grad: ', x.grad)
print('y.grad: ', y.grad)

运行结果

/home/zjma/.conda/envs/debugtest/bin/python /home/zjma/PycharmProjects/pythonProject/test_cpp2.py 
x:  tensor([1., 2., 3.], requires_grad=True)
y:  tensor([4., 5., 6.], requires_grad=True)
z:  tensor([ 6.,  9., 12.], grad_fn=<TestFunctionBackward>)
x.grad:  tensor([2., 2., 2.])
y.grad:  tensor([1., 1., 1.])进程已结束,退出代码为 0

运行结果符合预期。

扩展阅读

  • pytorch之c++/cuda拓展(讲得很详细,举的例子和上文基本一样,但用到了CUDA,很多内容可以扩展去看)
  • 官方教程 相关内容的笔记(后面可以复现一下)
    • PyTorch进阶1:C++扩展
    • pytorch 的C++扩展

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/648513.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信息安全认证首选CISP-PTE

&#x1f525;在信息安全领域&#xff0c;CISP-PTE认证正逐渐成为行业的新星。作为中国信息安全测评中心推出的专业认证&#xff0c;CISP-PTE为信息安全从业者提供了国内Z高标准的资质培训。 &#x1f3af;为什么选择CISP-PTE&#xff1f; 1️⃣业界认可&#xff1a;CISP-PTE是…

oracle10g rac节点启动没进程没日志

一节点正常运行&#xff0c;二节点通过crsctl start crs启动&#xff0c;发现alert日志及所有日志都没生成&#xff0c;oracle用户下连一个相关进程都没有清理缓存&#xff1a;rm -rf /tmp/.oracle/服务挨个启动也无效&#xff1a;/etc/init.evmd run >/dev/null 2>&…

如何训练和导出模型

介绍如何通过DI-engine使用DQN算法训练强化学习模型 一、什么是DQN算法 DQN算法&#xff0c;全称为Deep Q-Network算法&#xff0c;是一种结合了Q学习&#xff08;一种价值基础的强化学习算法&#xff09;和深度学习的算法。该算法是由DeepMind团队在2013年提出的&#xff0c;…

2024亚马逊开店教程:开店准备与注册流程指南

随着新一年的到来&#xff0c;亚马逊开启了新一轮的卖家入驻&#xff0c;并且针对新卖家优化了入驻流程&#xff0c;下面为大家简单整理一下最新亚马逊入驻教程&#xff0c;有想要入驻开店的小伙伴速速看过来&#xff01; 一、开店前准备 1、账号环境准备 为了防止账号由于网…

将 Amazon Bedrock 与 Elasticsearch 和 Langchain 结合使用

Amazon Bedrock 是一项完全托管的服务&#xff0c;通过单一 API 提供来自 AI21 Labs、Anthropic、Cohere、Meta、Stability AI 和 Amazon 等领先 AI 公司的高性能基础模型 (FMs) 选择&#xff0c;以及广泛的 构建生成式 AI 应用程序所需的功能&#xff0c;简化开发&#xff0c;…

MS2510:8 比特高速模数(ADC)转换器

描述&#xff1a; MS2510 是 8 比特&#xff0c; 20MSPS 模数转换器&#xff08; ADCs &#xff09; , 同时使用一个半闪速结构。 MS2510 在 5V 的电源电压下工作&#xff0c;其典型功耗只有 130mW &#xff0c;包括一个内部的采样保持电路&#xff0c;具有 高阻抗方…

斐波那契数列

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 学习必须往深处挖&…

ffmpeg和opencv一些容易影响图片清晰度的操作

ffmpeg 转视频或者图片&#xff0c;不指定码率清晰度会下降 ffmpeg -i xxx.png xxx.mp4 码率也叫比特率&#xff08;Bit rate&#xff09;(也叫数据率)是一个确定整体视频/音频质量的参数&#xff0c;秒为单位处理的字节数&#xff0c;码率和视频质量成正比&#xff0c;在视频…

IP被封怎么办?访问网站时IP被阻止?解决IP禁令全方法

相信很多人遇到过IP禁令&#xff1a;比如你在访问社交媒体、搜索引擎或电子商务网站时会被限制访问&#xff0c;又或者你的的账号莫名被封&#xff0c;这些由于网络上的种种限制我们经常会遭遇IP被封的情况&#xff0c;导致无法使用继续进行网络行动。在本文中&#xff0c;我们…

04 约数

定义&#xff1a; 若整数n除以整数d的余数为0&#xff0c;即d能够整除n&#xff0c;n是d的倍数&#xff0c;记作d|n. 通过质因子求一个数的约数 如果n可以表示成 其中均为n的质因子 因为对于任意一个质因子都有选0个 选1个 选2个....选个共种可能&#xff0c; n的约数个数…

在DevEco开发工具中,使用Previewer预览界面中的UI组件

1、在DevEco工具中&#xff0c;点击并展开PreViewer预览器 2、在PreViewer预览器中&#xff0c;点击Tt按钮&#xff08;Inspector&#xff09;切换至组件查看模式 3、在组件查看模式下选择组件&#xff0c;代码呈现选中状态&#xff0c;右侧呈现组件树&#xff0c;右下方呈现组…

ARM 驱动 1.22

linux内核等待队列wait_queue_head_t 头文件 include <linux/wait.h> 定义并初始化 wait_queue_head_t r_wait; init_waitqueue_head(&cm_dev->r_wait); wait_queue_head_t 表示等待队列头&#xff0c;等待队列wait时&#xff0c;会导致进程或线程被休眠&…

倍增算法笔记

主要应用场景 RMQ&#xff1a;区间最值问题 LCA&#xff1a;最近公共祖先问题 RMQ问题——区间最值 如果用数组f[N]存储,用数组a[i][j]表示从第i个数起连续 2^j 个数中的最大值,[i,i 2^j - 1],显然a[i][0] f[i],则很容易得到状态转移方程: a[i][j] max(a[i][j - 1], a[i …

读书笔记-《数据结构与算法》-摘要11[Divide and Conquer - 分治法]

在计算机科学中&#xff0c;分治法是一种很重要的算法。分治法即『分而治之』&#xff0c;把一个复杂的问题分成两个或更多的相同或相似的子问题&#xff0c;再把子问题分成更小的子问题……直到最后子问题可以简单的直接求解&#xff0c;原问题的解即子问题的解的合并。这个思…

电商API接口|爬虫案例|采集某东商品评论信息

前言&#xff1a; 平常大家都有网上购物的习惯&#xff0c;在商品下面卖的好的产品基本都会有评论&#xff0c;当然也不排除有刷评论的情况&#xff0c;因为评论会影响我们的购物决策。今天主要分享用pythonre正则表达式获取京东商品评论。API接口获取京东平台商品详情SKU数据…

11k+ star 一款不错的笔记leanote安装教程

特点 支持普通模式 支持markdown模式 支持搜索 安装教程 1.安装mongodb 1.1.下载 #下载 cd /opt wget https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-3.0.1.tgz 1.2解压 tar -xvf mongodb-linux-x86_64-3.0.1.tgz 1.3配置mongodb环境变量 vim /etc/profile 增…

电脑可以连接wifi,甚至可以qq聊天,但就是不能用浏览器上网,一直显示未检测出入户网线的解决方案

今天回到家&#xff0c;准备办公却发现电脑可以连接wifi&#xff0c;甚至可以qq聊天&#xff0c;但就是不能用浏览器上网&#xff0c;一直显示未检测出入户网线的解决方案&#xff0c;小白也可以看懂 以下有几种解决方案&#xff0c;不妨都试试&#xff0c;估计可以解决95%的相…

C#-前后端分离连接mysql数据库封装接口

C#是世界上最好的语言 新建项目 如下图所示选择框红的项目 然后新建 文件夹 Common 并新建类文件 名字任意 文件内容如下 因为要连接的是mysql数据库 所以需要安装 MySql.Data.MySqlClient 依赖; using MySql.Data.MySqlClient; using System.Data;namespace WebApplication1.…

Django 为应用定制化admin独立后台

定制后界面 在应用目录下找到admin.py并进行编辑 from django.contrib.admin import AdminSite from .models import Question,Choiceclass PollsAdminSite(AdminSite):site_header"Admin-site-header"site_title"admin-site-title"index_title"admi…

Conda 使用environment.yml创建一个新的Python项目

Conda系列&#xff1a; 翻译: Anaconda 与 miniconda的区别Miniconda介绍以及安装Conda python运行的包和环境管理 入门Conda python管理环境environments 一 从入门到精通Conda python管理环境environments 二 从入门到精通Conda python管理环境environments 三 从入门到精通…