之前做实验报了一个错误,卡了很久。
具体就是这行代码
from torch_scatter import scatter_add
这个torch_scatter是非官方的库,经常安装失败,
找了很多的安装方法,都不好使,特别是对新版的pytorch+cuda环境
机缘巧合发现torch_scatter有github
https://github.com/rusty1s/pytorch_scatter/tree/master/torch_scatter
然后找到我调用的函数,再深挖他的依赖,新建一个名字叫torch_scatter.py的脚本。
然后就可以调用本地的脚本了。
我需要的是scatter_add,对应的脚本如下。
#torch_scatter.py
import torch
from typing import Optionaldef scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,out: Optional[torch.Tensor] = None,dim_size: Optional[int] = None) -> torch.Tensor:index = broadcast(index, src, dim)if out is None:size = list(src.size())if dim_size is not None:size[dim] = dim_sizeelif index.numel() == 0:size[dim] = 0else:size[dim] = int(index.max()) + 1out = torch.zeros(size, dtype=src.dtype, device=src.device)return out.scatter_add_(dim, index, src)else:return out.scatter_add_(dim, index, src)def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,out: Optional[torch.Tensor] = None,dim_size: Optional[int] = None) -> torch.Tensor:return scatter_sum(src, index, dim, out, dim_size)def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):if dim < 0:dim = other.dim() + dimif src.dim() == 1:for _ in range(0, dim):src = src.unsqueeze(0)for _ in range(src.dim(), other.dim()):src = src.unsqueeze(-1)src = src.expand(other.size())return src