如题,在本地跑xai的grok-1的时候遇见的问题。
首先你的cuda应该是安装好的,也就是bash下nvidia-smi
可以显示、python下torch.cuda.is_available()
返回True
。
在执行
import jaxjax.local_device_count()
的时候报错这个。
在/usr/local/cuda/extras/CUPTI/lib64
应该是包含so文件的,这时候在~/.bashrc
加上export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64
,然后source ~/.bashrc
就行了。
import jax
jax.devices()
显示应该是[cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]
这样的,有几个卡就显示几个。
如果没有显示,那要看你是不是安装的jax的gpu版本,参考JAX: 库安装和GPU使用,解决不能识别gpu问题