Custom C++ and CUDA Extensions - PyTorch

0. Abstract

经历了一波 pybind11 和 CUDA 编程 的学习, 接下来看一看 PyTorch 官方给的 C++/CUDA 扩展的教程. 发现极其简单, 就是直接用 setuptools 导出 PyTorch C++ 版代码的 Python 接口就可以了. 所以, 本博客包含以下内容:

  • LibTorch 初步;
  • C++ Extension 例子;

1. LibTorch 初步

在 PyTorch 的首页安装指引中就可以看到 PyTorch 是支持 C++/Java 的:

下载后解压到一个地方, 如 /opt/libtorch. 然后就可以使用 C++ 编写 PyTorch 程序了. 官方给的有相关例子, 我们选择最经典的 MNIST 手写数字识别项目来看一看:

mnist/
├── CMakeLists.txt
├── README.md
└── mnist.cpp

1.1 CMake 项目

CMakeLists.txt 是构建 cpp 项目的说明文件:

cmake_minimum_required(VERSION 3.5)
project(mnist)
set(CMAKE_CXX_STANDARD 17)find_package(Torch REQUIRED)option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
if (DOWNLOAD_MNIST)message(STATUS "Downloading MNIST dataset")execute_process(COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py -d ${CMAKE_BINARY_DIR}/dataERROR_VARIABLE DOWNLOAD_ERROR)if (DOWNLOAD_ERROR)message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")endif()
endif()add_executable(mnist mnist.cpp)
target_compile_features(mnist PUBLIC cxx_range_for)
target_link_libraries(mnist ${TORCH_LIBRARIES})

为了下载 MNIST 数据集, 这里用到了一个 Python 文件 ../tools/download_mnist.py, 执行 cmake 后, 编译根目录(build)会出现一个 data 数据文件夹.

  • find_package(Torch REQUIRED) 查找 libtorch 时可能需要指定路径:
    find_package(Torch REQUIRED PATHS "path/to/libtorch/")
  • make 时, Ubuntu18.04 下出现错误: undefined reference to symbol ‘pthread_create@@GLIBC_2.2.5’.
    => 经查阅资料, 说: pthread 不是 linux 下的默认的库, 也就是在链接的时候, 无法找到 phread 库中线程函数的入口地址, 于是链接会失败.
    => 解决方案: target_link_libraries(mnist ${TORCH_LIBRARIES} -lpthread -lm)

make 之后, 执行 ./mnist 就能进行训练与测试了:

CUDA available! Training on GPU.
Train Epoch: 1 [59584/60000] Loss: 0.2078
Test set: Average loss: 0.2062 | Accuracy: 0.935
Train Epoch: 2 [59584/60000] Loss: 0.2039
Test set: Average loss: 0.1304 | Accuracy: 0.959
...

1.2 PyTorch C++ API

接下来看 C++ 代码:

struct Net : torch::nn::Module
{Net() : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),fc1(320, 50),fc2(50, 10){register_module("conv1", conv1);register_module("conv2", conv2);register_module("conv2_drop", conv2_drop);register_module("fc1", fc1);register_module("fc2", fc2);}torch::Tensor forward(torch::Tensor &x){x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));x = torch::relu(torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));x = x.view({-1, 320});x = torch::relu(fc1->forward(x));x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());x = fc2->forward(x);return torch::log_softmax(x, /*dim=*/1);}torch::nn::Conv2d conv1;torch::nn::Conv2d conv2;torch::nn::Dropout2d conv2_drop;torch::nn::Linear fc1;torch::nn::Linear fc2;
};template<typename DataLoader>
void train(size_t epoch,Net &model,torch::Device device,DataLoader &data_loader,torch::optim::Optimizer &optimizer,size_t dataset_size
)
{model.train();size_t batch_idx = 0;for (auto &batch: data_loader){auto data = batch.data.to(device), targets = batch.target.to(device);auto output = model.forward(data);auto loss = torch::nll_loss(output, targets);AT_ASSERT(!std::isnan(loss.template item<float>()));optimizer.zero_grad();loss.backward();optimizer.step();...}
}

可以看到, 代码非常简单, 几乎和 Python 接口一致, 如果把 :: 换成 ., 就更像了. 不一样的是多了些类型限制以及一些语法. 具体的我们不多研究, 终究还是没有 Python 简洁好用. 但简单了解一下 PyTorch C++ API 的文档说明还是有必要的:

