通过装饰器将有BUG的pytorch算子放置在CPU上,而不用修改模型代码
- 1.代码
某些pytorch算子发下到设备后会导致设备异常,为了暂时规避,先放在CPU上执行
修改模型源码很不友好,可以采用以下方法
1.代码
def force_cpu(func):def wrapper(self, *args, **kwargs):self_cpu = self.cpu()args_cpu = [arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args]kwargs_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}return func(self_cpu, *args_cpu, **kwargs_cpu).to(self.device)return wrappertorch.Tensor.masked_fill = force_cpu(torch.Tensor.masked_fill)device="cuda"
causal_mask=torch.ones((1, 1, 2048, 2048),dtype=torch.float16).to(device)
padding_mask=torch.ones((1, 1, 2048, 14),dtype=torch.bool).to(device)
min_dtype= -65504.0
mask_length=14
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
torch.cuda.synchronize()