Custom C++ and CUDA Extensions - PyTorch

0. Abstract

经历了一波 pybind11 和 CUDA 编程 的学习, 接下来看一看 PyTorch 官方给的 C++/CUDA 扩展的教程. 发现极其简单, 就是直接用 setuptools 导出 PyTorch C++ 版代码的 Python 接口就可以了. 所以, 本博客包含以下内容:

  • LibTorch 初步;
  • C++ Extension 例子;

1. LibTorch 初步

在 PyTorch 的首页安装指引中就可以看到 PyTorch 是支持 C++/Java 的:

下载后解压到一个地方, 如 /opt/libtorch. 然后就可以使用 C++ 编写 PyTorch 程序了. 官方给的有相关例子, 我们选择最经典的 MNIST 手写数字识别项目来看一看:

mnist/
├── CMakeLists.txt
├── README.md
└── mnist.cpp

1.1 CMake 项目

CMakeLists.txt 是构建 cpp 项目的说明文件:

cmake_minimum_required(VERSION 3.5)
project(mnist)
set(CMAKE_CXX_STANDARD 17)find_package(Torch REQUIRED)option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
if (DOWNLOAD_MNIST)message(STATUS "Downloading MNIST dataset")execute_process(COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py -d ${CMAKE_BINARY_DIR}/dataERROR_VARIABLE DOWNLOAD_ERROR)if (DOWNLOAD_ERROR)message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")endif()
endif()add_executable(mnist mnist.cpp)
target_compile_features(mnist PUBLIC cxx_range_for)
target_link_libraries(mnist ${TORCH_LIBRARIES})

为了下载 MNIST 数据集, 这里用到了一个 Python 文件 ../tools/download_mnist.py, 执行 cmake 后, 编译根目录(build)会出现一个 data 数据文件夹.

  • find_package(Torch REQUIRED) 查找 libtorch 时可能需要指定路径:
    find_package(Torch REQUIRED PATHS "path/to/libtorch/")
  • make 时, Ubuntu18.04 下出现错误: undefined reference to symbol ‘pthread_create@@GLIBC_2.2.5’.
    => 经查阅资料, 说: pthread 不是 linux 下的默认的库, 也就是在链接的时候, 无法找到 phread 库中线程函数的入口地址, 于是链接会失败.
    => 解决方案: target_link_libraries(mnist ${TORCH_LIBRARIES} -lpthread -lm)

make 之后, 执行 ./mnist 就能进行训练与测试了:

CUDA available! Training on GPU.
Train Epoch: 1 [59584/60000] Loss: 0.2078
Test set: Average loss: 0.2062 | Accuracy: 0.935
Train Epoch: 2 [59584/60000] Loss: 0.2039
Test set: Average loss: 0.1304 | Accuracy: 0.959
...

1.2 PyTorch C++ API

接下来看 C++ 代码:

struct Net : torch::nn::Module
{Net() : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),fc1(320, 50),fc2(50, 10){register_module("conv1", conv1);register_module("conv2", conv2);register_module("conv2_drop", conv2_drop);register_module("fc1", fc1);register_module("fc2", fc2);}torch::Tensor forward(torch::Tensor &x){x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));x = torch::relu(torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));x = x.view({-1, 320});x = torch::relu(fc1->forward(x));x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());x = fc2->forward(x);return torch::log_softmax(x, /*dim=*/1);}torch::nn::Conv2d conv1;torch::nn::Conv2d conv2;torch::nn::Dropout2d conv2_drop;torch::nn::Linear fc1;torch::nn::Linear fc2;
};template<typename DataLoader>
void train(size_t epoch,Net &model,torch::Device device,DataLoader &data_loader,torch::optim::Optimizer &optimizer,size_t dataset_size
)
{model.train();size_t batch_idx = 0;for (auto &batch: data_loader){auto data = batch.data.to(device), targets = batch.target.to(device);auto output = model.forward(data);auto loss = torch::nll_loss(output, targets);AT_ASSERT(!std::isnan(loss.template item<float>()));optimizer.zero_grad();loss.backward();optimizer.step();...}
}

可以看到, 代码非常简单, 几乎和 Python 接口一致, 如果把 :: 换成 ., 就更像了. 不一样的是多了些类型限制以及一些语法. 具体的我们不多研究, 终究还是没有 Python 简洁好用. 但简单了解一下 PyTorch C++ API 的文档说明还是有必要的:

