PyTorch 添加 C++ 拓展

参考内容:pytorch添加C++拓展简单实战编写及基本功能测试

文章目录

  • 第一步:编写 C++ 模块
    • test.h
    • test.cpp
  • 第二步:编写 setup.py
  • 第三步:安装 C++ 模块
  • 第四步:验证安装
  • 第五步:C++ 模块使用
    • test_cpp1.py
    • test_cpp2.py
  • 运行结果
  • 扩展阅读

编译安装前的文件目录:

这里的 csrc 应该不是指 pytorch 项目中的 /torch/csrc

csrc
├─ cpu
│    ├─ test.cpp
│    └─ test.h
└─ setup.py

第一步:编写 C++ 模块

test.h

#include <torch/extension.h>
#include <vector>// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB);// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput);

test.cpp

#include "test.h"// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y){AT_ASSERTM(x.sizes() == y.sizes(), "x must be the same size as y");torch::Tensor z = torch::zeros(x.sizes());z = 2 * x + y;return z;
}// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput){torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes());torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes());return {gradOutputX, gradOutputY};
}// pybind11 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){m.def("forward", &Test_forward_cpu, "TEST forward");m.def("backward", &Test_backward_cpu, "TEST backward");
}

第二步:编写 setup.py

from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__))
# 源代码目录
source_cpu = glob.glob(os.path.join(include_dirs, 'cpu', '*.cpp'))setup(name='test_cpp', # 模块名称,需要在 python 中调用version="0.1",ext_modules=[CppExtension('test_cpp', sources=source_cpu, include_dirs=[include_dirs]),],cmdclass={'build_ext': BuildExtension}
)

第三步:安装 C++ 模块

在 csrc 文件夹下运行命令

python setup.py install

第一次尝试的报错信息:

/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!********************************************************************************Please avoid running ``setup.py`` directly.Instead, use pypa/build, pypa/installer or otherstandards-based tools.See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.********************************************************************************!!self.initialize_options()
/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!********************************************************************************Please avoid running ``setup.py`` and ``easy_install``.Instead, use pypa/build, pypa/installer or otherstandards-based tools.See https://github.com/pypa/setuptools/issues/917 for details.********************************************************************************!!self.initialize_options()

参考 SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip 后得知是 setuptools 版本太高,于是降低 setuptools 版本,pip install setuptools==58.2.0

第二次尝试的运行结果:

running install
running bdist_egg
running egg_info
writing test_cpp.egg-info/PKG-INFO
writing dependency_links to test_cpp.egg-info/dependency_links.txt
writing top-level names to test_cpp.egg-info/top_level.txt
reading manifest file 'test_cpp.egg-info/SOURCES.txt'
writing manifest file 'test_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'test_cpp' extension
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu
Emitting ninja build file /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/1] c++ -MMD -MF /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o.d -pthread -B /home/zjma/.conda/envs/debugtest/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/zjma/pytorch_v1.13.1/csrc -I/home/zjma/pytorch_v1.13.1/torch/include -I/home/zjma/pytorch_v1.13.1/torch/include/torch/csrc/api/include -I/home/zjma/pytorch_v1.13.1/torch/include/TH -I/home/zjma/pytorch_v1.13.1/torch/include/THC -I/home/zjma/.conda/envs/debugtest/include/python3.8 -c -c /home/zjma/pytorch_v1.13.1/csrc/cpu/test.cpp -o /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=test_cpp -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
cc1plus: warning: command-line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.8
g++ -pthread -shared -B /home/zjma/.conda/envs/debugtest/compiler_compat -L/home/zjma/.conda/envs/debugtest/lib -Wl,-rpath=/home/zjma/.conda/envs/debugtest/lib -Wl,--no-as-needed -Wl,--sysroot=/ /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -L/home/zjma/pytorch_v1.13.1/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for test_cpp.cpython-38-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/test_cpp.py to test_cpp.cpython-38.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.test_cpp.cpython-38: module references __file__
creating 'dist/test_cpp-0.1-py3.8-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing test_cpp-0.1-py3.8-linux-x86_64.egg
removing '/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg' (and everything under it)
creating /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Extracting test_cpp-0.1-py3.8-linux-x86_64.egg to /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages
test-cpp 0.1 is already the active version in easy-install.pthInstalled /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Processing dependencies for test-cpp==0.1
Finished processing dependencies for test-cpp==0.1

