简单适配torch_npu不支持的ATen算子 一、背景说明 二、实现步骤详解 2.1 实现前向、反向传播算子 2.2 编译生成动态库 2.3 测试验证程序 三、关键点解析 四、验证结果
一、背景说明
1.1 PyTorch扩展机制
PrivateUse1
是PyTorch为第三方设备扩展设计的保留设备类型,允许开发者添加新硬件支持当算子在当前设备(如NPU)未实现时,PyTorch会自动回退(fallback)到CPU执行 本文以native_batch_norm
算子为例,演示如何为NPU设备添加自定义实现
1.2 核心概念
ATen :PyTorch的核心张量运算库,提供超过2000个基础算子内存格式 :描述张量在内存中的排布方式,如NCHW(批处理x通道x高度x宽度)自动微分 :PyTorch通过记录计算图实现反向传播,需要同时实现前向和反向算子
二、实现步骤详解
2.1 实现前向、反向传播算子
cat > native_batch_norm_npu. cpp << - 'EOF'
# include <torch/library.h>
# include <ATen/EmptyTensor.h>
# include <ATen/Device.h>
# include <ATen/Utils.h>
# include <ATen/native/Resize.h>
# include <c10/core/DeviceType.h> std:: tuple< at:: Tensor, at:: Tensor, at:: Tensor> native_batch_norm_npu ( const at:: Tensor& input, const c10:: optional< at:: Tensor> & weight, const c10:: optional< at:: Tensor> & bias, const c10:: optional< at:: Tensor> & running_mean, const c10:: optional< at:: Tensor> & running_var, bool training, double momentum, double eps)
{ at:: Tensor output = at:: empty_like ( input) ; at:: Tensor dummy_mean = at:: empty