所以, 这个 LibTorch 既能用来写 C++ 项目, 也能用来给 PyTorch 写扩展. 不过官方还是推荐使用 Python 接口:

2. C++ Extension 例子

官方文档给的例子比较复杂, 这里举一个简单的例子, 把计算:

y = torch.relu(torch.matmul(x, w.t()) + b)

整合到一个操作里, 也就是使用 LibTorch C++ 编写一个等价的运算, 并导出 Python 接口. 这么做的理由是:

大概意思就是 Python 比较慢, 由 Python 一次次调用操作而频繁启动 CUDA 核会拖慢速度.

其实我觉得只有用 CUDA 编程把序列操作整合起来才能真正减少 CUDA 核的频繁启动, LibTorch 能加速可能就是因为 C++ 更快而已.

直接上代码吧, 整个项目的解构是这样子的:

LinearAct/
├── linearfun.py
├── linearact.cpp
└── setup.py

linearact.cpp 包含了组合操作的 forward 过程和 backward 过程, 前者计算正向的正常计算, 后者计算反向的梯度计算:

#include <torch/extension.h>  // 注意这里头文件和直接写 C++ 项目不一样
#include <vector>std::vector<at::Tensor> forward(torch::Tensor &input, torch::Tensor &weight, torch::Tensor &bias)
{auto relu_input = input.mm(weight.t()) + bias;auto output = torch::relu(relu_input);return {relu_input, output};  // relu_input 会在梯度计算时用到
}std::vector<torch::Tensor>
backward(torch::Tensor &grad_output, torch::Tensor &relu_input, torch::Tensor &input, torch::Tensor &weight)
{   // 求导链式法则auto grad_relu = grad_output.masked_fill(relu_input < 0, 0);auto grad_input = grad_relu.mm(weight);auto grad_weight = grad_relu.t().mm(input);auto grad_bias = grad_relu.sum(0);return {grad_input, grad_weight, grad_bias};
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("forward", &forward, "Custom forward");m.def("backward", &backward, "Custom backward");
}

这涉及到 pybind11 的用法, 详情见《pybind11 学习笔记》, 还涉及到使用 torch.autograd.Function 自定义运算的梯度计算, 详情见《PyTorch 中的 apply [autograd.Function]》. 总之, 现在我们使用 LibTorch 写了组合操作, 并写了其参数的梯度计算. linearfun.py 是利用 torch.autograd.Functionforwardbackward 整合到一起, 组成一个完整的可以进行反向梯度传播的组合运算:

import torch  # 注意, 导入 linearact 前, 应先导入 torch
import linearactclass LinearActFunction(torch.autograd.Function):@staticmethoddef forward(ctx, input, weights, bias):relu_input, output = linearact.forward(input, weights, bias)  # c++ 函数variables = [relu_input, input, weights]ctx.save_for_backward(*variables)return output@staticmethoddef backward(ctx, grad_output):outputs = linearact.backward(grad_output, *ctx.saved_tensors)  # c++ 函数grad_x, grad_w, grad_b = outputsreturn grad_x, grad_w, grad_bmylinear = LinearActFunction.apply

LibTorch C++ 代码由 setuptools 导出 Python 接口:

from setuptools import setup
from torch.utils import cpp_extensionsetup(name='linearact',ext_modules=[cpp_extension.CppExtension('linearact', ['linearact.cpp'])],cmdclass={'build_ext': cpp_extension.BuildExtension}  # 整合了 pybind11 的功能
)

在命令行执行:

python setup.py install

就可以将 linearact 包安装到 Python 系统中, 任务完成. 下面进行验证:

import torch
from linearfun import mylinearx = torch.randn(2, 3, requires_grad=True)
w = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, requires_grad=True)
# 复制一份一样的参数
x1 = torch.from_numpy(x.detach().numpy())
w1 = torch.from_numpy(w.detach().numpy())
b1 = torch.from_numpy(b.detach().numpy())
x1.requires_grad_(True)
w1.requires_grad_(True)
b1.requires_grad_(True)# %% pytorch
y = torch.relu(torch.matmul(x, w.t()) + b)
y = y.norm(p=2)
print(y)y.backward()
print(x.grad)
print(w.grad)
print(b.grad)# %% custom
print('---------------------------')
y = mylinear(x1, w1, b1)
y = y.norm(p=2)
print(y)y.backward()
print(x1.grad)
print(w1.grad)
print(b1.grad)

