PyTorch扩展自定义PyThonC++(CUDA)算子的若干方法总结

PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结

转自:https://zhuanlan.zhihu.com/p/158643792

作者:奔腾的黑猫

在做毕设的时候需要实现一个PyTorch原生代码中没有的并行算子,所以用到了这部分的知识,再不总结就要忘光了= =,本文内容主要是PyTorch的官方教程的各种传送门,这些官方教程写的都很好,以后就可以不用再浪费时间在百度上了。由于图神经网络计算框架PyG的代码实现也是采用了扩展的方法,因此也可以当成下面总结PyG源码文章的前导知识吧 。

第一种情况:使用PyThon扩展PyTorch

使用PyThon扩展PyTorch准确的来说是在PyTorch的Python前端实现自定义算子或者模型,不涉及底层C++的实现。这种扩展方式是所有扩展方式中最简单的,也是官方首先推荐的,这是因为PyTorch在NVIDIA cuDNN,Intel MKL或NNPACK之类的库的支持下已经对可能出现的CPU和GPU操作进行了高度优化,因此用Python扩展的代码通常足够快。

比如要扩展一个新的PyThon算子(torch.nn)只需要继承torch**.nn.Module并实现其forward方法即可。详细的过程请参考**官方教程传送门:

Extending PyTorchpytorch.org/docs/master/notes/extending.html

第二种情况:使用**pybind11**构建共享库形式的C++和CUDA扩展

但是如果我们想对代码进行进一步优化,比如对自己的算子添加并行的CUDA实现或者连接个OpenCV的库什么的,那么仅仅使用Python进行扩展就不能满足需求;其次如果我们想序列化模型,在一个没有Python环境的生产环境下部署,也需要我们使用C++重写算法;最后考虑到考虑到多线程执行和性能原因,一般Python代码也并不适合做部署。因此在对性能有要求或者需要序列化模型的场景下我们还是会用到C++扩展。

下面我先把官方教程传送门放在这里:

CUSTOM C++ AND CUDA EXTENSIONSpytorch.org/tutorials/advanced/cpp_extension.html

对于一种典型的扩展情况,比如我们要设计一个全新的C++底层算子,其过程其实就三步:

第一步:使用C++编写算子的forward函数和backward函数

第二步:将该算子的forward函数和backward函数使用**pybind11**绑定到python上

第三步:使用setuptools/JIT/CMake编译打包C++工程为so文件

注意到在第一步中,我们不仅仅要实现forward函数也要实现backward函数,这是因为在C++端PyTorch目前不支持自动根据forward函数推导出backward函数,所以我们必须要对自己算子的反向传播过程完全清楚。一个需要注意的地方是,你可以选择直接在C++中继承torch::autograd类进行扩展;也可以像官方教程中那样在C++代码中实现forward和backward的核心过程,而在python端继承PyTorch的torch**.autograd.**Function类。

在C++端扩展forward函数和backward函数的需要注意以下规则:

(1)首先无论是forward函数还是backward函数都需要声明为静态函数

(2)forward函数可以接受任意多的参数并且应该返回一个 variable list或者variable;forward函数需要将torch::autograd::AutogradContext 作为自己的第一个参数。Variables可以被使用ctx->save_for_backward保存,而其他数据类型可以使用ctx->saved_data以<std::string,at::IValue>pairs的形式保存在一个map中。

(3)backward函数第一个参数同样需要为torch::autograd::AutogradContext,其余的参数是一个variable_list,包含的变量数量与forward输出的变量数量相等。它应该返回和forward输入一样多的变量。保存在forward中的Variable变量可以通过ctx->get_saved_variables而其他的数据类型可以通过ctx->saved_data获取。

请注意,backward的输入参数是自动微分系统反传回来的参数梯度值,其需要和forward函数的返回值位置一一对应的;而backward的返回值是对各参数根据自动微分规则求导后的梯度值,其需要和forward函数的输入参数位置一一对应,对于不需要求导的参数也需要使用空Variable占位。