编译安装后的文件目录:

csrc
├─ build
│    ├─ bdist.linux-x86_64
│    ├─ lib.linux-x86_64-3.8
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ lib.linux-x86_64-cpython-38
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ temp.linux-x86_64-3.8
│    │    ├─ .ninja_deps
│    │    ├─ .ninja_log
│    │    ├─ build.ninja
│    │    └─ home
│    └─ temp.linux-x86_64-cpython-38
│           ├─ .ninja_deps
│           ├─ .ninja_log
│           ├─ build.ninja
│           └─ home
├─ cpu
│    ├─ test.cpp
│    └─ test.h
├─ dist
│    └─ test_cpp-0.1-py3.8-linux-x86_64.egg
├─ setup.py
└─ test_cpp.egg-info├─ PKG-INFO├─ SOURCES.txt├─ dependency_links.txt└─ top_level.txt

第四步:验证安装

1、在虚拟环境的路径 /lib/python3.8/site-packages 下看到 test_cpp-0.1-py3.8-linux-x86_64.egg 文件
在这里插入图片描述
2、conda list 查看当前虚拟环境下已经安装的包
在这里插入图片描述3、进入 python 的交互模式,import test_cpp 后报错:

>>> import test_cpp
Traceback (most recent call last):File "<stdin>", line 1, in <module>
ImportError: libc10.so: cannot open shared object file: No such file or directory

参考 通过Python setup.py install的第三方包,import时却无法导入是什么问题呢? - 神经的网络里挣扎的回答 - 知乎,因为编译的 test_cpp 包需要依赖 torch 包,导致无法导入。所以,在 import test_cpp 前要先 import torch

第五步:C++ 模块使用

test_cpp1.py

import torch
import test_cpp
from torch.autograd import Functionclass TestFunction(Function):@staticmethoddef forward(ctx, x, y):return test_cpp.forward(x, y)@staticmethoddef backward(ctx, gradOutput):gradX, gradY = test_cpp.backward(gradOutput)return gradX, gradYclass Test(torch.nn.Module):def __init__(self):super(Test, self).__init__()def forward(self, inputA, inputB):return TestFunction.apply(inputA, inputB)

test_cpp2.py

import torch
from torch.autograd import Variablefrom test_cpp1 import Testx = Variable(torch.Tensor([1,2,3]), requires_grad=True)
y = Variable(torch.Tensor([4,5,6]), requires_grad=True)test = Test()
z = test(x, y)
z.sum().backward()print('x: ', x)
print('y: ', y)
print('z: ', z)
print('x.grad: ', x.grad)
print('y.grad: ', y.grad)

运行结果

/home/zjma/.conda/envs/debugtest/bin/python /home/zjma/PycharmProjects/pythonProject/test_cpp2.py 
x:  tensor([1., 2., 3.], requires_grad=True)
y:  tensor([4., 5., 6.], requires_grad=True)
z:  tensor([ 6.,  9., 12.], grad_fn=<TestFunctionBackward>)
x.grad:  tensor([2., 2., 2.])
y.grad:  tensor([1., 1., 1.])进程已结束,退出代码为 0

运行结果符合预期。

扩展阅读

  • pytorch之c++/cuda拓展(讲得很详细,举的例子和上文基本一样,但用到了CUDA,很多内容可以扩展去看)
  • 官方教程 相关内容的笔记(后面可以复现一下)
    • PyTorch进阶1:C++扩展
    • pytorch 的C++扩展

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/648513.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信息安全认证首选CISP-PTE