所以, 这个 LibTorch 既能用来写 C++ 项目, 也能用来给 PyTorch 写扩展. 不过官方还是推荐使用 Python 接口:

2. C++ Extension 例子

官方文档给的例子比较复杂, 这里举一个简单的例子, 把计算:

y = torch.relu(torch.matmul(x, w.t()) + b)

整合到一个操作里, 也就是使用 LibTorch C++ 编写一个等价的运算, 并导出 Python 接口. 这么做的理由是:

大概意思就是 Python 比较慢, 由 Python 一次次调用操作而频繁启动 CUDA 核会拖慢速度.

其实我觉得只有用 CUDA 编程把序列操作整合起来才能真正减少 CUDA 核的频繁启动, LibTorch 能加速可能就是因为 C++ 更快而已.

直接上代码吧, 整个项目的解构是这样子的:

LinearAct/
├── linearfun.py
├── linearact.cpp
└── setup.py

linearact.cpp 包含了组合操作的 forward 过程和 backward 过程, 前者计算正向的正常计算, 后者计算反向的梯度计算:

#include <torch/extension.h>  // 注意这里头文件和直接写 C++ 项目不一样
#include <vector>std::vector<at::Tensor> forward(torch::Tensor &input, torch::Tensor &weight, torch::Tensor &bias)
{auto relu_input = input.mm(weight.t()) + bias;auto output = torch::relu(relu_input);return {relu_input, output};  // relu_input 会在梯度计算时用到
}std::vector<torch::Tensor>
backward(torch::Tensor &grad_output, torch::Tensor &relu_input, torch::Tensor &input, torch::Tensor &weight)
{   // 求导链式法则auto grad_relu = grad_output.masked_fill(relu_input < 0, 0);auto grad_input = grad_relu.mm(weight);auto grad_weight = grad_relu.t().mm(input);auto grad_bias = grad_relu.sum(0);return {grad_input, grad_weight, grad_bias};
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("forward", &forward, "Custom forward");m.def("backward", &backward, "Custom backward");
}

这涉及到 pybind11 的用法, 详情见《pybind11 学习笔记》, 还涉及到使用 torch.autograd.Function 自定义运算的梯度计算, 详情见《PyTorch 中的 apply [autograd.Function]》. 总之, 现在我们使用 LibTorch 写了组合操作, 并写了其参数的梯度计算. linearfun.py 是利用 torch.autograd.Functionforwardbackward 整合到一起, 组成一个完整的可以进行反向梯度传播的组合运算:

import torch  # 注意, 导入 linearact 前, 应先导入 torch
import linearactclass LinearActFunction(torch.autograd.Function):@staticmethoddef forward(ctx, input, weights, bias):relu_input, output = linearact.forward(input, weights, bias)  # c++ 函数variables = [relu_input, input, weights]ctx.save_for_backward(*variables)return output@staticmethoddef backward(ctx, grad_output):outputs = linearact.backward(grad_output, *ctx.saved_tensors)  # c++ 函数grad_x, grad_w, grad_b = outputsreturn grad_x, grad_w, grad_bmylinear = LinearActFunction.apply

LibTorch C++ 代码由 setuptools 导出 Python 接口:

from setuptools import setup
from torch.utils import cpp_extensionsetup(name='linearact',ext_modules=[cpp_extension.CppExtension('linearact', ['linearact.cpp'])],cmdclass={'build_ext': cpp_extension.BuildExtension}  # 整合了 pybind11 的功能
)

在命令行执行:

python setup.py install

就可以将 linearact 包安装到 Python 系统中, 任务完成. 下面进行验证:

import torch
from linearfun import mylinearx = torch.randn(2, 3, requires_grad=True)
w = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, requires_grad=True)
# 复制一份一样的参数
x1 = torch.from_numpy(x.detach().numpy())
w1 = torch.from_numpy(w.detach().numpy())
b1 = torch.from_numpy(b.detach().numpy())
x1.requires_grad_(True)
w1.requires_grad_(True)
b1.requires_grad_(True)# %% pytorch
y = torch.relu(torch.matmul(x, w.t()) + b)
y = y.norm(p=2)
print(y)y.backward()
print(x.grad)
print(w.grad)
print(b.grad)# %% custom
print('---------------------------')
y = mylinear(x1, w1, b1)
y = y.norm(p=2)
print(y)y.backward()
print(x1.grad)
print(w1.grad)
print(b1.grad)