执行一次:

tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],[ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],[-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])
---------------------------
tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],[ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],[-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])

可以看见两者一模一样. 至于测速什么的不在本博文的考虑范围之内, 只是想了解 PyTorch 如何进行 C++ 扩展.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/55726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【分布式微服务云原生】探索Dubbo:接口定义语言的多样性与选择

目录 探索Dubbo&#xff1a;接口定义语言的多样性与选择引言Dubbo的接口定义语言&#xff08;IDL&#xff09;1. Java接口2. XML配置3. 注解4. Protobuf IDL 流程图&#xff1a;Dubbo服务定义流程表格&#xff1a;Dubbo IDL方式比较结论呼吁行动Excel表格&#xff1a;Dubbo IDL…

合并村庄生活废水处理设备工艺流程

诸城市鑫淼环保小编带大家了解一下合并村庄生活废水处理设备工艺流程 设备的构造 该填料采用优质PVC材料制成&#xff0c;是一种新型的折波填料&#xff0c;间距为30mm&#xff0c;比表面积超过400m/m&#xff0c;具有不易堵塞的特点&#xff0c;表面波纹设计便于膜的附着。 该…

CSS3--美若天仙!?

免责声明&#xff1a;本文仅做分享~ 目录 CSS引入方式 选择器 盒子尺寸和背景色 文字控制属性 单行文字 垂直居中 字体族 font复合属性 文本对齐方式 文本修饰线 color 文字颜色 ----- 复合选择器 伪类选择器 超链接伪类 CSS特性 继承性 层叠性 优先级 Emmet …

H、Happy Number(2024牛客国庆集训派对day7)

题目链接&#xff1a; H-Happy Number_2024牛客国庆集训派对day7 (nowcoder.com) 题目描述&#xff1a; 翻译为中文&#xff1a; 数据范围&#xff1a; 输入样例&#xff1a; 680 输出样例&#xff1a; 326623 分析: 本来以为是dfs&#xff0c;但是看到数据范围1e9, 联想到是…

通信工程学习:什么是三网融合

三网融合 三网融合&#xff0c;又称“三网合一”&#xff0c;是指电信网、广播电视网、互联网在高层业务应用上的深度融合。这一概念在近年来随着信息技术的快速发展而逐渐受到重视&#xff0c;并成为推动信息化社会建设的重要力量。以下是对三网融合的详细解释&#xff1a; 一…

go基础面试题汇总第一弹

init函数是什么时候执行的? init的函数的作用是什么&#xff1f; 通常作为程序执行前包的初始化&#xff0c;例如mysql redis 等中间件的初始化 init函数的执行顺序是怎样的&#xff1f; 分不同情况来回答&#xff1a; 在同一个go文件里面如果有多个init方法&#xff0c;它们…

扩展、包含、泛化-系统架构师(七十七)

1&#xff08;&#xff09;是系统分析阶段结束后得到的工作产品&#xff0c;&#xff08;&#xff09;是系统测试阶段完成后的工作产品。 问题1 A系统设计规格说明 B系统方案建议书 C系统规格说明 D单元测试数据 问题2 A验收测试计划 B测试标准 C系统测试计划 D操作手…

git fetch 和 git pull 的区别

git fetch 和 git pull 的区别 git fetch 功能&#xff1a;git fetch 用于从远程仓库获取最新的代码和提交信息&#xff0c;并将其保存到本地仓库的相应远程跟踪分支中&#xff0c;不会自动合并或修改当前的工作目录或当前分支。 合并&#xff1a;此命令不会自动合并获取的更新…

社团活动助手系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;活动分类管理&#xff0c;用户管理&#xff0c;社团活动管理&#xff0c;报名信息管理&#xff0c;签到登记管理&#xff0c;投票项目管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首…

四款语音转文字神器,一键搞定会议记录!

嘿&#xff0c;朋友们&#xff0c;今天咱们来聊聊那些语音转文字的免费软件吧&#xff01;在这个快节奏的时代&#xff0c;谁不想省点时间&#xff0c;少敲几下键盘呢&#xff1f;尤其是那些开会、采访或者闲聊时&#xff0c;语音消息满天飞的日子&#xff0c;一个好用的语音转…