&#x1f525;在信息安全领域&#xff0c;CISP-PTE认证正逐渐成为行业的新星。作为中国信息安全测评中心推出的专业认证&#xff0c;CISP-PTE为信息安全从业者提供了国内Z高标准的资质培训。 &#x1f3af;为什么选择CISP-PTE&#xff1f; 1️⃣业界认可&#xff1a;CISP-PTE是…

fMRI数据处理(随时更新)

要开始学习处理fMRI的数据了。 一、使用matlab工具包SPM读取fMRI数据 &#xff08;1&#xff09;首先得安装工具包SPM&#xff0c;我参考的是下面这篇博客&#xff1a; 在matlab下安装spm工具_spmas包matlab-CSDN博客 &#xff08;2&#xff09;使用SPM读取数据&#xff0c…

oracle10g rac节点启动没进程没日志

一节点正常运行&#xff0c;二节点通过crsctl start crs启动&#xff0c;发现alert日志及所有日志都没生成&#xff0c;oracle用户下连一个相关进程都没有清理缓存&#xff1a;rm -rf /tmp/.oracle/服务挨个启动也无效&#xff1a;/etc/init.evmd run >/dev/null 2>&…

抖音详情API:视频内容获取与解析技巧

一、引言 抖音是一款广受欢迎的短视频分享平台&#xff0c;每天都有大量的用户在抖音上分享自己的生活点滴和创意作品。对于开发者而言&#xff0c;如何获取并解析抖音上的视频内容&#xff0c;是一项极具挑战性的任务。本文将详细介绍抖音详情API&#xff0c;以及如何使用它来…

CVPR 2023: Make-a-Story Visual Memory Conditioned Consistent Story Generation

我们采用以下 6 个分类标准来详细解释本文的研究主题: 1. 生成模型类型: 基于扩散的:这种方法通过前向扩散过程迭代地将噪声细化为图像。这允许生成高质量的图像,并控制特定方面,如场景元素和照明。基于注意力的:注意力机制有助于模型在生成每个帧时集中在文本描述和视觉…

如何训练和导出模型

介绍如何通过DI-engine使用DQN算法训练强化学习模型 一、什么是DQN算法 DQN算法&#xff0c;全称为Deep Q-Network算法&#xff0c;是一种结合了Q学习&#xff08;一种价值基础的强化学习算法&#xff09;和深度学习的算法。该算法是由DeepMind团队在2013年提出的&#xff0c;…

2024亚马逊开店教程:开店准备与注册流程指南

随着新一年的到来&#xff0c;亚马逊开启了新一轮的卖家入驻&#xff0c;并且针对新卖家优化了入驻流程&#xff0c;下面为大家简单整理一下最新亚马逊入驻教程&#xff0c;有想要入驻开店的小伙伴速速看过来&#xff01; 一、开店前准备 1、账号环境准备 为了防止账号由于网…

将 Amazon Bedrock 与 Elasticsearch 和 Langchain 结合使用

Amazon Bedrock 是一项完全托管的服务&#xff0c;通过单一 API 提供来自 AI21 Labs、Anthropic、Cohere、Meta、Stability AI 和 Amazon 等领先 AI 公司的高性能基础模型 (FMs) 选择&#xff0c;以及广泛的 构建生成式 AI 应用程序所需的功能&#xff0c;简化开发&#xff0c;…

MS2510:8 比特高速模数(ADC)转换器

描述&#xff1a; MS2510 是 8 比特&#xff0c; 20MSPS 模数转换器&#xff08; ADCs &#xff09; , 同时使用一个半闪速结构。 MS2510 在 5V 的电源电压下工作&#xff0c;其典型功耗只有 130mW &#xff0c;包括一个内部的采样保持电路&#xff0c;具有 高阻抗方…

斐波那契数列

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 学习必须往深处挖&…

网闸网络ip端口映射原理分析