执行一次:

tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],[ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],[-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])
---------------------------
tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],[ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],[-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])

可以看见两者一模一样. 至于测速什么的不在本博文的考虑范围之内, 只是想了解 PyTorch 如何进行 C++ 扩展.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/55726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS3--美若天仙!?

免责声明&#xff1a;本文仅做分享~ 目录 CSS引入方式 选择器 盒子尺寸和背景色 文字控制属性 单行文字 垂直居中 字体族 font复合属性 文本对齐方式 文本修饰线 color 文字颜色 ----- 复合选择器 伪类选择器 超链接伪类 CSS特性 继承性 层叠性 优先级 Emmet …

H、Happy Number(2024牛客国庆集训派对day7)

题目链接&#xff1a; H-Happy Number_2024牛客国庆集训派对day7 (nowcoder.com) 题目描述&#xff1a; 翻译为中文&#xff1a; 数据范围&#xff1a; 输入样例&#xff1a; 680 输出样例&#xff1a; 326623 分析: 本来以为是dfs&#xff0c;但是看到数据范围1e9, 联想到是…

通信工程学习:什么是三网融合

三网融合 三网融合&#xff0c;又称“三网合一”&#xff0c;是指电信网、广播电视网、互联网在高层业务应用上的深度融合。这一概念在近年来随着信息技术的快速发展而逐渐受到重视&#xff0c;并成为推动信息化社会建设的重要力量。以下是对三网融合的详细解释&#xff1a; 一…

扩展、包含、泛化-系统架构师(七十七)

1&#xff08;&#xff09;是系统分析阶段结束后得到的工作产品&#xff0c;&#xff08;&#xff09;是系统测试阶段完成后的工作产品。 问题1 A系统设计规格说明 B系统方案建议书 C系统规格说明 D单元测试数据 问题2 A验收测试计划 B测试标准 C系统测试计划 D操作手…

社团活动助手系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;活动分类管理&#xff0c;用户管理&#xff0c;社团活动管理&#xff0c;报名信息管理&#xff0c;签到登记管理&#xff0c;投票项目管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首…

四款语音转文字神器,一键搞定会议记录!

嘿&#xff0c;朋友们&#xff0c;今天咱们来聊聊那些语音转文字的免费软件吧&#xff01;在这个快节奏的时代&#xff0c;谁不想省点时间&#xff0c;少敲几下键盘呢&#xff1f;尤其是那些开会、采访或者闲聊时&#xff0c;语音消息满天飞的日子&#xff0c;一个好用的语音转…

【业务场景】最全的购物车设计与实现

前言 博主最近在做一个购物商城&#xff0c;正好设计到购物车模块&#xff0c;于是乎全面的来聊一聊购物车模块实现的一些核心要点吧&#xff0c;很值得反复品味的设计&#xff0c;当需要实现购物车的时候&#xff0c;本文应该拿来就能用。 目录 1.需要解决的核心问题清单 2…

Mybatis-plus做了什么

Mybatis-plus做了什么 Mybatis回顾以前的方案Mybatis-plus 合集总览&#xff1a;Mybatis框架梳理 聊一下mybatis-plus。你是否有过疑问&#xff0c;Mybatis-plus中BaseMapper方法对应的SQL在哪里&#xff1f;它为啥会被越来越多人接受。在Mybatis已经足够灵活的情况下&…

22.第二阶段x86游戏实战2-背包遍历REP指令详解

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 本次游戏没法给 内容参考于&#xff1a;微尘网络安全 本人写的内容纯属胡编乱造&#xff0c;全都是合成造假&#xff0c;仅仅只是为了娱乐&#xff0c;请不要…

【Java 并发编程】多线程安全问题(上)

前言 虽然并发编程让我们的 CPU 核心能够得到充分的使用&#xff0c;程序运行效率更高效。但是也会引发一些问题。比如当进程中有多个并发线程进入一个重要数据的代码块时&#xff0c;在修改数据的过程中&#xff0c;很有可能引发线程安全问题&#xff0c;从而造成数据异常。 p…

