0,环境
ubuntu 22.04
pytorch 2.3.1
x86
RTX 3080
cuda 12.2
1, 示例代码
以potrs为例;
hello_cholesk.py
"""
hello_cholesky.py
step1, Cholesky decompose;
step2, inverse A;
step3, Cholesky again;
python3 hello_cholesky.py --size 256 --cuda_device_id 0
"""
import torch
import time
import argparsedef cholesky_measure(A, cuda_dev=0):dev = torch.device(f"cuda:{cuda_dev}")A = A.to(dev)print(f'Which device to compute : {dev}')SY = 100* torch.mm(A, A.t()) + 200*torch.eye(N, device=dev)to_start = time.time() SY = torch.linalg.cholesky(SY)SY = torch.cholesky_inverse(SY)SY = torch.linalg.cholesky(SY, upper=True)run_time = time.time() - to_start print(f'The device: {dev}, run: {run_time:.3f} second')print(f'SY : {SY}')print(f'****'*20)return run_timeif __name__ == "__main__":parser = argparse.ArgumentParser(description='dim of A.')parser.add_argument('--N', type=int, default=512, required=True, help='dim of A')args = parser.parse_args()N = args.Nprint(f'A N : {N}') A = torch.randn(N, N)cuda_dev = 0time_dev0 = cholesky_measure(A, cuda_dev) time_dev1 = cholesky_measure(A, cuda_dev+1) print(f'time_dev0 /time_dev1 = {time_dev0/time_dev1:.2f} ')
运行效果:
2,调用栈跟踪
跟踪如下调用关系:
Tensor cholesky_inverse(const Tensor &input, bool upper) aten/src/ATen/native/BatchLinearAlgebra.cppstatic Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper)DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)Tensor& cholesky_inverse_kernel_impl_cusolver(Tensor &result, Tensor& infos, bool upper)void _cholesky_inverse_cusolver_potrs_based(Tensor& result, Tensor& infos, bool upper)template<typename scalar_t>inline static void apply_cholesky_cusolver_potrs(Tensor& self_working_copy, const Tensor& A_column_major_copy, bool upper, Tensor& infos)at::cuda::solver::potrs<scalar_t>(handle, uplo, n_32, nrhs_32,A_ptr + i * A_matrix_stride,lda_32,self_working_copy_ptr + i * self_matrix_stride,ldb_32,infos_ptr);
一些细节: