I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,28 @@
# mypy: allow-untyped-defs
from typing_extensions import deprecated
from torch.nn.parallel.data_parallel import data_parallel, DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import gather, scatter
__all__ = [
"replicate",
"scatter",
"parallel_apply",
"gather",
"data_parallel",
"DataParallel",
"DistributedDataParallel",
]
@deprecated(
"`torch.nn.parallel.DistributedDataParallelCPU` is deprecated, "
"please use `torch.nn.parallel.DistributedDataParallel` instead.",
category=FutureWarning,
)
class DistributedDataParallelCPU(DistributedDataParallel):
pass

View File

@ -0,0 +1,135 @@
import warnings
from typing import List, Optional
import torch
from torch._utils import _get_device_index
from torch.autograd import Function
from torch.nn.parallel import comm
class Broadcast(Function):
@staticmethod
def forward(ctx, target_gpus, *inputs):
assert all(
i.device.type != "cpu" for i in inputs
), "Broadcast function not implemented for CPU tensors"
target_gpus = [_get_device_index(x, True) for x in target_gpus]
ctx.target_gpus = target_gpus
if len(inputs) == 0:
return ()
ctx.num_inputs = len(inputs)
ctx.input_device = inputs[0].get_device()
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
non_differentiables = []
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
if not input_requires_grad:
for output in outputs:
non_differentiables.append(output[idx])
ctx.mark_non_differentiable(*non_differentiables)
return tuple([t for tensors in outputs for t in tensors])
@staticmethod
def backward(ctx, *grad_outputs):
return (None,) + ReduceAddCoalesced.apply(
ctx.input_device, ctx.num_inputs, *grad_outputs
)
class ReduceAddCoalesced(Function):
@staticmethod
def forward(ctx, destination, num_inputs, *grads):
ctx.target_gpus = [
grads[i].get_device() for i in range(0, len(grads), num_inputs)
]
grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)]
return comm.reduce_add_coalesced(grads_, destination)
@staticmethod
def backward(ctx, *grad_outputs):
return (
None,
None,
) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
class Gather(Function):
@staticmethod
def forward(ctx, target_device, dim, *inputs):
assert all(
i.device.type != "cpu" for i in inputs
), "Gather function not implemented for CPU tensors"
if target_device == "cpu":
ctx.target_device = "cpu"
else:
target_device = _get_device_index(target_device, True)
ctx.target_device = target_device
ctx.dim = dim
ctx.input_gpus = tuple(i.get_device() for i in inputs)
if all(t.dim() == 0 for t in inputs) and dim == 0:
inputs = tuple(t.view(1) for t in inputs)
warnings.warn(
"Was asked to gather along dimension 0, but all "
"input tensors were scalars; will instead unsqueeze "
"and return a vector."
)
ctx.unsqueezed_scalar = True
else:
ctx.unsqueezed_scalar = False
ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
return comm.gather(inputs, ctx.dim, ctx.target_device)
@staticmethod
def backward(ctx, grad_output):
scattered_grads = Scatter.apply(
ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output
)
if ctx.unsqueezed_scalar:
scattered_grads = tuple(g[0] for g in scattered_grads)
return (None, None) + scattered_grads
class Scatter(Function):
@staticmethod
def forward(ctx, target_gpus, chunk_sizes, dim, input):
target_gpus = [_get_device_index(x, True) for x in target_gpus]
ctx.dim = dim
ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
streams = None
if torch.cuda.is_available() and ctx.input_device == -1:
# Perform CPU to GPU copies in a background stream
streams = [
_get_stream(torch.device("cuda", device)) for device in target_gpus
]
outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
# Synchronize with the copy stream
if streams is not None:
for i, output in enumerate(outputs):
with torch.cuda.device(target_gpus[i]):
main_stream = torch.cuda.current_stream()
main_stream.wait_stream(streams[i])
output.record_stream(main_stream)
return outputs
@staticmethod
def backward(ctx, *grad_output):
return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
# background streams used for copying
_streams: Optional[List[Optional[torch.Stream]]] = None
def _get_stream(device: torch.device):
"""Get a background stream for copying between CPU and target device."""
global _streams
if device.type == "cpu":
return None
device_mod = getattr(torch, device.type, None)
if device_mod is None:
return None
if _streams is None:
_streams = [None] * device_mod.device_count()
if _streams[device.index] is None:
_streams[device.index] = device_mod.Stream(device.index)
return _streams[device.index]