【业务场景】最全的购物车设计与实现

前言 博主最近在做一个购物商城&#xff0c;正好设计到购物车模块&#xff0c;于是乎全面的来聊一聊购物车模块实现的一些核心要点吧&#xff0c;很值得反复品味的设计&#xff0c;当需要实现购物车的时候&#xff0c;本文应该拿来就能用。 目录 1.需要解决的核心问题清单 2…

Mybatis-plus做了什么

Mybatis-plus做了什么 Mybatis回顾以前的方案Mybatis-plus 合集总览&#xff1a;Mybatis框架梳理 聊一下mybatis-plus。你是否有过疑问&#xff0c;Mybatis-plus中BaseMapper方法对应的SQL在哪里&#xff1f;它为啥会被越来越多人接受。在Mybatis已经足够灵活的情况下&…

【分布式微服务云原生】 RPC协议:超越HTTP的远程通信艺术

目录 RPC协议&#xff1a;超越HTTP的远程通信艺术引言RPC协议的实现方式RPC的核心机制流程图&#xff1a;RPC通信流程表格&#xff1a;不同RPC实现方式的比较结论呼吁行动Excel表格&#xff1a;RPC协议实现方式总结 RPC协议&#xff1a;超越HTTP的远程通信艺术 摘要 RPC&#…

pdsh:一个用于并行执行命令的工具

pdsh&#xff08;Parallel Distributed Shell&#xff09;是一个用于并行执行命令的工具&#xff0c;可以在多个远程主机上同时运行相同的命令。它对于需要在多台服务器上执行批量操作的系统管理员和开发人员非常有用。 pdsh 介绍 主要特性 并行执行&#xff1a; pdsh 可以在…

22.第二阶段x86游戏实战2-背包遍历REP指令详解

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 本次游戏没法给 内容参考于&#xff1a;微尘网络安全 本人写的内容纯属胡编乱造&#xff0c;全都是合成造假&#xff0c;仅仅只是为了娱乐&#xff0c;请不要…

【Java 并发编程】多线程安全问题(上)

前言 虽然并发编程让我们的 CPU 核心能够得到充分的使用&#xff0c;程序运行效率更高效。但是也会引发一些问题。比如当进程中有多个并发线程进入一个重要数据的代码块时&#xff0c;在修改数据的过程中&#xff0c;很有可能引发线程安全问题&#xff0c;从而造成数据异常。 p…

mysql学习教程,从入门到精通,SQL 表、列别名(Aliases)(30)

1、SQL 表、列别名&#xff08;Aliases&#xff09; 在SQL中&#xff0c;表别名&#xff08;Table Aliases&#xff09;和列别名&#xff08;Column Aliases&#xff09;是两种非常有用的技术&#xff0c;可以使查询语句更加简洁和易读。它们还可以帮助处理复杂的查询&#xf…

免费 Oracle 各版本 离线帮助使用和介绍

文章目录 Oracle 各版本 离线帮助使用和介绍概要在线帮助下载离线文档包&#xff1a;解压离线文档&#xff1a;访问离线文档&#xff1a;导航使用&#xff1a;目录介绍Install and Upgrade&#xff08;安装和升级&#xff09;&#xff1a;Administration&#xff08;管理&#…

做无货源反向代购业务需要的代购系统功能需求讲解(一):商品数据接入

在电子商务领域&#xff0c;无货源反向代购业务逐渐崭露头角&#xff0c;成为许多创业者和中小企业拓展市场的新途径。这种业务模式的核心在于通过代购平台&#xff0c;将国外或特定地区的商品信息展示给国内消费者&#xff0c;并在消费者下单后&#xff0c;由代购方进行采购、…

第18场小白入门赛(蓝桥杯)

第 18 场 小白入门赛 6 武功秘籍 考察进制理解。 对于第 i i i 位&#xff0c;设 b i t i x bit_ix biti​x &#xff0c;每一位的最大值是 b j b_j bj​ &#xff0c;也就是说每一位是 b j 1 b_j1 bj​1 进制 &#xff0c;那么第 i i i 位的大小就是 x ∑ j i 1…