直接上代码
- DDP forward
if self.device_ids:if len(self.device_ids) == 1:inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])output = self.module(*inputs[0], **kwargs[0])else:inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)output = self.gather(outputs, self.output_device)
else:output = self.module(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids):return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):r"""Scatter with support for kwargs dictionary"""inputs = scatter(inputs, target_gpus, dim) if inputs else []kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []if len(inputs) < len(kwargs):inputs.extend([() for _ in range(len(kwargs) - len(inputs))])elif len(kwargs) < len(inputs):kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])inputs = tuple(inputs)kwargs = tuple(kwargs)return inputs, kwargs
def scatter(inputs, target_gpus, dim=0):r"""Slices tensors into approximately equal chunks anddistributes them across given GPUs. Duplicatesreferences to objects that are not tensors."""def scatter_map(obj):if isinstance(obj, torch.Tensor):return Scatter.apply(target_gpus, None, dim, obj)if is_namedtuple(obj):return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]if isinstance(obj, tuple) and len(obj) > 0:return list(zip(*map(scatter_map, obj)))if isinstance(obj, list) and len(obj) > 0:return [list(i) for i in zip(*map(scatter_map, obj))]if isinstance(obj, dict) and len(obj) > 0:return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]return [obj for targets in target_gpus]# After scatter_map is called, a scatter_map cell will exist. This cell# has a reference to the actual function scatter_map, which has references# to a closure that has a reference to the scatter_map cell (because the# fn is recursive). To avoid this reference cycle, we set the function to# None, clearing the celltry:res = scatter_map(inputs)finally:scatter_map = Nonereturn res
torch/nn/parallel/_functions.py,默认tensor to gpu,已经有了stream的加持。
from torch.nn.parallel._functions import Scatter, Gatherclass Scatter(Function):@staticmethoddef forward(ctx, target_gpus, chunk_sizes, dim, input):target_gpus = [_get_device_index(x, True) for x in target_gpus]ctx.dim = dimctx.input_device = input.get_device() if input.device.type != "cpu" else -1streams = Noneif torch.cuda.is_available() and ctx.input_device == -1:# Perform CPU to GPU copies in a background streamstreams = [_get_stream(device) for device in target_gpus]outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)# Synchronize with the copy streamif streams is not None:for i, output in enumerate(outputs):with torch.cuda.device(target_gpus[i]):main_stream = torch.cuda.current_stream()main_stream.wait_stream(streams[i])output.record_stream(main_stream)return outputs@staticmethoddef backward(ctx, *grad_output):return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
结论:目前DDP模式下,已经有了preftech + stream的加持。