免费 Oracle 各版本 离线帮助使用和介绍

文章目录 Oracle 各版本 离线帮助使用和介绍概要在线帮助下载离线文档包&#xff1a;解压离线文档&#xff1a;访问离线文档&#xff1a;导航使用&#xff1a;目录介绍Install and Upgrade&#xff08;安装和升级&#xff09;&#xff1a;Administration&#xff08;管理&#…

做无货源反向代购业务需要的代购系统功能需求讲解(一):商品数据接入

在电子商务领域&#xff0c;无货源反向代购业务逐渐崭露头角&#xff0c;成为许多创业者和中小企业拓展市场的新途径。这种业务模式的核心在于通过代购平台&#xff0c;将国外或特定地区的商品信息展示给国内消费者&#xff0c;并在消费者下单后&#xff0c;由代购方进行采购、…

成都睿明智科技有限公司真实可靠吗?

在这个日新月异的电商时代&#xff0c;抖音作为短视频与直播电商的佼佼者&#xff0c;正以前所未有的速度重塑着消费者的购物习惯。而在这片充满机遇与挑战的蓝海中&#xff0c;成都睿明智科技有限公司以其独到的眼光和专业的服务&#xff0c;成为了众多商家信赖的合作伙伴。今…

【万字长文】Word2Vec计算详解(一)

【万字长文】Word2Vec计算详解&#xff08;一&#xff09; 写在前面 本文用于记录本人学习NLP过程中&#xff0c;学习Word2Vec部分时的详细过程&#xff0c;本文与本人写的其他文章一样&#xff0c;旨在给出Word2Vec模型中的详细计算过程&#xff0c;包括每个模块的计算过程&a…

Ubuntu-24.10无法安装Sunlogin-15.2的解决方案

目录 1. 报错信息2. 解决方案3. dpkg-deb命令帮助4. References 1. 报错信息 albertqeeZBG7W:/opt/albertqee/Downloads$ ls | egrep -i sun SunloginClient_11.0.1.44968_amd64.deb SunloginClient_15.2.0.63062_amd64.deb SunloginClient_15.2.0.63064_amd64.deb albertqeeZ…

JavaScript函数基础(通俗易懂篇)

10.函数 10.1 函数的基础知识 为什么会有函数&#xff1f; 在写代码的时候&#xff0c;有一些常用的代码需要书写很多次&#xff0c;如果直接复制粘贴的话&#xff0c;会造成大量的代码冗余&#xff1b; 函数可以封装一段重复的javascript代码&#xff0c;它只需要声明一次&a…

在虚拟机里试用了几个linux操作系统

在虚拟机里试用了几个操作系统。遇到一些问题。虚拟机有时候出错。有时候出现死机现象&#xff0c;有的不能播放视频。有的显示效果不太好。 试了debian12&#xff0c;ubuntu20.4&#xff0c;ubuntu22.4&#xff0c;ubuntu24.4&#xff0c;deepin。其中ubuntu20.4使用时没有出…

Jenkins打包,发布,部署

一、概念 Jenkins是一个开源的持续集成工具&#xff0c;主要用于自动构建和测试软件项目&#xff0c;以及监控外部任务的运行。与版本管理工具&#xff08;如SVN&#xff0c;GIT&#xff09;和构建工具&#xff08;如Maven&#xff0c;Ant&#xff0c;Gradle&#xff09;结合使…

武汉正向科技|无人值守起重机,采用格雷母线定位系统,扎根智能制造工业

武汉正向科技开发的无人值守起重机系统在原起重机系统的基础上&#xff0c;利用格雷母线位置检测技术&#xff0c;信息技术&#xff0c;网络技术及传感器技术为起重机系统添加管理层&#xff0c;控制层和基础层。实现起重机智能化&#xff0c;无人化作业的库区综合管理系统。 正…

【数据结构 | PTA】栈

文章目录 7-1 汉诺塔的非递归实现7-2 出栈序列的合法性**7-3 简单计算器**7-4 盲盒包装流水线 7-1 汉诺塔的非递归实现 借助堆栈以非递归&#xff08;循环&#xff09;方式求解汉诺塔的问题&#xff08;n, a, b, c&#xff09;&#xff0c;即将N个盘子从起始柱&#xff08;标记…