View File

@ -0,0 +1,260 @@
# mypy: allow-untyped-defs
import warnings
from typing import List
import torch
from torch._utils import (
_flatten_dense_tensors,
_get_device_index,
_handle_complex,
_reorder_tensors_as,
_take_tensors,
_unflatten_dense_tensors,
)
from torch.cuda import nccl
def broadcast(tensor, devices=None, *, out=None):
r"""Broadcasts a tensor to specified GPU devices.
Args:
tensor (Tensor): tensor to broadcast. Can be on CPU or GPU.
devices (Iterable[torch.device, str or int], optional): an iterable of
GPU devices, among which to broadcast.
out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
store output results.
.. note::
Exactly one of :attr:`devices` and :attr:`out` must be specified.
Returns:
- If :attr:`devices` is specified,
a tuple containing copies of :attr:`tensor`, placed on
:attr:`devices`.
- If :attr:`out` is specified,
a tuple containing :attr:`out` tensors, each containing a copy of
:attr:`tensor`.
"""
tensor = _handle_complex(tensor)
if not ((devices is None) ^ (out is None)):
raise RuntimeError(
f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}"
)
if devices is not None:
devices = [_get_device_index(d) for d in devices]
return torch._C._broadcast(tensor, devices)
else:
return torch._C._broadcast_out(tensor, out)
def broadcast_coalesced(tensors, devices, buffer_size=10485760):
"""Broadcast a sequence of tensors to the specified GPUs.
Small tensors are first coalesced into a buffer to reduce the number of synchronizations.
Args:
tensors (sequence): tensors to broadcast. Must be on the same device,
either CPU or GPU.
devices (Iterable[torch.device, str or int]): an iterable of GPU
devices, among which to broadcast.
buffer_size (int): maximum size of the buffer used for coalescing
Returns:
A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`.
"""
devices = [_get_device_index(d) for d in devices]
tensors = [_handle_complex(t) for t in tensors]
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
def reduce_add(inputs, destination=None):
"""Sum tensors from multiple GPUs.
All inputs should have matching shapes, dtype, and layout. The output tensor
will be of the same shape, dtype, and layout.
Args:
inputs (Iterable[Tensor]): an iterable of tensors to add.
destination (int, optional): a device on which the output will be
placed (default: current device).
Returns:
A tensor containing an elementwise sum of all inputs, placed on the
:attr:`destination` device.
"""
destination = _get_device_index(destination, optional=True)
input_size = inputs[0].size()
root_index = None # index of input tensor that already is on the correct device
for i, inp in enumerate(inputs):
assert inp.device.type != "cpu", "reduce_add expects all inputs to be on GPUs"
if inp.get_device() == destination:
root_index = i
if inp.size() != input_size:
got = "x".join(str(x) for x in inp.size())
expected = "x".join(str(x) for x in input_size)
raise ValueError(
f"input {i} has invalid size: got {got}, but expected {expected}"
)
if root_index is None:
raise RuntimeError(
"reduce_add expects destination to be on the same GPU with one of the tensors"
)
if len(inputs) == 1:
return inputs[0]
if nccl.is_available(inputs):
result = torch.empty_like(inputs[root_index])
nccl.reduce(inputs, output=result, root=root_index)
else:
destination_device = torch.device(inputs[root_index].device.type, destination)
nonroot = [t for i, t in enumerate(inputs) if i != root_index]
# make a new tensor w/o clone
result = inputs[root_index] + nonroot[0].to(
device=destination_device, non_blocking=True
)
for other in nonroot[1:]:
result.add_(other.to(device=destination_device, non_blocking=True))
return result
def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
"""Sum tensors from multiple GPUs.
Small tensors are first coalesced into a buffer to reduce the number
of synchronizations.
Args:
inputs (Iterable[Iterable[Tensor]]): iterable of iterables that
contain tensors from a single device.
destination (int, optional): a device on which the output will be
placed (default: current device).
buffer_size (int): maximum size of the buffer used for coalescing
Returns:
A tuple of tensors containing an elementwise sum of each group of
inputs, placed on the ``destination`` device.
"""
# TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
# return `inputs`.
dense_tensors: List[List] = [[] for _ in inputs] # shape (num_gpus, num_tensors)
output = []
ref_order = []
# process sparse ones first since they may have different sizes on different gpus
for tensor_at_gpus in zip(*inputs):
if all(t.is_sparse for t in tensor_at_gpus):
result = reduce_add(tensor_at_gpus, destination) # this will be sparse too
output.append(result)
ref_order.append(tensor_at_gpus[0])
else:
for coll, t in zip(dense_tensors, tensor_at_gpus):
coll.append(t.to_dense() if t.is_sparse else t)
ref_order.append(dense_tensors[0][-1])
itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
# now the dense ones, which have consistent sizes
for chunks in zip(*itrs):
flat_tensors = [
_flatten_dense_tensors(chunk) for chunk in chunks
] # (num_gpus,)
flat_result = reduce_add(flat_tensors, destination)
for t in _unflatten_dense_tensors(flat_result, chunks[0]):
# The unflattened tensors do not share storage, and we don't expose
# base flat tensor anyways, so give them different version counters.
# See NOTE [ Version Counter in comm.*_coalesced ]
output.append(t.data)
return tuple(_reorder_tensors_as(output, ref_order))
def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None):
"""Scatters tensor across multiple GPUs.
Args:
tensor (Tensor): tensor to scatter. Can be on CPU or GPU.
devices (Iterable[torch.device, str or int], optional): an iterable of
GPU devices, among which to scatter.
chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on
each device. It should match :attr:`devices` in length and sums to
``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided
into equal chunks.
dim (int, optional): A dimension along which to chunk :attr:`tensor`.
Default: ``0``.
streams (Iterable[torch.cuda.Stream], optional): an iterable of Streams, among
which to execute the scatter. If not specified, the default stream will
be utilized.
out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
store output results. Sizes of these tensors must match that of
:attr:`tensor`, except for :attr:`dim`, where the total size must
sum to ``tensor.size(dim)``.
.. note::
Exactly one of :attr:`devices` and :attr:`out` must be specified. When
:attr:`out` is specified, :attr:`chunk_sizes` must not be specified and
will be inferred from sizes of :attr:`out`.
Returns:
- If :attr:`devices` is specified,
a tuple containing chunks of :attr:`tensor`, placed on
:attr:`devices`.
- If :attr:`out` is specified,
a tuple containing :attr:`out` tensors, each containing a chunk of
:attr:`tensor`.
"""
tensor = _handle_complex(tensor)
if out is None:
devices = [_get_device_index(d) for d in devices]
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
else:
if devices is not None:
raise RuntimeError(
f"'devices' must not be specified when 'out' is specified, but got devices={devices}"
)
if chunk_sizes is not None:
raise RuntimeError(
f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}"
)
return tuple(torch._C._scatter_out(tensor, out, dim, streams))
def gather(tensors, dim=0, destination=None, *, out=None):
r"""Gathers tensors from multiple GPU devices.
Args:
tensors (Iterable[Tensor]): an iterable of tensors to gather.
Tensor sizes in all dimensions other than :attr:`dim` have to match.
dim (int, optional): a dimension along which the tensors will be
concatenated. Default: ``0``.
destination (torch.device, str, or int, optional): the output device.
Can be CPU or CUDA. Default: the current CUDA device.
out (Tensor, optional, keyword-only): the tensor to store gather result.
Its sizes must match those of :attr:`tensors`, except for :attr:`dim`,
where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``.
Can be on CPU or CUDA.
.. note::
:attr:`destination` must not be specified when :attr:`out` is specified.
Returns:
- If :attr:`destination` is specified,
a tensor located on :attr:`destination` device, that is a result of
concatenating :attr:`tensors` along :attr:`dim`.
- If :attr:`out` is specified,
the :attr:`out` tensor, now containing results of concatenating
:attr:`tensors` along :attr:`dim`.
"""
tensors = [_handle_complex(t) for t in tensors]
if out is None:
if destination == -1:
warnings.warn(
"Using -1 to represent CPU tensor is deprecated. Please use a "
'device object or string instead, e.g., "cpu".',
FutureWarning,
stacklevel=2,
)
destination = _get_device_index(destination, allow_cpu=True, optional=True)
return torch._C._gather(tensors, dim, destination)
else:
if destination is not None:
raise RuntimeError(
f"'destination' must not be specified when 'out' is specified, but got destination={destination}"
)
return torch._C._gather_out(tensors, out, dim)

