参考
在 C++ 中注册调度运算符
使用自定义 C++ 运算符扩展 TorchScript
环境:
- NVIDIA Driver Version : 545.23.08
- CUDA Version: 12.1
- Python Version: 3.11
- Pytorch Version: 2.1
- Cmake version : 3.18.1
- 工作目录:
workspace/test
一、 C++ 自定义运算符
创建workspace/test/add2.cpp
:
#include <stdio.h>
#include "add2.cuh"#include "torch/script.h"namespace {
using torch::Tensor;
using torch::DeviceType;Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) {TORCH_CHECK(self_.sizes() == other_.sizes());TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU);TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU);printf("cpu\n");Tensor self = self_.contiguous();Tensor other = other_.contiguous();Tensor result = torch::empty(self.sizes(), self.options());const float* self_ptr = self.data_ptr<float>();const float* other_ptr = other.data_ptr<float>();float* result_ptr = result.data_ptr<float>();for (int64_t i = 0; i < result.numel(); i++) {result_ptr[i] = self_ptr[i] + other_ptr[i];}return result;
}Tensor myadd_cuda(const Tensor& self_, const Tensor& other_) {TORCH_CHECK(self_.sizes() == other_.sizes());TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CUDA);TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CUDA);printf("cuda\n");Tensor self = self_.contiguous();Tensor other = other_.contiguous();Tensor result = torch::empty(self.sizes(), self.options());const float* self_ptr = self.data_ptr<float>();const float* other_ptr = other.data_ptr<float>();float* result_ptr = result.data_ptr<float>();launch_add2(result_ptr, self_ptr, other_ptr, result.numel());return result;
}} //namespaceTORCH_LIBRARY(myops, m) {m.def("myadd(Tensor self, Tensor other) -> Tensor");
}
TORCH_LIBRARY_IMPL(myops, CPU, m) {m.impl("myadd", myadd_cpu);
}
TORCH_LIBRARY_IMPL(myops, CUDA, m) {m.impl("myadd", myadd_cuda);
}
创建workspace/test/add2.cu
:
#include "add2.cuh"__global__ void add2_kernel(float* c,const float* a,const float* b,long n) {for (int i = blockIdx.x * blockDim.x + threadIdx.x; \i < n; i += gridDim.x * blockDim.x) {c[i] = a[i] + b[i];}
}void launch_add2(float* c,const float* a,const float* b,long n) {dim3 grid((n + 1023) / 1024);dim3 block(1024);add2_kernel<<<grid, block>>>(c, a, b, n);
}
创建workspace/test/add2.cuh
:
void launch_add2(float* c, const float* a, const float* b, long n);
二、 cmake编译动态库
创建workspace/test/CMakeLists.txt
:
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(add2)find_package(Torch REQUIRED)
#find_package(CUDA REQUIRED)# Define our library target
add_library(add2 SHARED add2.cpp add2.cu)
# Enable C++17
target_compile_features(add2 PRIVATE cxx_std_17)
# Link against LibTorch
target_link_libraries(add2 "${TORCH_LIBRARIES}")
新建目录build,编译:
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
make
创建workspace/test/test.py
:
import time
import ctypes
import numpy as np
import torchprint(torch.__version__)
torch.ops.load_library("build/libadd2.so")
print(torch.ops.myops.myadd)# c = a + b (shape: [n])
n = 1024 * 1024
a1 = torch.rand(n, device="cpu")
b1 = torch.rand(n, device="cpu")a2 = torch.rand(n, device="cuda:0")
b2 = torch.rand(n, device="cuda:0")def run_torch():c = torch.ops.myops.myadd(a1, b1)return cdef run_cuda():c = torch.ops.myops.myadd(a2, b2)return cprint("\nRunning cpu...")
print(a1)
print(b1)
start_time = time.time()
c_cpu = run_torch()
end_time = time.time()
print(c_cpu)
print((end_time-start_time)*1e6)print("\nRunning cuda...")
print(a2)
print(b2)
start_time = time.time()
c_cuda = run_cuda()
end_time = time.time()
print(c_cuda)
print((end_time-start_time)*1e6)
结果如下
$ python3 test.py
2.1.0+cu121
myops.myaddRunning cpu...
tensor([0.5668, 0.9394, 0.5168, ..., 0.3057, 0.0873, 0.6022])
tensor([0.1668, 0.8012, 0.4616, ..., 0.7969, 0.7210, 0.8589])
cpu
tensor([0.7335, 1.7406, 0.9784, ..., 1.1026, 0.8083, 1.4611])
9006.977081298828Running cuda...
tensor([0.3864, 0.3490, 0.5892, ..., 0.4237, 0.4182, 0.6051], device='cuda:0')
tensor([0.3069, 0.7079, 0.1878, ..., 0.7639, 0.6509, 0.5006], device='cuda:0')
cuda
tensor([0.6933, 1.0568, 0.7770, ..., 1.1876, 1.0690, 1.1058], device='cuda:0')
362.396240234375