// PyG的C++扩展就选择的是直接继承PyTorch的C++端的torch::autograd类进行扩展
// 下面是PyG的一个ScatterSum算子的扩展示例
// 不用纠结这个算子的具体内容,对扩展的算子的结构有一个大致了解即可
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:// AutogradContext *ctx指针可以操作static variable_list forward(AutogradContext *ctx, Variable src,Variable index, int64_t dim,torch::optional<Variable> optional_out,torch::optional<int64_t> dim_size) {dim = dim < 0 ? src.dim() + dim : dim;ctx->saved_data["dim"] = dim;ctx->saved_data["src_shape"] = src.sizes();index = broadcast(index, src, dim);auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");auto out = std::get<0>(result);ctx->save_for_backward({index});// 如果在扩展的C++代码中使用非Aten内建操作修改了tensor的值,需要对其进行脏标记if (optional_out.has_value())ctx->mark_dirty({optional_out.value()});  return {out};}// grad_outs是out参数反传回来的梯度值static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {auto grad_out = grad_outs[0];auto saved = ctx->get_saved_variables();auto index = saved[0];auto dim = ctx->saved_data["dim"].toInt();auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());auto grad_in = torch::gather(grad_out, dim, index, false);// 不需要求导的参数需要空Variable占位return {grad_in, Variable(), Variable(), Variable(), Variable()};}
};

由于涉及到在C++环境下操作张量和反向传播等操作,因此我们需要对PyTorch的C++后端的库有所了解,主要就是Torch和Aten这两个库,下面我简要介绍一下这两兄弟。

img

其中Torch是PyTorch的C++底层实现(PS:其实是先有的Torch后有的PyTorch,从名字也能看出来),FB在编码PyTorch的时候就有意将PyTorch的接口和Torch的接口设计的十分类似,因此如果你对PyTorch很熟悉的话那么你也会很快的对Torch上手。

Torch官方文档传送门:

The C++ Frontendpytorch.org/cppdocs/frontend.html

安装PyTorch的C++前端的官方教程:

INSTALLING C++ DISTRIBUTIONS OF PYTORCHpytorch.org/cppdocs/installing.html

而Aten是ATen从根本上讲是一个张量库,在PyTorch中几乎所有其他Python和C ++接口都在其上构建。它提供了一个核心Tensor类,在其上定义了数百种操作。这些操作大多数都具有CPU和GPU实现,Tensor该类将根据其类型向其动态调度。和Torch相比Aten更接近底层和核心逻辑。

Aten源代码传送门:

https://github.com/zdevito/ATen/tree/master/aten/srcgithub.com/zdevito/ATen/tree/master/aten/src

使用Aten声明和操作张量的教程:

TENSOR BASICSpytorch.org/cppdocs/notes/tensor_basics.html

由于Pyorch的C++后端文档比较少,因此要多参考官方的例子,尝试去模仿官方教程的代码,同时可以通过Python前端的接口猜测后端接口的功能,如果没有文档了就读一读源码,还是有不少注释的,还能理解实现的逻辑。

第三种情况:为TORCHSCRIPT添加C++和CUDA扩展

首先简单解释一下TorchScript是什么,如果用官方的定义来说:“TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从一个Python进程中保存并可以在一个没有Python环境的进程中被加载。”通俗来说TorchScript就是一个序列化模型(即Inference)的工具,它可以让你的PyTorch代码方便的在生产环境中部署,同时在将PyTorch代码转化TorchScript代码时还会对你的模型进行一些性能上的优化。使用TorchScript完成模型的部署要比我们之前提到的使用C++重写要简单的多,因为是自动生成的。

TorchScript包含两种序列化模型的方法:tracingscript,两种方法各有其适用场景,由于和本文关系不大就不详细展开了,具体的官方教程传送门在此:

INTRODUCTION TO TORCHSCRIPTpytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

但是,TorchScript只能自动化的构造PyTorch的原生代码,如果我们需要序列化自定义的C++扩展算子,则需要我们显式的将这些自定义算子注册到TorchScript中,所幸的是,这一过程其实非常简单,整个过程和第二小节中使用pybind11构建共享库的形式的C++和CUDA扩展十分类似。官方教程传送门如下:

EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORSpytorch.org/tutorials/advanced/torch_script_custom_ops.html

而对于自定义的C++类,如果要注册到TorchScript要稍微复杂一些,官方教程传送门如下:

EXTENDING TORCHSCRIPT WITH CUSTOM C++ CLASSESpytorch.org/tutorials/advanced/torch_script_custom_classes.html?highlight=registeroperators

另外需要注意的是,如果想要编写能够被TorchScript编译器理解的代码,需要注意在C++自定义扩展算子参数中的数据类型,目前被TorchScript支持的参数数据类型有torch::Tensortorch::Scalar(标量类型),doubleint64_tstd::vector,而像float,int,short这些是不能作为自定义扩展算子的参数数据类型的。