View File

@ -0,0 +1,285 @@
# mypy: allow-untyped-defs
import operator
import warnings
from itertools import chain
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
import torch
from torch._utils import (
_get_all_device_indices,
_get_available_device_type,
_get_device_index,
_get_devices_properties,
)
from torch.nn.modules import Module
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
__all__ = ["DataParallel", "data_parallel"]
def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None:
imbalance_warn = """
There is an imbalance between your GPUs. You may want to exclude GPU {} which
has less than 75% of the memory or cores of GPU {}. You can do so by setting
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
environment variable."""
device_ids = [_get_device_index(x, True) for x in device_ids]
dev_props = _get_devices_properties(device_ids)
def warn_imbalance(get_prop):
values = [get_prop(props) for props in dev_props]
min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
if min_val / max_val < 0.75:
warnings.warn(
imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])
)
return True
return False
if warn_imbalance(lambda props: props.total_memory):
return
if warn_imbalance(lambda props: props.multi_processor_count):
return
T = TypeVar("T", bound=Module)
class DataParallel(Module, Generic[T]):
r"""Implements data parallelism at the module level.
This container parallelizes the application of the given :attr:`module` by
splitting the input across the specified devices by chunking in the batch
dimension (other objects will be copied once per device). In the forward
pass, the module is replicated on each device, and each replica handles a
portion of the input. During the backwards pass, gradients from each replica
are summed into the original module.
The batch size should be larger than the number of GPUs used.
.. warning::
It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
instead of this class, to do multi-GPU training, even if there is only a single
node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
Arbitrary positional and keyword inputs are allowed to be passed into
DataParallel but some types are specially handled. tensors will be
**scattered** on dim specified (default 0). tuple, list and dict types will
be shallow copied. The other types will be shared among different threads
and can be corrupted if written to in the model's forward pass.
The parallelized :attr:`module` must have its parameters and buffers on
``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
module.
.. warning::
In each forward, :attr:`module` is **replicated** on each device, so any
updates to the running module in ``forward`` will be lost. For example,
if :attr:`module` has a counter attribute that is incremented in each
``forward``, it will always stay at the initial value because the update
is done on the replicas which are destroyed after ``forward``. However,
:class:`~torch.nn.DataParallel` guarantees that the replica on
``device[0]`` will have its parameters and buffers sharing storage with
the base parallelized :attr:`module`. So **in-place** updates to the
parameters or buffers on ``device[0]`` will be recorded. E.g.,
:class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
rely on this behavior to update the buffers.
.. warning::
Forward and backward hooks defined on :attr:`module` and its submodules
will be invoked ``len(device_ids)`` times, each with inputs located on
a particular device. Particularly, the hooks are only guaranteed to be
executed in correct order with respect to operations on corresponding
devices. For example, it is not guaranteed that hooks set via
:meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
`all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
that each such hook be executed before the corresponding
:meth:`~torch.nn.Module.forward` call of that device.
.. warning::
When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
:func:`forward`, this wrapper will return a vector of length equal to
number of devices used in data parallelism, containing the result from
each device.
.. note::
There is a subtlety in using the
``pack sequence -> recurrent network -> unpack sequence`` pattern in a
:class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
details.
Args:
module (Module): module to be parallelized
device_ids (list of int or torch.device): CUDA devices (default: all devices)
output_device (int or torch.device): device location of output (default: device_ids[0])
Attributes:
module (Module): the module to be parallelized
Example::
>>> # xdoctest: +SKIP
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var) # input_var can be on any device, including CPU
"""
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
def __init__(
self,
module: T,
device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
output_device: Optional[Union[int, torch.device]] = None,
dim: int = 0,
) -> None:
super().__init__()
torch._C._log_api_usage_once("torch.nn.parallel.DataParallel")
device_type = _get_available_device_type()
if device_type is None:
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = _get_all_device_indices()
if device_ids is None:
raise RuntimeError("no available devices were found")
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = [_get_device_index(x, True) for x in device_ids]
self.output_device = _get_device_index(output_device, True)
self.src_device_obj = torch.device(device_type, self.device_ids[0])
if device_type == "cuda":
_check_balance(self.device_ids)
if len(self.device_ids) == 1:
self.module.to(self.src_device_obj)
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
with torch.autograd.profiler.record_function("DataParallel.forward"):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
"module must have its parameters and buffers "
f"on device {self.src_device_obj} (device_ids[0]) but found one of "
f"them on device: {t.device}"
)
inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
# for forward function without any inputs, empty list and dict will be created
# so the module can be executed on one device which is the first one in device_ids
if not inputs and not module_kwargs:
inputs = ((),)
module_kwargs = ({},)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **module_kwargs[0])
replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
return self.gather(outputs, self.output_device)
def replicate(
self, module: T, device_ids: Sequence[Union[int, torch.device]]
) -> List[T]:
return replicate(module, device_ids, not torch.is_grad_enabled())
def scatter(
self,
inputs: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
device_ids: Sequence[Union[int, torch.device]],
) -> Any:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(
self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
) -> List[Any]:
return parallel_apply(
replicas, inputs, kwargs, self.device_ids[: len(replicas)]
)
def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any:
return gather(outputs, output_device, dim=self.dim)
def data_parallel(
module: Module,
inputs: Any,
device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
output_device: Optional[Union[int, torch.device]] = None,
dim: int = 0,
module_kwargs: Optional[Any] = None,
) -> torch.Tensor:
r"""Evaluate module(input) in parallel across the GPUs given in device_ids.
This is the functional version of the DataParallel module.
Args:
module (Module): the module to evaluate in parallel
inputs (Tensor): inputs to the module
device_ids (list of int or torch.device): GPU ids on which to replicate module
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0])
Returns:
a Tensor containing the result of module(input) located on
output_device
"""
if not isinstance(inputs, tuple):
inputs = (inputs,) if inputs is not None else ()
device_type = _get_available_device_type()
if device_type is None:
raise RuntimeError("device type could not be determined")
if device_ids is None:
device_ids = _get_all_device_indices()
if device_ids is None:
raise RuntimeError("no available devices were found")
if output_device is None:
output_device = device_ids[0]
device_ids = [_get_device_index(x, True) for x in device_ids]
output_device = _get_device_index(output_device, True)
src_device_obj = torch.device(device_type, device_ids[0])
for t in chain(module.parameters(), module.buffers()):
if t.device != src_device_obj:
raise RuntimeError(
"module must have its parameters and buffers "
f"on device {src_device_obj} (device_ids[0]) but found one of "
f"them on device: {t.device}"
)
inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
# for module without any inputs, empty list and dict will be created
# so the module can be executed on one device which is the first one in device_ids
if not inputs and not module_kwargs:
inputs = ((),)
module_kwargs = ({},)
assert module_kwargs is not None
if len(device_ids) == 1:
return module(*inputs[0], **module_kwargs[0])
used_device_ids = device_ids[: len(inputs)]
replicas = replicate(module, used_device_ids)
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
return gather(outputs, output_device, dim)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,128 @@
import threading
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch._utils import ExceptionWrapper
from torch.cuda._utils import _get_device_index
from torch.nn.modules import Module
__all__ = ["get_a_var", "parallel_apply"]
def get_a_var(
obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]],
) -> Optional[torch.Tensor]:
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (list, tuple)):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(
modules: Sequence[Module],
inputs: Sequence[Any],
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
) -> List[Any]:
r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(
inputs
), f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = [_get_device_index(x, True) for x in devices]
streams = [torch.cuda.current_stream(x) for x in devices]
lock = threading.Lock()
results = {}
grad_enabled, autocast_enabled = (
torch.is_grad_enabled(),
torch.is_autocast_enabled(),
)
def _worker(
i: int,
module: Module,
input: Any,
kwargs: Dict[str, Any],
device: Optional[Union[int, torch.device]] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
torch.set_grad_enabled(grad_enabled)
if device is None:
t = get_a_var(input)
if t is None:
with lock:
results[i] = ExceptionWrapper(
where=f"in replica {i}, no device was provided and no tensor input was found; "
"device cannot be resolved"
)
return
device = t.get_device()
if stream is None:
stream = torch.cuda.current_stream(device)
try:
with torch.cuda.device(device), torch.cuda.stream(
stream
), torch.amp.autocast("cuda", enabled=autocast_enabled):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = module(*input, **kwargs)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(
where=f"in replica {i} on device {device}"
)
if len(modules) > 1:
threads = [
threading.Thread(
target=_worker, args=(i, module, input, kwargs, device, stream)
)
for i, (module, input, kwargs, device, stream) in enumerate(
zip(modules, inputs, kwargs_tup, devices, streams)
)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs

View File

@ -0,0 +1,212 @@
from collections import OrderedDict
from typing import (
cast,
Dict,
Iterator,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
TypeVar,
Union,
)
import torch
from torch._utils import _get_device_index
from torch.nn.modules import Module
from torch.nn.parallel import comm
if TYPE_CHECKING:
from torch.jit import ScriptModule
from torch.jit._state import EnabledProxy
__all__ = ["replicate"]
def _is_script_module(module: Module) -> bool:
import torch.jit
return isinstance(module, torch.jit.ScriptModule)
def _is_script_method(module: Module) -> bool:
import torch.jit
return isinstance(module, torch._C.ScriptMethod)
def _init_script_module() -> "ScriptModule":
import torch.jit
return torch.jit.ScriptModule()
def _is_jit_enabled() -> "EnabledProxy":
import torch.jit._state
return torch.jit._state._enabled
# Check if we can safely replicate the module.
# there are two types of module:
# 1. python modules
# 2. ScriptModule
#
# currently a module cannot be replicated properly if the descendants of
# any ScriptModule contains python module (type 1 above)
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
# module.modules() contains module itself as the first element
def descendant_modules(module: Module) -> Iterator[Module]:
gen = module.modules()
next(gen)
return gen
if not _is_jit_enabled():
return True
if memo is None:
memo = set()
# memoize visited modules
memo.add(module)
if _is_script_module(module):
memo.update(descendant_modules(module))
return all(
_is_script_module(descendant) for descendant in descendant_modules(module)
)
for child in module.children():
# since any unreplicatable module will cause the check to return
# False early, visited modules here can be safely ignored.
if child in memo:
continue
if not _replicatable_module(child, memo):
return False
return True
def _broadcast_coalesced_reshape(
tensors: Sequence[torch.Tensor],
devices: Sequence[Union[int, torch.device]],
detach: bool = False,
) -> List[List[torch.Tensor]]:
from torch.nn.parallel._functions import Broadcast
if detach:
return comm.broadcast_coalesced(tensors, devices)
else:
# Use the autograd function to broadcast if not detach
if len(tensors) > 0:
tensor_copies = Broadcast.apply(devices, *tensors)
return [
tensor_copies[i : i + len(tensors)]
for i in range(0, len(tensor_copies), len(tensors))
]
else:
return []
T = TypeVar("T", bound=Module)
def replicate(
network: T,
devices: Sequence[Union[int, torch.device]],
detach: bool = False,
) -> List[T]:
if not _replicatable_module(network):
raise RuntimeError(
"Cannot replicate network where python modules are "
"childrens of ScriptModule"
)
if not devices:
return []
devices = [_get_device_index(x, True) for x in devices]
num_replicas = len(devices)
params = list(network.parameters())
param_indices = {param: idx for idx, param in enumerate(params)}
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
buffers = list(network.buffers())
buffers_rg: List[torch.Tensor] = []
buffers_not_rg: List[torch.Tensor] = []
for buf in buffers:
if buf.requires_grad and not detach:
buffers_rg.append(buf)
else:
buffers_not_rg.append(buf)
buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
buffer_copies_not_rg = _broadcast_coalesced_reshape(
buffers_not_rg, devices, detach=True
)
modules = list(network.modules())
module_copies: List[List[Module]] = [[] for _ in devices]
module_indices: Dict[Module, int] = {}
for i, module in enumerate(modules):
module_indices[module] = i
for j in range(num_replicas):
replica = module._replicate_for_data_parallel()
# This is a temporary fix for DDP. DDP needs to access the
# replicated model parameters. It used to do so through
# `mode.parameters()`. The fix added in #33907 for DP stops the
# `parameters()` API from exposing the replicated parameters.
# Hence, we add a `_former_parameters` dict here to support DDP.
replica._former_parameters = OrderedDict()
module_copies[j].append(replica)
for i, module in enumerate(modules):
for key, child in module._modules.items():
if child is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._modules[key] = None
else:
module_idx = module_indices[child]
for j in range(num_replicas):
replica = module_copies[j][i]
setattr(replica, key, module_copies[j][module_idx])
for key, param in module._parameters.items():
if param is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._parameters[key] = None
else:
param_idx = param_indices[param]
for j in range(num_replicas):
replica = module_copies[j][i]
param_copy = param_copies[j][param_idx]
# parameters in replicas are no longer leaves,
# so setattr them as non-parameter attributes
setattr(replica, key, param_copy)
# expose the parameter for DDP
replica._former_parameters[key] = param_copy
for key, buf in module._buffers.items(): # type: ignore[assignment]
if buf is None:
for j in range(num_replicas):
replica = module_copies[j][i]
replica._buffers[key] = None
else:
if buf.requires_grad and not detach:
buffer_copies = buffer_copies_rg
buffer_idx = buffer_indices_rg[buf]
else:
buffer_copies = buffer_copies_not_rg
buffer_idx = buffer_indices_not_rg[buf]
for j in range(num_replicas):
replica = module_copies[j][i]
setattr(replica, key, buffer_copies[j][buffer_idx])
return [cast(T, module_copies[j][0]) for j in range(num_replicas)]

View File

@ -0,0 +1,138 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar, Union
from typing_extensions import deprecated
import torch
from torch.nn.parallel._functions import Gather, Scatter
__all__ = ["scatter", "scatter_kwargs", "gather"]
@deprecated(
"`is_namedtuple` is deprecated, please use the python checks instead",
category=FutureWarning,
)
def is_namedtuple(obj: Any) -> bool:
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return _is_namedtuple(obj)
def _is_namedtuple(obj: Any) -> bool:
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
)
T = TypeVar("T", dict, list, tuple)
# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise.
@overload
def scatter(
inputs: torch.Tensor,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> Tuple[torch.Tensor, ...]:
...
@overload
def scatter(
inputs: T,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> List[T]:
...
def scatter(inputs, target_gpus, dim=0):
r"""Slice tensors into approximately equal chunks and distributes them across given GPUs.
Duplicates references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
res = scatter_map(inputs)
finally:
scatter_map = None # type: ignore[assignment]
return res
def scatter_kwargs(
inputs: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
target_gpus: Sequence[Union[int, torch.device]],
dim: int = 0,
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
r"""Scatter with support for kwargs dictionary."""
scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(scattered_inputs) < len(scattered_kwargs):
scattered_inputs.extend(
() for _ in range(len(scattered_kwargs) - len(scattered_inputs))
)
elif len(scattered_kwargs) < len(inputs):
scattered_kwargs.extend(
{} for _ in range(len(scattered_inputs) - len(scattered_kwargs))
)
return tuple(scattered_inputs), tuple(scattered_kwargs)
def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any:
r"""Gather tensors from different GPUs on a specified device.
This function is useful for gathering the results of a distributed computation.
It takes a sequence of objects, one for each GPU, and returns a single object
on the specified device.
Args:
outputs (Any): A sequence of objects (potentially tensors) to gather.
target_device (Union[int, torch.device]): The device to gather the tensors to.
Use 'cpu' for CPU to avoid a deprecation warning.
dim (int, optional): The dimension along which to gather. Default: 0.
Returns:
Any: A gathered object (potentially tensor) on the specified device.
"""
def gather_map(outputs):
out = outputs[0]
if isinstance(out, torch.Tensor):
return Gather.apply(target_device, dim, *outputs)
if out is None:
return None
if isinstance(out, dict):
if not all(len(out) == len(d) for d in outputs):
raise ValueError("All dicts must have the same number of keys")
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
if _is_namedtuple(out):
return type(out)._make(map(gather_map, zip(*outputs)))
return type(out)(map(gather_map, zip(*outputs)))
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
res = gather_map(outputs)
finally:
gather_map = None # type: ignore[assignment]
return res