今天我们进行网闸网络ip端口映射原理分析&#xff1a;即两个不同网段的网址进行网络通信&#xff0c;需要路由器、网关等配置。举例更形象一些。 一、相同端口映射 1、正向访问 比如两个不同网段的网络ip:A:10.18.3.119 需要和B:10.199.177.10 进行通信&#xff0c;A要访问…

ffmpeg和opencv一些容易影响图片清晰度的操作

ffmpeg 转视频或者图片&#xff0c;不指定码率清晰度会下降 ffmpeg -i xxx.png xxx.mp4 码率也叫比特率&#xff08;Bit rate&#xff09;(也叫数据率)是一个确定整体视频/音频质量的参数&#xff0c;秒为单位处理的字节数&#xff0c;码率和视频质量成正比&#xff0c;在视频…

PyTorch中self.layers的作用

self.layers 是一个用于存储网络层的属性。它是一个 nn.ModuleList 对象&#xff0c;这是PyTorch中用于存储 nn.Module 子模块的特殊列表。 为什么使用 nn.ModuleList&#xff1f; 在PyTorch中&#xff0c;当需要处理多个神经网络层时&#xff0c;通常使用 nn.ModuleList 或 …

TCP三次握手-普通话版

前言&#xff1a;UDP和TCP 总拿UDP和TCP进行比较&#xff0c;为什么呢&#xff1f;因为UDP是不可靠传输&#xff0c;数据过来后把数据分成小份后就发送出去了&#xff0c;我不管你们收没收到哈&#xff0c;反正我是发过去了&#xff0c;你能收到多少就看这网速行不行&#xff0…

IP被封怎么办?访问网站时IP被阻止?解决IP禁令全方法

相信很多人遇到过IP禁令&#xff1a;比如你在访问社交媒体、搜索引擎或电子商务网站时会被限制访问&#xff0c;又或者你的的账号莫名被封&#xff0c;这些由于网络上的种种限制我们经常会遭遇IP被封的情况&#xff0c;导致无法使用继续进行网络行动。在本文中&#xff0c;我们…

linux动态库,静态库

参考 链接 https://blog.csdn.net/Goforyouqp/article/details/132106168 /* ---------- h.h 文件 -------------- */ #ifndef H_H #define H_H void print(void); #endif /* ---------- h.c 文件 -------------- */ #include "h.h" #include &l…

牛客周赛 Round 29(A B C D E)

目录 A.小红大战小紫 题目大意&#xff1a; 解题思路&#xff1a; AC代码&#xff1a; B.小红的白日梦 题目大意&#xff1a; 解题思路&#xff1a; AC代码&#xff1a; C.小红的小小红 题目大意&#xff1a; AC代码&#xff1a; D.小红的中位数 题目大意&#xff…

04 约数

定义&#xff1a; 若整数n除以整数d的余数为0&#xff0c;即d能够整除n&#xff0c;n是d的倍数&#xff0c;记作d|n. 通过质因子求一个数的约数 如果n可以表示成 其中均为n的质因子 因为对于任意一个质因子都有选0个 选1个 选2个....选个共种可能&#xff0c; n的约数个数…

在DevEco开发工具中,使用Previewer预览界面中的UI组件

1、在DevEco工具中&#xff0c;点击并展开PreViewer预览器 2、在PreViewer预览器中&#xff0c;点击Tt按钮&#xff08;Inspector&#xff09;切换至组件查看模式 3、在组件查看模式下选择组件&#xff0c;代码呈现选中状态&#xff0c;右侧呈现组件树&#xff0c;右下方呈现组…

ARM 驱动 1.22

linux内核等待队列wait_queue_head_t 头文件 include <linux/wait.h> 定义并初始化 wait_queue_head_t r_wait; init_waitqueue_head(&cm_dev->r_wait); wait_queue_head_t 表示等待队列头&#xff0c;等待队列wait时&#xff0c;会导致进程或线程被休眠&…