目前就先总结这么多吧,这点东西居然写了一天,好累啊(*  ̄︿ ̄)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532457.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

给 Python 算法插上性能的翅膀——pybind11 落地实践

给 Python 算法插上性能的翅膀——pybind11 落地实践 转自&#xff1a;https://zhuanlan.zhihu.com/p/444805518 作者&#xff1a;jesonxiang&#xff08;向乾彪&#xff09;&#xff0c;腾讯 TEG 后台开发工程师 1. 背景 目前 AI 算法开发特别是训练基本都以 Python 为主&…

chrome自动提交文件_收集文档及提交名单统计

知乎文章若有排版问题请见谅&#xff0c;原文放在个人博客中【欢迎互踩&#xff01;】文叔叔文档收集使用动机在我们的学习工作中&#xff0c;少不了要让大家集体提交文件的情况&#xff0c;举个最简单的例子&#xff1a;收作业。 传统的文件收集流程大致是&#xff1a;群内发出…

惠普800g1支持什么内存_惠普黑白激光打印机哪种好 惠普黑白激光打印机推荐【图文详解】...

打印机的出现让我们在生活和日常工作中变得越来越方便&#xff0c;不过随着科技的发展&#xff0c;打印机的类型也变得非常多&#xff0c;其中就有黑白激光打印机&#xff0c;而黑白激光打印机的品牌也有很多&#xff0c;比如我们的惠普黑白激光打印机&#xff0c;今天小编就给…

控制台输出颜色控制

控制台输出颜色控制 转自&#xff1a;https://cloud.tencent.com/developer/article/1142372 前端时间&#xff0c;写了一篇 PHP 在 Console 模式下的进度显示 &#xff0c;正好最近的一个数据合并项目需要用到控制台颜色输出&#xff0c;所以就把相关的信息整理下&#xff0c;…

idea连接跳板机_跳板机服务(jumpserver)

一、跳板机服务作用介绍1、有效管理用户权限信息2、有效记录用户登录情况3、有效记录用户操作行为二、跳板机服务架构原理三、跳板机服务安装过程第一步&#xff1a;安装跳板机依赖软件yum -y install git python-pip mariadb-devel gcc automake autoconf python-devel readl…

【详细图解】再次理解im2col

【详细图解】再次理解im2col 转自&#xff1a;https://mp.weixin.qq.com/s/GPDYKQlIOq6Su0Ta9ipzig 一句话&#xff1a;im2col是将一个[C,H,W]矩阵变成一个[H,W]矩阵的一个方法&#xff0c;其原理是利用了行列式进行等价转换。 为什么要做im2col? 减少调用gemm的次数。 重要…

反思 大班 快乐的机器人_幼儿园大班教案《快乐的桌椅》含反思

大班教案《快乐的桌椅》含反思适用于大班的体育主题教学活动当中&#xff0c;让幼儿提高协调性和灵敏性&#xff0c;创新桌椅的玩法&#xff0c;正确爬的方法&#xff0c;学会匍匐前进&#xff0c;快来看看幼儿园大班《快乐的桌椅》含反思教案吧。幼儿园大班教案《快乐的桌椅》…

DCN可形变卷积实现1:Python实现

DCN可形变卷积实现1&#xff1a;Python实现 我们会先用纯 Python 实现一个 Pytorch 版本的 DCN &#xff0c;然后实现其 C/CUDA 版本。 本文主要关注 DCN 可形变卷积的代码实现&#xff0c;不会过多的介绍其思想&#xff0c;如有兴趣&#xff0c;请参考论文原文&#xff1a; …

蓝牙耳机声音一顿一顿的_线控耳机党阵地转移成功,OPPO这款TWS耳机体验满分...

“你看到我手机里3.5mm的耳机孔了吗”&#xff0c;这可能是许多线控耳机党最想说的话了。确实&#xff0c;如今手机在做“减法”&#xff0c;而厂商们首先就拿3.5mm耳机孔“开刀”&#xff0c;我们也丧失了半夜边充电边戴耳机打游戏的乐趣。竟然如此&#xff0c;那如何在耳机、…

AI移动端优化之Im2Col+Pack+Sgemm

AI移动端优化之Im2ColPackSgemm 转自&#xff1a;https://blog.csdn.net/just_sort/article/details/108412760 这篇文章是基于NCNN的Sgemm卷积为大家介绍Im2ColPackSgemm的原理以及算法实现&#xff0c;希望对算法优化感兴趣或者做深度学习模型部署的读者带来帮助。 1. 前言 …

elementui的upload组件怎么获取上传的文本流、_抖音feed流直播间引流你还不会玩?实操讲解...

本文由艾奇在线明星优化师写作计划出品在这个全民惊恐多灾多难且带有魔幻的2020&#xff0c;一场突如其来的疫情改变了人们很多消费习惯&#xff0c;同时加速了直播电商的发展&#xff0c;现在直播已经成为商家必争的营销之地&#xff0c;直播虽然很火&#xff0c;但如果没有流…

FFmpeg 视频处理入门教程

FFmpeg 视频处理入门教程 转自&#xff1a;https://www.ruanyifeng.com/blog/2020/01/ffmpeg.html 作者&#xff1a; 阮一峰 日期&#xff1a; 2020年1月14日 FFmpeg 是视频处理最常用的开源软件。 它功能强大&#xff0c;用途广泛&#xff0c;大量用于视频网站和商业软件&…

checkbox wpf 改变框的大小_【论文阅读】倾斜目标范围框(标注)的终极方案

前言最常用的斜框标注方式是在正框的基础上加一个旋转角度θ&#xff0c;其代数表示为(x_c,y_c,w,h,θ)&#xff0c;其中(x_c,y_c )表示范围框中心点坐标&#xff0c;(w,h)表示范围框的宽和高[1,2,7]。对于该标注方式&#xff0c;如果将w和h的值互换&#xff0c;再将θ加上或者…

彻底理解BP之手写BP图像分类你也行

彻底理解BP之手写BP图像分类你也行 转自&#xff1a;https://zhuanlan.zhihu.com/p/397963213 第一节&#xff1a;用矩阵的视角&#xff0c;看懂BP的网络图 1.1、什么是BP反向传播算法 BP(Back Propagation)误差反向传播算法&#xff0c;使用反向传播算法的多层感知器又称为B…

梯度下降法和牛顿法计算开根号

梯度下降法和牛顿法计算开根号 本文将介绍如何不调包&#xff0c;只能使用加减乘除法实现对根号x的求解。主要介绍梯度下降和牛顿法者两种方法&#xff0c;并给出 C 实现。 梯度下降法 思路/步骤 转化问题&#xff0c;将 x\sqrt{x}x​ 的求解转化为最小化目标函数&#xff…

汇博工业机器人码垛机怎么写_全自动码垛机器人在企业生产中的地位越来越重要...

全自动码垛机器人在企业生产中的地位越来越重要在智能化的各种全自动生产线中&#xff0c;全自动码垛机器人成了全自动生产线的重要机械设备&#xff0c;在各种生产中发挥着不可忽视的作用。全自动码垛机器人主要用于生产线上的包装过程中&#xff0c;不仅能够提高企业的生产率…

小说中场景的功能_《流浪地球》:从小说到电影

2019年春节贺岁档冒出一匹黑马&#xff1a;国产科幻片《流浪地球》大年初一上映后口碑、票房双丰收&#xff1a;截至9日下午&#xff0c;票房已破15亿&#xff0c;并获得9.2的高评分。著名导演詹姆斯卡梅隆通过社交媒体对我国春节期间上映的科幻影片《流浪地球》发出的祝愿&…

线性回归与逻辑回归及其实现

线性回归与逻辑回归及其实现 回归与分类 预测值定性分析&#xff0c;即离散变量预测时&#xff0c;称之为分类&#xff1b;预测值定量分析&#xff0c;即连续变量预测时&#xff0c;称之为回归。 如预测一张图片是猫还是狗&#xff0c;是分类问题&#xff1b;预测明年的房价…

hbase 页面访问_HBase

HBase 特点 海量存储 Hbase 适合存储 PB 级别的海量数据&#xff0c;在 PB 级别的数据以及采用廉价 PC 存储的情况下&#xff0c;能在几十到百毫秒内返回数据。这与 Hbase 的极易扩展性息息相关。正式因为 Hbase 良好的扩展性&#xff0c;才为海量数据的存储提供了便利。 2&…

深入理解L1、L2正则化

深入理解L1、L2正则化 转自&#xff1a;【面试看这篇就够了】L1、L2正则化理解 一、概述 正则化&#xff08;Regularization&#xff09;是机器学习中一种常用的技术&#xff0c;其主要目的是控制模型复杂度&#xff0c;减小过拟合。正则化技术已经成为模型训练中的常用技术&a…