I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1 @@
from .join import Join, Joinable, JoinHook

View File

@ -0,0 +1,323 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from enum import auto, Enum
from functools import partial
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
import torch
import torch.nn as nn
from torch.autograd.graph import save_on_cpu
from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module"
_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "."
class CheckpointImpl(Enum):
REENTRANT = auto()
NO_REENTRANT = auto()
class ActivationWrapper(torch.nn.Module, ABC):
"""
Base class for Activation Checkpoint and Activation Offload.
Not meant to be instantiated directly.
"""
def __init__(self, mod):
super().__init__()
self._checkpoint_wrapped_module = mod
# state_dict post hook to remove prefix to allow loading into a
# non-checkpoint wrapped module.
self._register_state_dict_hook(self._post_state_dict_hook)
# load_state_dict pre-hook to allow loading back into
# checkpoint-wrapped module.
self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
@abstractmethod
def forward(self, *args, **kwargs):
raise ValueError("Subclasses should implement forward().")
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self._checkpoint_wrapped_module, name)
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator]
def named_parameters(
self,
*args,
**kwargs,
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""
Override :meth:`named_parameters()` to intercept parameter names.
remove all occurrences of ``_CHECKPOINT_PREFIX``.
"""
for param_name, param in super().named_parameters(*args, **kwargs):
yield param_name.replace(_CHECKPOINT_PREFIX, ""), param
@staticmethod
def _post_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
"""
_post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix,
so that this module can be loaded into non-checkpointed modules.
It would still be able to be loaded into checkpoint-wrapped modules as this class,
adds the prefix back before loading the state_dict.
"""
_replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix)
return state_dict
@staticmethod
def _pre_load_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> None:
"""
``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called.
For ``checkpoint_wrapper``, it will add back the module
prefix so that non-checkpointed modules can be loaded into
checkpoint_wrapper modules properly.
"""
_replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}")
class OffloadWrapper(ActivationWrapper):
def __init__(self, mod):
super().__init__(mod)
def forward(self, *args, **kwargs):
with save_on_cpu(pin_memory=True):
return self._checkpoint_wrapped_module(*args, **kwargs)
class CheckpointWrapper(ActivationWrapper):
"""
An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing.
Note that this module is not meant to be used directly but instead,
it is to be used through the ``checkpoint_wrapper`` function.
"""
def __init__(
self,
mod: torch.nn.Module,
checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
checkpoint_fn=None,
**checkpoint_fn_kwargs,
):
super().__init__(mod)
self.checkpoint_impl = checkpoint_impl
if checkpoint_fn is None:
# use torch.utils.checkpoint
self.checkpoint_fn = partial(
torch_utils_checkpoint,
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
**checkpoint_fn_kwargs,
)
else:
# Construct user-specified checkpoint function.
self.checkpoint_fn = partial(
checkpoint_fn,
**checkpoint_fn_kwargs,
)
def forward(self, *args, **kwargs):
# Support keyword arguments for reentrant checkpoint. Note that this
# only works if user has specified self.checkpoint_impl and is not
# using their own custom checkpoint_fn.
if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}:
# Pack the args and kwargs
flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs)
# Function that only takes (packed) args, but can unpack them
# into the original args and kwargs for the checkpointed
# function, and runs that function.
def my_function(*inputs):
# unpack back into args and kwargs
unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys)
# run original module
return self._checkpoint_wrapped_module(
*unpacked_args, **unpacked_kwargs
)
# Pass the function that only takes packed args into reentrant
# checkpoint API.
return self.checkpoint_fn( # type: ignore[misc]
my_function,
*flat_args,
)
else:
return self.checkpoint_fn( # type: ignore[misc]
self._checkpoint_wrapped_module, *args, **kwargs
)
def offload_wrapper(module: torch.nn.Module) -> torch.nn.Module:
"""
Wrap a module for activation offloading to CPU.
Offloads intermediate activations to the CPU for modules wrapped with this function.
Wrappers with activation offload can be composed with ones that do recomputation-based
checkpoint to trade off increased compute versus increased CPU
memory usage and additional H2D transfers.
Usage::
offloaded_module = offload_wrapper(module)
outputs = checkpointed_module(inputs)
Args:
module (nn.Module):
The module to be wrapped
Returns:
(nn.Module):
Wrapped module
"""
return OffloadWrapper(module)
def checkpoint_wrapper(
module: torch.nn.Module,
checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
checkpoint_fn=None,
**checkpoint_fn_kwargs,
) -> torch.nn.Module:
"""
Wrap a module for activation checkpointing.
If the module is wrapped with this function, all subsequent calls to the module will,
automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function.
Usage::
checkpointed_module = checkpoint_wrapper(module)
outputs = checkpointed_module(inputs)
Args:
module (nn.Module):
The module to be wrapped
checkpoint_impl (Optional[CheckpointImpl]):
The checkpointing implementation to use. Note that this will only
be passed into the ``torch.utils.checkpoint.checkpoint``
implementation, and is ignored if a custom ``checkpoint_fn`` is
specified. Note that for implementations using reentrant checkpoint
from ``torch.utils.checkpoint``, keyword arguments will only be
supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`.
checkpoint_fn (Optional[Callable]):
Functional checkpoint implementation to use. If this is specified,
it will be used over the default ``torch.utils.checkpoint.checkpoint``
implementation and the `checkpoint_impl` argument will be ignored.
**checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`.
Returns:
(nn.Module):
Wrapped module
"""
if checkpoint_impl == CheckpointImpl.REENTRANT:
warnings.warn(
f"Please specify {CheckpointImpl.NO_REENTRANT} as "
f"{CheckpointImpl.REENTRANT} will soon be removed as "
"the default and eventually deprecated.",
FutureWarning,
stacklevel=2,
)
return CheckpointWrapper(
module,
checkpoint_impl,
checkpoint_fn,
**checkpoint_fn_kwargs,
)
def apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=lambda _: True,
auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None,
):
"""
Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration.
For each module within `model`, the `check_fn` is used to decide
whether `module` should be wrapped with :func:`checkpoint_wrapper` or not.
Note::
This function modifies `model` in place and replaces appropriate layers with
their checkpoint-wrapped modules.
Note::
This function will not wrap the overall root module. If this is needed, please directly use
:func:`checkpoint_wrapper` or :func:`offload_wrapper`.
Usage::
model = nn.Sequential(
nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)
)
check_fn = lambda l: isinstance(l, nn.Linear)
# checkpoint activations
apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
# Or offload activations to CPU
apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn)
Args:
model (nn.Module):
The model whose submodules should be wrapped with activation checkpointing.
checkpoint_wrapper_fn (Optional[Callable[nn.Module]])
A ``Callable`` which will wrap modules
check_fn (Optional[Callable[nn.Module, nn.Module]])
A lambda function which will be passed each child submodule of ``model`` and returns
``True`` or ``False`` depending on whether the submodule should be wrapped.
auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's
submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``.
Returns: None (`model` is modified inplace)
"""
# TODO: Importing inside function to avoid circular import issue between FSDP and
# checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code.
from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply
from torch.distributed.fsdp.wrap import (
_Policy,
_recursive_wrap,
lambda_auto_wrap_policy,
)
policy = (
auto_wrap_policy
if auto_wrap_policy is not None
else partial(lambda_auto_wrap_policy, lambda_fn=check_fn)
)
if not callable(policy):
if not isinstance(policy, _Policy):
raise ValueError(
f"Expected {policy} to be callable or be a pre-defined wrap policy"
)
target_module_to_kwargs = policy._run_policy(
model, ignored_modules=set(), root_kwargs={}
)
wrap_fn = _construct_wrap_fn(
model, target_module_to_kwargs, checkpoint_wrapper_fn
)
_post_order_apply(model, wrap_fn)
return
_recursive_wrap(
module=model,
auto_wrap_policy=policy, # type: ignore[arg-type]
wrapper_cls=checkpoint_wrapper_fn,
ignored_modules=set(),
ignored_params=set(),
only_wrap_children=True,
)

View File

@ -0,0 +1,7 @@
from . import default_hooks as default
LOW_PRECISION_HOOKS = [
default.fp16_compress_hook,
default.bf16_compress_hook,
]

View File

@ -0,0 +1,192 @@
# mypy: allow-untyped-defs
import functools
from typing import Optional
import torch
import torch.distributed as dist
class DefaultState:
r"""
Stores state needed to perform the default communication algorithm within a communication hook.
Args:
process_group (ProcessGroup): The process group to be used.
"""
__slots__ = [
"process_group",
"world_size",
"gradient_predivide_factor",
"gradient_postdivide_factor",
]
def __init__(self, process_group: dist.ProcessGroup):
if process_group is None:
raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
self.process_group = process_group
self.world_size = dist.get_world_size(process_group)
# Setting two factors `self.gradient_predivide_factor`
# and `self.gradient_postdivide_factor` to avoid underflow and overflow
self.gradient_predivide_factor = self._get_gradient_predivide_factor(
self.world_size
)
self.gradient_postdivide_factor = (
self.world_size / self.gradient_predivide_factor
)
@staticmethod
def _get_gradient_predivide_factor(world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor *= 2
return float(factor)
class LowPrecisionState(DefaultState):
r"""
Stores state needed to perform gradient communication in a lower precision within a communication hook.
Communication hook will cast gradients back to the original
parameter precision specified by ``parameter_type`` (default: torch.float32).
Builds on top of the :class:`DefaultState`.
Args:
parameter_type (torch.dtype): The precision of model's parameters.
Required for a hook to cast gradients back to a parameter's precision.
"""
__slots__ = [
"parameter_type",
]
def __init__(
self,
process_group,
parameter_type=torch.float32,
):
super().__init__(process_group)
self.parameter_type = parameter_type
def _decompress(state: LowPrecisionState, grad: torch.Tensor):
"""
Casts gradients back to full parameter precision so that further computation happens in full precision.
"""
orig_grad_data = grad.data
grad.data = grad.data.to(state.parameter_type)
device_type = ""
try:
if grad.device.type == "privateuse1":
device_type = torch._C._get_privateuse1_backend_name()
else:
device_type = grad.device.type
backend = getattr(torch, device_type)
except AttributeError as e:
raise AttributeError(
f"Device {grad.device} does not have a \
corresponding backend registered as 'torch.device_type'."
) from e
# Don't let this memory get reused until after the transfer.
orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type]
def allreduce_hook(state: DefaultState, grad: torch.Tensor):
r"""
Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients.
Args:
state (DefaultState): State information, configures pre- and post-division factors.
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks.
"""
# Average grad by pre-division factor. Together pre- and post-division factors
# lead to an overall averaging by world_size, required for consistency with PyTorch DDP.
# This is a two-step process to avoid potential underflow and overflow.
if state.gradient_predivide_factor > 1:
grad.div_(state.gradient_predivide_factor)
dist.all_reduce(grad, group=state.process_group)
# Average grad by post-division factor.
if state.gradient_postdivide_factor > 1:
grad.div_(state.gradient_postdivide_factor)
def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor):
r"""
Implement the FSDP communication hook for ``reduce_scatter`` algorithm.
For sharded FSDP strategies and a necessary pre- and post-division of gradients.
Args:
state (DefaultState): State information, configures pre- and post-division factors.
grad (torch.Tensor): An unsharded gradient for the local batch that needs to be
communicated across ranks.
output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
"""
# Average grad by pre-division factor.
if state.gradient_predivide_factor > 1:
grad.div_(state.gradient_predivide_factor)
dist.reduce_scatter_tensor(output, grad, group=state.process_group)
# Average grad's shard by post-division factor.
if state.gradient_postdivide_factor > 1:
output.div_(state.gradient_postdivide_factor)
def _low_precision_hook(
prec: torch.dtype,
state: LowPrecisionState,
grad: torch.Tensor,
output: torch.Tensor,
):
if grad.dtype != prec:
grad.data = grad.data.to(prec)
if output is not None:
if output.dtype != prec:
output.data = output.data.to(prec)
reduce_scatter_hook(state, grad, output)
_decompress(state, output)
else:
allreduce_hook(state, grad)
_decompress(state, grad)
def fp16_compress_hook(
state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
):
r"""
Implement FSDP communication hook for a simple gradient compression approach.
Casts ``grad`` to half-precision floating-point format (``torch.float16``).
It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
gradients are averaged by a ``state.gradient_postdivide_factor``.
Once post-division is done, compressed gradients are casted back to parameters' precision.
Args:
state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
"""
fp16_hook = functools.partial(_low_precision_hook, torch.float16)
return fp16_hook(state, grad, output)
def bf16_compress_hook(
state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
):
r"""
Implement FSDP communication hook for a simple gradient compression approach .
Casts ``grad`` to half-precision floating-point format.
It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
gradients are averaged by a ``state.gradient_postdivide_factor``.
Once post-division is done, compressed gradients are casted back to parameters' precision.
Args:
state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
"""
bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16)
return bf16_hook(state, grad, output)

View File

@ -0,0 +1 @@
from .optimizer_overlap import _as_overlapped_optim

View File

@ -0,0 +1,97 @@
# mypy: allow-untyped-defs
import inspect
from abc import ABC, abstractmethod
from typing import Dict, Type
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
_hook_then_optimizer,
_OptimizerHookState,
)
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.optim import as_functional_optim
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
# Contains the mappings between the regular and overlapped optimizer types.
_registered_overlapped_optims: Dict[Type, Type] = {}
def register_overlapped(optim_cls):
def decorator(target_overlapped_optim_cls):
if target_overlapped_optim_cls in _registered_overlapped_optims:
raise ValueError(
f"{target_overlapped_optim_cls} already registered with optim_cls "
f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to"
f"re-register it for {optim_cls} is not supported."
)
_registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls
return target_overlapped_optim_cls
return decorator
class OverlappedOptimizer(ABC):
def __init__(self, optim_cls: Type) -> None:
"""
Initialize the OverlappedOptimizer.
Overlappedoptimizer is a base class that child classes can implement to
specify how different optimizers will register themselves with DDP.
"""
self.optim_cls = optim_cls
@abstractmethod
def register_ddp(self, ddp: DistributedDataParallel) -> None:
"""Registers the overlapped optimizer with DDP."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support overlapped DDP."
)
@abstractmethod
def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
"""Registers the overlapped optimizer with FSDP."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support overlapped FSDP."
)
@register_overlapped(Optimizer)
class _OverlappedStandardOptimizer(OverlappedOptimizer):
"""Overlaps a regular ``Optimizer``."""
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
super().__init__(optim_cls)
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
self._opt_hook_state = _OptimizerHookState(f_optim, params)
def register_ddp(self, ddp_inst: DistributedDataParallel):
# NOTE: using a custom communication hook and fused optimizer is not
# yet supported.
ddp_inst.register_comm_hook( # type: ignore[operator]
None, # wrapped hook state
_hook_then_optimizer(allreduce_hook, self._opt_hook_state),
)
# TODO: register_fsdp once FSDP supports communication hook.
def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
"""Register the overlapped optimizer with FSDP."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support overlapped FSDP."
)
def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
"""Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
for clz in inspect.getmro(optim_cls):
try:
return _registered_overlapped_optims[clz](
optim_cls, params, *args, **kwargs
)
except KeyError:
pass
# Fallback to standard overlapped optimizer, which will raise errors if user
# is attempting to use an unsupported optimizer.
return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs)

View File

@ -0,0 +1,150 @@
# mypy: allow-untyped-defs
import functools
from enum import Enum
import torch
import torch.distributed as dist
TORCH_HALF_MIN = torch.finfo(torch.float16).min
TORCH_HALF_MAX = torch.finfo(torch.float16).max
class DQuantType(Enum):
"""
Different quantization methods for auto_quantize API are identified here.
auto_quantize API currently supports fp16 and bfp16 methods.
"""
FP16 = ("fp16",)
BFP16 = "bfp16"
def __str__(self) -> str:
return self.value
def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
def _quantize_tensor(tensor, qtype):
if not isinstance(tensor, torch.Tensor):
raise RuntimeError(
f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
)
if qtype == DQuantType.FP16:
return _fp32_to_fp16_with_clamp(tensor)
elif qtype == DQuantType.BFP16:
return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
else:
raise RuntimeError(f"Quantization type {qtype} is not supported")
def _quantize_tensor_list(tensor_list, qtype):
if not isinstance(tensor_list, list) or not all(
isinstance(p, torch.Tensor) for p in tensor_list
):
raise RuntimeError(
f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
)
quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
return quantized_tensor_list
def _dequantize_tensor(tensor, qtype, quant_loss=None):
if not isinstance(tensor, torch.Tensor):
raise RuntimeError(
f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
)
if qtype == DQuantType.FP16:
if tensor.dtype != torch.float16:
raise RuntimeError(
f"tensor dtype is {tensor.dtype} while expected to be FP16."
)
elif tensor.dtype == torch.float16 and quant_loss is None:
return tensor.float()
else:
return tensor.float() / quant_loss
elif qtype == DQuantType.BFP16:
if tensor.dtype != torch.float16:
raise RuntimeError(
f"tensor dtype is {tensor.dtype} while expected to be FP16."
)
else:
return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
else:
raise RuntimeError(f"Quantization type {qtype} is not supported")
def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
if not isinstance(tensor_list, list) or not all(
isinstance(p, torch.Tensor) for p in tensor_list
):
raise RuntimeError(
f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
)
dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
return dequantized_tensor_list
def auto_quantize(func, qtype, quant_loss=None):
"""
Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output.
Currently it only supports:
. FP16 and BFP16 quantization method supported for gloo and nccl backends
. all_gather, all_to_all collective ops
Note: BFP16 only supports 2D tensors.
Args:
func (Callable): A function representing collective operations.
qtype (QuantType): Quantization method
quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
Returns:
(Callable): the same collective as func but enables automatic quantization/dequantization.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
group = kwargs.get("group", None)
async_op = kwargs.get("async_op", False)
if async_op is True:
raise RuntimeError("The async_op=True mode is not supported yet.")
if func == dist.all_gather:
tensors = args[0]
input_tensors = _quantize_tensor(args[1], qtype)
out_tensors = _quantize_tensor_list(tensors, qtype)
dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
for i, t in enumerate(
_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)
):
tensors[i] = t
elif func == dist.all_to_all:
tensors = args[0]
input_tensors = _quantize_tensor_list(args[1], qtype)
out_tensors = _quantize_tensor_list(tensors, qtype)
dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
for i, t in enumerate(
_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)
):
tensors[i] = t
elif func == dist.all_to_all_single:
tensors = args[0]
out_splits = kwargs.get("out_splits", None)
in_splits = kwargs.get("in_splits", None)
# Quantizing the input/output tensor
input_tensors = _quantize_tensor(args[1], qtype)
out_tensors = _quantize_tensor(tensors, qtype)
dist.all_to_all_single(
out_tensors, input_tensors, out_splits, in_splits, group=group
)
for i, t in enumerate(
_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)
):
tensors[i] = t
else:
raise RuntimeError(f"The collective op {func} is not supported yet")
return wrapper

View File

@ -0,0 +1,110 @@
# mypy: allow-untyped-defs
from enum import Enum
from functools import partial
import torch.distributed as dist
from . import (
debugging_hooks as debugging,
default_hooks as default,
optimizer_overlap_hooks as optimizer_overlap,
powerSGD_hook as powerSGD,
quantization_hooks as quantization,
)
__all__ = ["DDPCommHookType", "register_ddp_comm_hook"]
def _ddp_comm_hook_wrapper(comm_hook, model, state):
model.register_comm_hook(state, comm_hook)
def _powerSGD_comm_hook_wrapper(
comm_hook,
model,
state,
matrix_approximation_rank,
start_powerSGD_iter=1_000,
):
"""
Wrap PowerSGD communication hook.
To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
which will be wrapped up with other state info.
"""
powerSGD_state = powerSGD.PowerSGDState(
process_group=state,
matrix_approximation_rank=matrix_approximation_rank,
start_powerSGD_iter=start_powerSGD_iter,
)
model.register_comm_hook(powerSGD_state, comm_hook)
class DDPCommHookType(Enum):
"""
Enumerate ``ddp_comm_hooks`` and ``ddp_comm_hook_wrapper`` communucation hook types.
DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example,
you can register allreduce hook by
``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``.
"""
ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
FP16_COMPRESS = partial(
_ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook
)
BF16_COMPRESS = partial(
_ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook
)
QUANTIZE_PER_TENSOR = partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
)
QUANTIZE_PER_CHANNEL = partial(
_ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
)
POWER_SGD = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=1,
)
# Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version,
# but it runs slower and consumes more memory.
POWER_SGD_RANK2 = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=2,
)
# Batching can lead to a faster training at the cost of accuracy.
BATCHED_POWER_SGD = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=1,
)
BATCHED_POWER_SGD_RANK2 = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=2,
)
NOOP = partial(
_ddp_comm_hook_wrapper,
comm_hook=debugging.noop_hook,
)
def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None):
"""
Register ``ddp_comm_hooks`` to DDP model.
Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
to the DDP model. User can specify the type of hook as an enum
``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
be passed to the model.
Uses Python comm hook implementations.
Example::
>>> # xdoctest: +SKIP
>>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
"""
comm_hook_type.value(model=model, state=state)

View File

@ -0,0 +1,458 @@
# mypy: allow-untyped-defs
import weakref
from typing import Any, Callable, List, Optional
import torch
import torch.distributed as dist
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus
from torch.nn.parallel.distributed import DistributedDataParallel
__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"]
# Functional optimizers require passing a list of gradients to their `step()`
# method, and ZeRO requires a functional optimizer to overlap with DDP
# Passing a `None` instead of an actual gradient indicates to the optimizer
# to not update the corresponding parameter
_NO_PARAM_UPDATE: None = None
def _perform_local_step(
bucket: dist.GradBucket,
zero: ZeroRedundancyOptimizer,
rank: int,
):
r"""
Perform a local optimizer step using the gradients provided by ``bucket``.
Arguments:
bucket (dist.GradBucket): the bucket providing the gradients.
zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
instance to perform the :meth:`_local_step`.
rank (int): the calling process's rank.
.. warning::
This function assumes that appropriate synchronization has taken place
so that the bucket's gradients can be used.
"""
overlap_info = zero._overlap_info
bucket_index = bucket.index()
assert (
len(zero.optim.param_groups) == 1
), "Overlapping DDP with ZeRO only supports a single parameter group"
# Construct the `gradients` input for the local optimizer step, which
# expects `None` in a list position to indicate that the corresponding
# parameter should not be updated
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
gradients: List[Optional[torch.Tensor]] = [
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
]
assert (
bucket_index in overlap_info.offsets
), f"Bucket index {bucket_index} was not assigned to rank {rank}"
gradients_offset = overlap_info.offsets[bucket_index]
bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
bucket_offset = bucket_assignment.offset
length = len(bucket_assignment.parameters)
bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length]
for i, grad in enumerate(bucket_gradients):
gradients[gradients_offset + i] = grad
zero._local_step(gradients)
def _broadcast_bucket(
bucket_index: int,
zero: ZeroRedundancyOptimizer,
):
r"""
Broadcasts a bucket's parameters.
Arguments:
bucket_index (int): the index of the bucket corresponding to the
parameters to broadcast.
zero (ZeroRedundancyOptimizer): the calling process's
:class:`ZeroRedundancyOptimizer` instance.
"""
overlap_info = zero._overlap_info
assert (
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
), "`assigned_ranks_per_bucket` is not fully constructed"
# Sort to ensure the same ordering across ranks
assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
assert len(assigned_ranks) > 0, (
f"Bucket {bucket_index} should be " "assigned to at least one rank"
)
for assigned_rank in assigned_ranks:
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
if bucket_index in bucket_assignments:
overlap_info.broadcast_handles.append(
dist.broadcast(
bucket_assignments[bucket_index].tensor,
src=dist.get_global_rank(zero.process_group, assigned_rank),
group=zero.process_group,
async_op=True,
)
)
def _save_ddp_bucket_info(
bucket: dist.GradBucket,
zero: ZeroRedundancyOptimizer,
):
r"""
Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``.
In particular, this function is meant to be called upon seeing each
gradient bucket to use when overlapping, meaning it does not save or compute any global
information.
Arguments:
bucket (dist.GradBucket): the current gradient bucket.
zero (ZeroRedundancyOptimizer): the calling process's
:class:`ZeroRedundancyOptimizer` instance.
"""
overlap_info = zero._overlap_info
bucket_params = bucket.parameters()
assert len(bucket_params) > 0, "Empty bucket"
# Save the parameters in the bucket
overlap_info.params_per_bucket.append(bucket_params)
if overlap_info.shard_buckets:
# Additionally save the bucket size for the assignment heuristic to use
bucket_size = 0
for param in bucket_params:
bucket_size += param.numel()
assert overlap_info.total_size is not None
overlap_info.total_size += bucket_size
def _hook_with_zero_step_setup(
ddp_ref: weakref.ReferenceType,
zero: ZeroRedundancyOptimizer,
bucket: dist.GradBucket,
):
r"""
Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
This means the logic to run in the
hook before the backward pass and optimizer step can actually be
overlapped. This is factored out since it is common to both
:func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
Arguments:
ddp_ref (weakref.ReferenceType): weak reference to the process's
:class:`DistributedDataParallel` instance.
zero (ZeroRedundancyOptimizer): the calling process's
:class:`ZeroRedundancyOptimizer` instance.
bucket (dist.GradBucket): the current gradient bucket.
"""
# Proceed as normal until the DDP buckets have been rebuilt
if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr]
assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED
return
bucket_index = bucket.index()
overlap_info = zero._overlap_info
if overlap_info.status == _OverlapStatus.UNINITIALIZED:
overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0:
# This corresponds to the first bucket of the backward pass
# immediately after all information has been saved, so we
# can perform the delayed ZeRO initialization
zero._init_zero_for_overlap()
else:
# Once DDP buckets have been rebuilt but ZeRO has not been
# properly initialized yet, save the information needed
_save_ddp_bucket_info(bucket, zero)
def hook_with_zero_step(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
ddp: DistributedDataParallel,
zero: ZeroRedundancyOptimizer,
shard_buckets: bool = False,
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
r"""
Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass.
This approach overlaps the optimizer computation and communication with the
backward communication. In particular, the backward computation proceeds
contiguously, and the optimizer computation follows, overlapping with
outstanding backward communication (i.e. all-reduces) and possibly other
optimizer communication (i.e. broadcasts).
The optimizer step computation begins after the last gradient bucket computation has finished.
This approach may be preferred over :meth:`hook_with_zero_step_interleaved`
if communication is relatively slow compared to computation.
Arguments:
hook (Callable[[Any, dist.GradBucket], torch.futures.Future]): the hook
to modify.
ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
instance to use.
zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
instance to use.
shard_buckets (bool): if ``True``, then the assignment of each
:class:`DistributedDataParallel` bucket is partitioned across
possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
across possibly multiple ranks) to approximate uniformity; if
``False``, then each bucket is wholly assigned to a single
:class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
Returns:
The modified hook.
Raises:
ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
RuntimeError: if using any backend other than NCCL/HCCL since currently
Gloo may hang.
.. warning::
Given the way that overlapping :class:`DistributedDataParallel` with
:class:`ZeroRedundancyOptimizer` is currently implemented, the first
two or three training iterations do not perform parameter updates in
the optimizer step, depending on if ``static_graph=False`` or
``static_graph=True``, respectively. This is because it needs
information about the gradient bucketing strategy used by
:class:`DistributedDataParallel`, which is not finalized until the
second forward pass if ``static_graph=False`` or until the third
forward pass if ``static_graph=True``.
"""
if not zero._overlap_with_ddp:
raise ValueError(
"ZeroRedundancyOptimizer must be constructed with "
"`overlap_with_ddp=True` to use this hook properly"
)
ddp_ref = weakref.ref(ddp)
# NOTE: Gloo may hang with this overlapping approach, so we require
# NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr]
if (pg != dist.Backend.NCCL) and (pg != "hccl"):
raise RuntimeError(
"Overlapping DDP with ZeRO using this approach currently requires "
"NCCL/HCCL backend to avoid hangs"
)
if shard_buckets:
zero._overlap_info.shard_buckets = True
zero._overlap_info.total_size = 0
def hook_with_zero_fn(
state: Any,
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
r"""
Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket.
Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket.
The function gives a gradient bucket tensor and
performs additional computation on the iteration that
the :class:`DistributedDataParallel` buckets are rebuilt to collect
information used to implement the modified hook.
Arguments:
state (Any): any state for the hook.
bucket (dist.GradBucket): the :class:`DistributedDataParallel`
gradient bucket.
"""
fut = hook(state, bucket)
_hook_with_zero_step_setup(ddp_ref, zero, bucket)
if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
return fut
overlap_info = zero._overlap_info
bucket_index = bucket.index()
rank = zero.global_rank
assert overlap_info.status == _OverlapStatus.INITIALIZED
assert (
len(overlap_info.assigned_ranks_per_bucket) > bucket_index
), "`assigned_ranks_per_bucket` is not fully constructed"
assigned_to_bucket = (
rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
)
# Save the bucket reference and all-reduce future for the final bucket
if assigned_to_bucket:
overlap_info.bucket_index_to_bucket[bucket_index] = bucket
overlap_info.bucket_index_to_future[bucket_index] = fut
# Check that buckets are indexed incrementally starting from 0 in the
# order of their autograd hooks firing
if len(overlap_info.bucket_indices_seen) > 0:
assert (
overlap_info.bucket_indices_seen[-1] == bucket_index - 1
), "Bucket indices are not in incremental order"
else:
assert bucket_index == 0, "Bucket indices do not start from 0"
overlap_info.bucket_indices_seen.append(bucket_index)
# Directly return the future without any optimizer computation if this
# is not the last bucket
num_buckets = len(overlap_info.params_per_bucket)
is_last_bucket = bucket_index == num_buckets - 1
if not is_last_bucket:
return fut
# Perform partial optimizer step on all buckets after the final
# bucket has been computed
# NOTE: This should not be chained as a callback to the last bucket's
# all-reduce future since that would add synchronization that delays
# all optimizer computation to wait for that last all-reduce
for bucket_index in range(num_buckets):
assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
if rank in assigned_ranks:
# Wait on the bucket's all-reduce future to ensure correct
# gradients
assert bucket_index in overlap_info.bucket_index_to_future, (
f"All-reduce future for bucket {bucket_index} not saved "
f"on rank {rank}"
)
allreduce_future = overlap_info.bucket_index_to_future[bucket_index]
allreduce_future.wait()
# Perform the partial optimizer step
curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index]
_perform_local_step(curr_bucket, zero, rank)
_broadcast_bucket(bucket_index, zero)
# Ensure that all parameter updates are finished before the
# next forward pass
overlap_info.wait_for_broadcasts()
overlap_info.clear_per_iter_info()
return fut
return hook_with_zero_fn
def hook_with_zero_step_interleaved(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
ddp: DistributedDataParallel,
zero: ZeroRedundancyOptimizer,
shard_buckets: bool = False,
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
r"""
Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass
This approach overlaps the optimizer computation and communication with the
backward computation and communication. In particular, once a bucket's
gradients have been computed, the optimizer computation using those
gradients is launched (though the actual computation must wait for the
bucket's all-reduce to complete). This yields an interleaving of all-
reduces and broadcasts in the communication stream.
This approach may be preferred over :meth:`hook_with_zero_step` if
communication is relatively fast compared to computation.
Arguments:
hook (Any * dist.GradBucket -> torch.futures.Future): the hook to
modify.
ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
instance to use.
zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
instance to use.
shard_buckets (bool): if ``True``, then the assignment of each
:class:`DistributedDataParallel` bucket is partitioned across
possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
across possibly multiple ranks) to approximate uniformity; if
``False``, then each bucket is wholly assigned to a single
:class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
Returns:
The modified hook.
Raises:
ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
RuntimeError: if using any backend other than NCCL since currently
Gloo may hang.
.. warning::
Given the way that overlapping :class:`DistributedDataParallel` with
:class:`ZeroRedundancyOptimizer` is currently implemented, the first
two or three training iterations do not perform parameter updates in
the optimizer step, depending on if ``static_graph=False`` or
``static_graph=True``, respectively. This is because it needs
information about the gradient bucketing strategy used by
:class:`DistributedDataParallel`, which is not finalized until the
second forward pass if ``static_graph=False`` or until the third
forward pass if ``static_graph=True``.
"""
if not zero._overlap_with_ddp:
raise ValueError(
"ZeroRedundancyOptimizer must be constructed with "
"`overlap_with_ddp=True` to use this hook properly"
)
ddp_ref = weakref.ref(ddp)
# NOTE: Gloo may hang with this overlapping approach, so we require
# NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr]
if (pg != dist.Backend.NCCL) and (pg != "hccl"):
raise RuntimeError(
"Overlapping DDP with ZeRO using this approach currently requires "
"NCCL/HCCL backend to avoid hangs"
)
if shard_buckets:
zero._overlap_info.shard_buckets = True
zero._overlap_info.total_size = 0
def hook_with_zero_interleaved_fn(
state,
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
r"""
Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`.
This function uses the gradients in gradient in given bucket to perform a partial
:class:`ZeroRedundancyOptimizer` :meth:`step`
Arguments:
state: any state for the hook.
bucket (dist.GradBucket): the :class:`DistributedDataParallel`
gradient bucket.
"""
fut = hook(state, bucket)
_hook_with_zero_step_setup(ddp_ref, zero, bucket)
if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
return fut
def zero_step(fut: torch.futures.Future) -> torch.Tensor:
r"""
Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`.
Returns:
A :class:`torch.Tensor` representing the contents of the
gradient bucket.
"""
overlap_info = zero._overlap_info
bucket_index = bucket.index()
rank = zero.global_rank
assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
overlap_info.bucket_indices_seen.append(bucket_index)
if rank in assigned_ranks:
_perform_local_step(bucket, zero, rank)
_broadcast_bucket(bucket_index, zero)
num_buckets = len(overlap_info.params_per_bucket)
if len(overlap_info.bucket_indices_seen) == num_buckets:
# Ensure that all parameter updates are finished before the
# next forward pass
overlap_info.wait_for_broadcasts()
overlap_info.clear_per_iter_info()
return bucket.buffer()
return fut.then(zero_step)
return hook_with_zero_interleaved_fn

View File

@ -0,0 +1,29 @@
from typing import Any
import torch
from torch.distributed import GradBucket
__all__ = ["noop_hook"]
def noop_hook(_: Any, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
"""
Return a future that wraps the input, so it is a no-op that does not incur any communication overheads.
This hook should **only** be used for headroom analysis of allreduce optimization,
instead of the normal gradient synchronization.
For example, if only less than 10% speedup of training time can be observed after this hook is registered,
it usually implies that allreduce is not a performance bottleneck for this case.
Such instrumentation can be particularly useful
if GPU traces cannot be easily retrieved or the trace analysis is complicated
some factors such as the overlap between allreduce and computation or the desynchronization across ranks.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(None, noop_hook)
"""
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(bucket.buffer())
return fut

View File

@ -0,0 +1,225 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, cast, Tuple
import torch
import torch.distributed as dist
__all__ = [
"allreduce_hook",
"fp16_compress_hook",
"bf16_compress_hook",
"fp16_compress_wrapper",
"bf16_compress_wrapper",
]
def _allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future[torch.Tensor]:
"""Average the input gradient tensor by allreduce and returns a future."""
group_to_use = process_group if process_group is not None else dist.group.WORLD
# Apply the division first to avoid overflow, especially for FP16.
tensor.div_(group_to_use.size())
return (
dist.all_reduce(tensor, group=group_to_use, async_op=True)
.get_future()
.then(lambda fut: fut.value()[0])
)
def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
Call ``allreduce`` using ``GradBucket`` tensors.
Once gradient tensors are aggregated across all workers, its ``then``
callback takes the mean and returns the result.
If user registers this DDP communication hook,
DDP results is expected to be same as the case where no hook was registered.
Hence, this won't change behavior of DDP and user can use this as a reference
or modify this hook to log useful information or any other purposes while
unaffecting DDP behavior.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
return _allreduce_fut(process_group, bucket.buffer())
def fp16_compress_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
"""
Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size.
This DDP communication hook implements a simple gradient compression
approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)
and then divides it by the process group size.
It allreduces those ``float16`` gradient tensors. Once compressed gradient
tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
buffer = (
cast(Tuple[torch.Tensor, ...], bucket)[0]
if isinstance(bucket, tuple)
else bucket.buffer()
)
compressed_tensor = buffer.to(torch.float16).div_(world_size)
def decompress(fut):
decompressed_tensor = buffer
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
decompressed_tensor.copy_(value)
return decompressed_tensor
if torch._utils.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
)
return decompress(grad)
else:
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
return fut.then(decompress)
# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress.
def bf16_compress_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
"""
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
This DDP communication hook implements a simple gradient compression
approach that casts ``GradBucket`` tensor to half-precision
`Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_ (``torch.bfloat16``)
and then divides it by the process group size.
It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient
tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
buffer = (
cast(Tuple[torch.Tensor, ...], bucket)[0]
if isinstance(bucket, tuple)
else bucket.buffer()
)
compressed_tensor = buffer.to(torch.bfloat16).div_(world_size)
def decompress(fut):
decompressed_tensor = buffer
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
decompressed_tensor.copy_(value)
return decompressed_tensor
if torch._utils.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
)
return decompress(grad)
else:
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
return fut.then(decompress)
def fp16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
"""
Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to
the input data type, such as ``float32``.
Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``.
Example::
>>> # xdoctest: +SKIP
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
"""
def fp16_compress_wrapper_hook(
hook_state, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
# Cast bucket tensor to FP16.
bucket.set_buffer(bucket.buffer().to(torch.float16))
fut = hook(hook_state, bucket)
def decompress(fut):
decompressed_tensor = bucket.buffer()
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
decompressed_tensor.copy_(fut.value())
return decompressed_tensor
# Decompress after hook has run.
return fut.then(decompress)
return fp16_compress_wrapper_hook
def bf16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
"""
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
`Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16``),
and casts the resulting tensor of the given hook back to the input data type, such as ``float32``.
Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``.
Example::
>>> # xdoctest: +SKIP
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
"""
def bf16_compress_wrapper_hook(
hook_state, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
# Cast bucket tensor to BF16.
bucket.set_buffer(bucket.buffer().to(torch.bfloat16))
fut = hook(hook_state, bucket)
def decompress(fut):
decompressed_tensor = bucket.buffer()
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
decompressed_tensor.copy_(fut.value())
return decompressed_tensor
# Decompress after hook has run.
return fut.then(decompress)
return bf16_compress_wrapper_hook

View File

@ -0,0 +1,88 @@
from dataclasses import dataclass
from typing import Any, no_type_check
import torch
import torch.distributed as dist
from torch.autograd import Variable
from torch.distributed.utils import _free_storage
@dataclass
class _AllreduceUpcastHookState:
"""
State to manage DDP mixed precision in backward / gradient communication.
This contains a weakref to the DDP module for access to reducer and process
group, and a stream to run parameter and gradient upcasts.
"""
ddp_weakref: Any
upcast_stream: torch.cuda.Stream
wait_for_stream_enqueued: bool = False
@no_type_check
def _reducer_allreduce_and_upcast_hook(
hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer.
Performs allreduce in the reduced precision given by DDP's mixed precision
reduce_dtype, and upcasts parameters and gradients to fp32 in preparation
to run the optimizer.
"""
ddp_weakref = hook_state.ddp_weakref
reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group
gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view
# Cast bucket if different than param_dtype.
if (
ddp_weakref().mixed_precision.param_dtype
!= ddp_weakref().mixed_precision.reduce_dtype
):
# Cast bucket tensor to reduce_dtype
bucket.set_buffer(
bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype)
)
fut = reducer._run_allreduce_hook(bucket)
ret_fut = torch.futures.Future()
stream = hook_state.upcast_stream
with torch.cuda.stream(stream):
fut.wait()
bucket.buffer().div_(process_group.size())
ret_fut.set_result(bucket.buffer())
# Upcast parameters and gradients so optimizer step can run in fp32.
params, grads = bucket.parameters(), bucket.gradients()
for p, g in zip(params, grads):
p.data = p._fp_param
# free storage for mp param as it will be allocated again in next
# forward pass.
_free_storage(p._mp_param)
p.grad.data = p.grad.to(p.data.dtype)
# enqueue a callback to wait for this stream at end of backward
def wait_for_stream_cb():
torch.cuda.current_stream().wait_stream(stream)
# Remove post-backward hooks since they are re-installed in next
# iteration, similar to FSDP.
# Parameters that don't require grad still needed to be casted since
# they may participate in computation. However, they would not be recast
# by hook above as they don't have a grad hook installed, so cast them
# back here.
for n, p in ddp_weakref().module.named_parameters():
if hasattr(p, "_ddp_mp_hook_state"):
p._ddp_mp_hook_state[1].remove()
delattr(p, "_ddp_mp_hook_state")
if not p.requires_grad and not hasattr(p, "_ddp_ignored"):
p.data = p._fp_param
# reset for next backward pass
hook_state.wait_for_stream_enqueued = False
if not hook_state.wait_for_stream_enqueued:
Variable._execution_engine.queue_callback(wait_for_stream_cb)
# mark that the callback is enqueued
hook_state.wait_for_stream_enqueued = True
return ret_fut

View File

@ -0,0 +1,160 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, List, no_type_check
import torch
import torch.distributed as dist
from torch.autograd import Variable
__all__: List[str] = []
_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
class _OptimizerHookState:
"""
Holds state for running optimizer in-line after DDP communication hook.
Currently contains only optimizer class which must have a method `step_param`.
"""
__slots__ = ["functional_optimizer", "params_to_optimize"]
def __init__(self, functional_optim, params=None):
self.functional_optimizer = functional_optim
self._check_valid_functional_optim()
self._set_params_to_optimize(params)
def _set_params_to_optimize(self, params):
if params is not None:
self.params_to_optimize = set(params)
def _check_valid_functional_optim(self):
if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME):
raise ValueError(
f"Class {type(self.functional_optimizer)} must implement method "
f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}."
)
@dataclass
class _OptimInBackwardHookState:
optim_stream: torch.cuda.Stream
wait_for_optim_stream_enqueued: bool
@no_type_check
def _apply_optim_in_backward_hook(
gradient_is_bucket_view: bool,
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
r"""
Register hook to apply the optimizer in backward.
If torch.distributed.optim._apply_optimizer_in_backward is used to overlap
optimizer with backward pass, DDP will run the below hook to run optimizer
step for parameters after gradient communication has taken place.
"""
optim_in_bwd_state = _OptimInBackwardHookState(
optim_stream=torch.cuda.Stream(),
wait_for_optim_stream_enqueued=False,
)
def apply_optim_in_backward_hook(
hook_state: Any,
bucket: dist.GradBucket,
optim_stream_state,
) -> torch.futures.Future[torch.Tensor]:
# Run original hook
ddp_weakref = hook_state
ddp_inst = ddp_weakref()
reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
fut = reducer._run_allreduce_hook(bucket)
optimizer_stream = optim_stream_state.optim_stream
with torch.cuda.stream(optimizer_stream):
fut.wait()
# Apply gradient division since C++ side only allreduces and does
# not average. TODO: (rohan-varma) the div factor may be different
# when running with join hook
bucket.buffer().div_(process_group.size())
model_params = bucket.parameters()
grads = bucket.gradients()
# TODO (rohan-varma): upcast as needed for DDP mixed precision,
# once optimizer in backward + DDP mixed precision is supported.
for p, g in zip(model_params, grads):
if hasattr(p, "_in_backward_optimizers"):
# Note: need to set grad to the bucket's grad, because
# running allreduce results in the bucket's grad being
# reduced, but not grad field.
if not gradient_is_bucket_view:
p.grad = g
for optim in p._in_backward_optimizers:
optim.step()
# Need to return a Future[Tensor] to obey comm hook API contract.
ret_fut = torch.futures.Future()
ret_fut.set_result(bucket.buffer())
# enqueue a callback to wait for this optimizer stream at the end of
# backward and set all DDP managed grads to None.
def wait_for_optim_stream_callback():
torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream)
# Set DDP managed grads to None
for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
if hasattr(param, "_in_backward_optimizers"):
param.grad = None
# reset for the next backwards pass
optim_stream_state.wait_for_optim_stream_enqueued = False
if not optim_stream_state.wait_for_optim_stream_enqueued:
Variable._execution_engine.queue_callback(wait_for_optim_stream_callback)
# mark that the callback is enqueued
optim_stream_state.wait_for_optim_stream_enqueued = True
return ret_fut
comm_hook = partial(
apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state
)
# These are needed for DDP's logging of comm hooks
comm_hook.__name__ = apply_optim_in_backward_hook.__name__
comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__
return comm_hook
def _hook_then_optimizer(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
optimizer_state: _OptimizerHookState,
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
r"""Run optimizer in a functional fashion after DDP communication hook."""
has_set_params = (
hasattr(optimizer_state, "params_to_optimize")
and optimizer_state.params_to_optimize is not None
)
def hook_then_optimizer_wrapper(
hook_state, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
# Run original hook
fut = hook(hook_state, bucket)
def optimizer_step(fut):
gradient_tensors = bucket.gradients()
model_params = bucket.parameters()
for grad_tensor, model_param in zip(gradient_tensors, model_params):
if (
not has_set_params
or model_param in optimizer_state.params_to_optimize
):
optimizer_state.functional_optimizer.step_param(
model_param,
grad_tensor,
)
return bucket.buffer()
return fut.then(optimizer_step)
return hook_then_optimizer_wrapper

View File

@ -0,0 +1,124 @@
# mypy: allow-untyped-defs
import logging
import torch
import torch.distributed as dist
from . import default_hooks as default
logger = logging.getLogger(__name__)
class PostLocalSGDState:
r"""
Store state for all-reducing gradients globally until given step, then locally after.
Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
and all-reducing gradients locally using ``subgroup`` afterwards.
If ``process_group`` is ``None``, the global process group will be used.
If ``subgroup`` is ``None``, the intra-node process group on each machine will be used.
Additionally, ``post_local_gradient_allreduce`` may be worth tuning,
because both true and false may give a faster convergence.
"""
__slots__ = [
"process_group",
"subgroup",
"start_localSGD_iter",
"post_local_gradient_allreduce",
"iter",
]
def __init__(
self,
process_group,
subgroup,
start_localSGD_iter,
post_local_gradient_allreduce=True,
):
"""Initialize state object with given parameters and log when localSGD start."""
logger.info(
"Local SGD will be started after %s iterations", start_localSGD_iter
)
# The group used for all-reducing gradients globally.
self.process_group = process_group
# The group used for all-reducing gradients locally.
self.subgroup = subgroup
self.start_localSGD_iter = start_localSGD_iter
# Allreduce gradients locally since iteration `start_localSGD_iter`.
# This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication.
self.post_local_gradient_allreduce = post_local_gradient_allreduce
# Iteration/step in the training loop.
self.iter = 0
def maybe_increase_iter(self, bucket):
"""Track iterations and trigger log message at start of local SGD."""
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.
if bucket.is_last():
self.iter += 1
if self.iter == self.start_localSGD_iter:
logger.info("Start to apply local SGD after %s iterations.", self.iter)
def post_localSGD_hook(
state: PostLocalSGDState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
Run post-localSGD algorithm.
This DDP communication hook is used for running post-localSGD algorithm,
by combining with a model averaging component (e.g.,
:class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`)
that runs after the optimizer step.
Args:
state (PostLocalSGDState): State information to run post-localSGD.
Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD.
bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode,
only exactly one tensor is stored in this bucket.
Returns:
Future handler of the communication, which updates the gradients in place.
Example::
>>> # xdoctest: +SKIP
>>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup,
start_localSGD_iter=10)
>>> ddp_model.register_comm_hook(state, post_localSGD_hook)
>>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``.
>>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module.
"""
global_group_to_use = (
state.process_group if state.process_group is not None else dist.group.WORLD
)
# The input tensor is a flattened 1D tensor.
input_tensor = bucket.buffer()
# Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
if state.iter < state.start_localSGD_iter:
state.maybe_increase_iter(bucket)
return default._allreduce_fut(global_group_to_use, input_tensor)
# If `post_local_gradient_allreduce` is not set,
# then no gradient synchronization after the first `start_localSGD_iter` iterations.
if not state.post_local_gradient_allreduce:
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(input_tensor)
return fut
# Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations.
# Note that by default, a separate subgroup for each node is created which
# causes an intra-node allreduce to be done at each training step.
# From this moment, model averaging should run after the optimizer step,
# to globally allreduce all the parameters.
if state.subgroup is None:
state.subgroup, _ = dist.new_subgroups()
return default._allreduce_fut(state.subgroup, input_tensor)

View File

@ -0,0 +1,856 @@
# mypy: allow-untyped-defs
import logging
import math
from collections import defaultdict
from typing import Dict
import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d
from . import default_hooks as default
__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"]
logger = logging.getLogger(__name__)
def _orthogonalize(matrices, epsilon=0):
"""
Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
"""
assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]
num_matrices = matrices.shape[0]
rank = matrices.shape[2]
dtype = matrices.dtype
if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
_orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
else:
torch.linalg.qr(
matrices,
out=(
matrices,
torch.empty(
num_matrices, rank, rank, device=matrices.device, dtype=dtype
),
),
)
def _orthogonalize_gram_schmidt(matrices, epsilon=0):
"""
Apply Gram-Schmidt procedure to orthogonalize a batch of matrices.
If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`,
"""
num_cols = matrices.shape[2]
for i in range(num_cols):
# Normalize the i'th column.
col = matrices[:, :, i : i + 1]
# If no epsilon is added here, division by zero may be caused by vanishing gradients.
# This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer
# in the neural network.
if epsilon == 0:
# Note that col ** 2 can underflow/overflow if we use FP16.
# May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead.
try:
col /= torch.norm(col, dim=1, keepdim=True)
except ZeroDivisionError:
logger.error(
"The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 "
"as `orthogonalization_epsilon` in PowerSGD state."
)
# Recover the values from NaNs to 0s.
col.fill_(0.0)
else:
col /= torch.norm(col, dim=1, keepdim=True) + epsilon
# Project it on the rest and remove it.
if i + 1 < num_cols:
rest = matrices[:, :, i + 1 :]
rest -= torch.sum(col * rest, dim=1, keepdim=True) * col
def _should_compress(
num_rows, num_cols, matrix_approximation_rank, min_compression_rate
):
"""
Recommend if tensor given is worth compressing.
Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing,
including statistics describing the expected savings from compression. We consider a tensor worth
compressing when ``min_compression_rate`` < uncompressed size / compressed size, where
uncompressed size = ``num_rows`` * ``num_cols``,
and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where:
compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above);
uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and,
compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
""" # noqa: B950
uncompressed_size = num_rows * num_cols
compressed_size = (num_rows + num_cols) * matrix_approximation_rank
return (
compressed_size * min_compression_rate < uncompressed_size,
uncompressed_size,
compressed_size,
)
def _report_compression_stats(bucket, state):
"""Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state."""
if bucket.is_last() and state.iter >= state.next_stats_report:
stats = state.compression_stats()
logger.info(
"Compression stats: iter %s, total before compression %s, total after compression %s, "
"rate %s",
state.iter,
stats[1],
stats[2],
stats[0],
)
state.next_stats_report = state.iter + state.compression_stats_logging_frequency
class PowerSGDState:
r"""
Store both the algorithm's hyperparameters and internal state for all gradients during training.
Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold.
To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps.
3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression.
Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts.
4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.
5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck.
.. warning ::
If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
and this can conflict with any tensor memorized before the rebuild process.
""" # noqa: B950
__slots__ = [
"process_group",
# The fields below are the hyperparameters that often need to be tuned by the user.
"matrix_approximation_rank",
"start_powerSGD_iter",
# The fields below are the hyperparameters that seldom need be tuned by the user.
"min_compression_rate",
"orthogonalization_epsilon",
# The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy.
"use_error_feedback",
"warm_start",
"batch_tensors_with_same_shape",
# The fields below are internal state.
"rng",
"error_dict",
"p_memory_dict",
"q_memory_dict",
"iter",
# The fields below are for recording compression stats.
"total_numel_before_compression",
"total_numel_after_compression",
"compression_stats_logging_frequency",
"next_stats_report",
]
def __init__(
self,
process_group,
matrix_approximation_rank=1,
start_powerSGD_iter=1_000,
min_compression_rate=2,
use_error_feedback=True,
warm_start=True,
orthogonalization_epsilon=0,
random_seed=0,
compression_stats_logging_frequency=10_000,
batch_tensors_with_same_shape: bool = False,
):
logger.info(
"PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; "
"min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; "
"random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s",
matrix_approximation_rank,
start_powerSGD_iter,
min_compression_rate,
orthogonalization_epsilon,
use_error_feedback,
warm_start,
random_seed,
compression_stats_logging_frequency,
batch_tensors_with_same_shape,
)
self.process_group = process_group
self.matrix_approximation_rank = matrix_approximation_rank
# Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages:
# 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
# even if the matrix approximation rank is increased to a large value.
# To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
# (or a more conservative compression such as FP16 compression) with PowerSGD.
# 2) There is an internal optimization of rebuilding buckets process in DDP,
# in order to save the memory space.
# This step takes place after the first iteration.
# However, this means that the shape of input bucketized tensors is subject to change,
# which will complicate the implementations of error feedback and warm-up.
# Running vanilla allreduce in the first few iterations can avoid this complexity.
if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1:
raise ValueError(
"Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
"because PowerSGD can only be applied after the first two iterations in DDP."
)
self.start_powerSGD_iter = start_powerSGD_iter
self.min_compression_rate = min_compression_rate
# Error feedback is usually crucial for both for convergence and generalization,
# because PowerSGD is a biased compressor,
# i.e., compressing and decompressing a random gradient does not yield the original in expectation.
# This mechanism requires a temporary copy of the input gradients,
# so it increases the peak memory consumption by the size of the gradient tensor.
# However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank),
# sometimes it is possible to converge to the optima without error feedback.
# See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf
self.use_error_feedback = use_error_feedback
# Warm-start reuses P(s) and Q(s) from the previous iteration.
# This can improve the approximation quality and hence improve the accuracy.
# Additionally, by avoiding the initialization of these low-rank tensors at every step,
# this can also accelerate training.
# However, this is at the cost of extra memory.
self.warm_start = warm_start
# Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients.
self.orthogonalization_epsilon = orthogonalization_epsilon
# The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
# but in the same order for all the DDP replicas.
# Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
# If the same random projection is used,
# there will be differences between the gradients that are never synchronized.
import numpy as np
self.rng = np.random.RandomState(random_seed)
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket index to the local error.
self.error_dict: Dict[int, torch.Tensor] = {}
self.p_memory_dict: Dict[int, torch.Tensor] = {}
self.q_memory_dict: Dict[int, torch.Tensor] = {}
# Iteration/step in the training loop.
self.iter = 0
# Compression stats accumulators
self.total_numel_before_compression = 0
self.total_numel_after_compression = 0
# We'll report compression stats every 'compression_stats_logging_frequency' iterations
# Note that we always report compression stats at least once.
self.compression_stats_logging_frequency = max(
1, compression_stats_logging_frequency
)
self.next_stats_report = 0
# Batching tensors with same shape can increase parallelism in compression / decompression computation.
# This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however
# this may reduce the overlap between computation and communication, and increase the memory footprint
# due to stacking tensors.
# Turn on if compression / decompression computation is a bottleneck.
self.batch_tensors_with_same_shape = batch_tensors_with_same_shape
def __getstate__(self):
r"""
Return a ``Dict[str, Any]`` which will be pickled and saved.
``process_group`` is not serializable and excluded from
a returned state.
"""
logger.warning(
"NOTE: Process group is not serializable and excluded from a saved state."
)
return {
slot: getattr(self, slot)
for slot in self.__slots__
if slot != "process_group"
}
def __setstate__(self, state):
r"""
Take a provided ``state`` and set to this ``PowerSGDState`` instance.
``process_group`` is set to default.
"""
self.process_group = distributed_c10d._get_default_group()
logger.warning(
"NOTE: Process group will be set to a default group (i.e. the world size).\
If a different group is desired, please set `self.process_group` after PowerSGD state is loaded."
)
for slot, value in state.items():
setattr(self, slot, value)
def maybe_increase_iter(self, bucket):
"""Track iterations and trigger log message at start of local SGD."""
# Since bucket 0 is the last bucket to allreduce in an iteration.
# Only increase `iter` when bucket 0 is processed.
if bucket.is_last():
self.iter += 1
if self.iter == self.start_powerSGD_iter:
logger.info("Start to apply PowerSGD after %s iterations.", self.iter)
def compression_stats(self):
r"""
Return latest compression statistics as tuple.
Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where:
compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression);
numel_before_compression is the total number of elements before compression was applied; and,
numel_after_compression is the total number of elements after compression was applied.
""" # noqa: B950
compress_rate = (
self.total_numel_before_compression / self.total_numel_after_compression
if self.total_numel_after_compression > 0
else 0
)
return (
compress_rate,
self.total_numel_before_compression,
self.total_numel_after_compression,
)
def powerSGD_hook(
state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
r"""
Implement PowerSGD algorithm.
This DDP communication hook implements PowerSGD gradient compression
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
Once gradient tensors are aggregated across all workers, this hook applies
compression as follows:
1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups:
1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth.
1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases).
2. Handles uncompressed tensors:
2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression;
2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor.
3. Handles the tensors that should be compressed by PowerSGD compression:
3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
3.2. Computes each P in Ps, which is equal to MQ;
3.3. Allreduces Ps as a batch;
3.4. Orthogonalizes each P in Ps;
3.5. Computes each Q in Qs, which is approximately equal to M^TP;
3.6. Allreduces Qs as a batch;
3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T.
Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
This not only gives the user more control over the tradeoff between speedup and accuracy,
but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
Args:
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter``
and ``min_compression_rate``.
bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode,
only exactly one tensor is stored in this bucket.
Returns:
Future handler of the communication, which updates the gradients in place.
Example::
>>> # xdoctest: +SKIP
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
start_powerSGD_iter=10, min_compression_rate=0.5)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
""" # noqa: B950
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
# The input tensor is a flattened 1D tensor.
input_tensor = bucket.buffer()
# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
state.maybe_increase_iter(bucket)
return default._allreduce_fut(group_to_use, input_tensor)
# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
dtype = input_tensor.dtype
# Incorporate the error from the previous state into the gradients.
bucket_index = bucket.index()
input_tensor_cp = None
total_length = input_tensor.shape[0]
if state.use_error_feedback:
if bucket_index in state.error_dict:
input_tensor.add_(state.error_dict[bucket_index])
else:
logger.info(
"A zero tensor of length %s that represents local error is created.",
total_length,
)
state.error_dict[bucket_index] = torch.zeros(
total_length, device=device, dtype=dtype
)
# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
# by comparing this copy and the input tensor updated after decompression.
input_tensor_cp = torch.clone(input_tensor).detach()
# Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
tensors = bucket.gradients()
# Step I: Divide all the tensors into two groups,
# one will be compressed before allreduce and the other will be directly allreduced without compression.
tensors_to_compress, uncompressed_tensors = [], []
total_Ps_size = 0
total_Qs_size = 0
for tensor in tensors:
matrix = tensor.view(tensor.shape[0], -1)
n, m = matrix.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
compress_test = _should_compress(
n, m, matrix_approximation_rank, state.min_compression_rate
)
state.total_numel_before_compression += compress_test[1]
if compress_test[0]:
tensors_to_compress.append(matrix)
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
state.total_numel_after_compression += compress_test[2]
else:
uncompressed_tensors.append(tensor)
state.total_numel_after_compression += compress_test[1]
_report_compression_stats(bucket, state)
# Step II: Handle uncompressed tensors.
# Allocate contiguous memory for these tensors to allreduce efficiently.
uncompressed_tensors_memory = (
torch.cat([tensor.view(-1) for tensor in uncompressed_tensors])
if uncompressed_tensors
else torch.tensor([], device=device, dtype=dtype)
)
# Step III: Handle the tensors that should be compressed.
# Allocate contiguous memory for Ps and Qs to allreduce efficiently.
# If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
# The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
need_randomize_qs = False
if not state.warm_start or bucket_index not in state.p_memory_dict:
need_randomize_qs = True
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
if state.warm_start:
logger.info(
"Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.",
total_Ps_size,
total_Qs_size,
)
state.p_memory_dict[bucket_index] = torch.empty(
total_Ps_size, device=device, dtype=dtype
)
state.q_memory_dict[bucket_index] = torch.empty(
total_Qs_size, device=device, dtype=dtype
)
# Batch tensors to compress by shape.
shape_to_tensors = defaultdict(list)
for tensor in tensors_to_compress:
shape_to_tensors[tensor.shape].append(tensor)
# This function decides whether to batch tensors with same shape or not according to the argument,
# so the following process could share the same code.
def maybe_batched_tensors_to_compress():
for tensors in shape_to_tensors.values():
if state.batch_tensors_with_same_shape:
batch_size = len(tensors)
if batch_size == 1:
# Use the original tensor to avoid copy.
yield tensors[0].unsqueeze(0)
else:
yield torch.stack(tensors)
else:
for tensor in tensors:
yield tensor.unsqueeze(0)
# Create Ps and Qs that point to the allocated memory.
tensors_to_compress = []
ps = []
qs = []
p_idx = 0
q_idx = 0
for tensor in maybe_batched_tensors_to_compress():
batch_size, n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
tensors_to_compress.append(tensor)
ps.append(
state.p_memory_dict[bucket_index][
p_idx : p_idx + batch_size * n * matrix_approximation_rank
].view(batch_size, n, matrix_approximation_rank)
)
qs.append(
state.q_memory_dict[bucket_index][
q_idx : q_idx + batch_size * m * matrix_approximation_rank
].view(batch_size, m, matrix_approximation_rank)
)
p_idx += batch_size * n * matrix_approximation_rank
q_idx += batch_size * m * matrix_approximation_rank
# If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
# The exception is the first iteration when PowerSGD is applied.
if not need_randomize_qs:
for q in qs:
_orthogonalize(q, state.orthogonalization_epsilon)
else:
with torch.random.fork_rng(devices=[]):
# Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# This seed should differ at every step.
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q).
torch.manual_seed(state.rng.randint(1_000_000_000))
for q in qs:
q.copy_(
torch.randn(
*q.shape,
device="cpu",
dtype=dtype,
)
)
_orthogonalize(q, state.orthogonalization_epsilon)
# Compute Ps.
for tensor, q, p in zip(tensors_to_compress, qs, ps):
torch.bmm(tensor, q, out=p)
# This allreduce is only applied to uncompressed tensors,
# so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs.
# However, this somehow requires a separate future chain at this time.
allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce(
uncompressed_tensors_memory, group=group_to_use, async_op=True
).get_future()
def unpack_uncompressed_tensors_and_allreduce_ps(fut):
uncompressed_tensors_memory = fut.value()[0].div_(world_size)
idx = 0
for tensor in uncompressed_tensors:
tensor.copy_(
uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor)
)
idx += tensor.numel()
# Since these Ps will be orthogonalized later, no need to divide them by world size.
return (
dist.all_reduce(
state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
)
.get_future()
.wait()[0]
)
def compute_qs(fut):
state.p_memory_dict[bucket_index] = fut.value()
for p in ps:
_orthogonalize(p, state.orthogonalization_epsilon)
# Compute Qs.
for tensor, p, q in zip(tensors_to_compress, ps, qs):
torch.bmm(tensor.transpose(1, 2), p, out=q)
# TODO: The above procedure does two matmul+allreduce steps per iteration --
# one left multiplication and one right multiplication.
# For warm-start, can take one such step at a time, and alternate between them.
# Allreduce Qs.
return (
dist.all_reduce(
state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
)
.get_future()
.wait()[0]
)
def decompress(fut):
state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
for p, q, tensor in zip(ps, qs, tensors_to_compress):
torch.bmm(p, q.transpose(1, 2), out=tensor)
# Copy batched tensors back to original buffer.
if state.batch_tensors_with_same_shape:
for tensor in tensors_to_compress:
if tensor.shape[0] == 1:
# Skip tensor with batch_size == 1 since itself is the original tensor.
continue
original_tensors = shape_to_tensors[tensor.shape[1:]]
for i, original_tensor in enumerate(original_tensors):
original_tensor.copy_(tensor[i])
if torch.cuda.is_available():
torch.cuda.synchronize(device)
if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
if not state.warm_start:
state.p_memory_dict.clear()
state.q_memory_dict.clear()
state.maybe_increase_iter(bucket)
return input_tensor
return (
allreduce_contiguous_uncompressed_tensors_fut.then(
unpack_uncompressed_tensors_and_allreduce_ps
)
.then(compute_qs)
.then(decompress)
)
def batched_powerSGD_hook(
state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
r"""
Implement simplified PowerSGD algorithm.
This DDP communication hook implements a simplified PowerSGD gradient compression
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
This variant does not compress the gradients layer by layer,
but instead compresses the flattened input tensor that batches all the gradients.
Therefore, it is **faster** than :meth:`powerSGD_hook`,
but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1.
.. warning ::
Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy,
because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
Therefore, the user should always consider :meth:`powerSGD_hook` first,
and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1.
Once gradient tensors are aggregated across all workers, this hook applies
compression as follows:
1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
3. Computes P, which is equal to MQ;
4. Allreduces P;
5. Orthogonalizes P;
6. Computes Q, which is approximately equal to M^TP;
7. Allreduces Q;
8. Computes M, which is approximately equal to PQ^T.
9. Truncates the input tensor to the original length.
Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
This not only gives the user more control over the tradeoff between speedup and accuracy,
but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
Args:
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode,
only exactly one tensor is stored in this bucket.
Returns:
Future handler of the communication, which updates the gradients in place.
Example::
>>> # xdoctest: +SKIP
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
""" # noqa: B950
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
# The input tensor is a flattened 1D tensor.
input_tensor = bucket.buffer()
# Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
if state.iter < state.start_powerSGD_iter:
state.maybe_increase_iter(bucket)
return default._allreduce_fut(group_to_use, input_tensor)
# Apply PowerSGD after `start_powerSGD_iter` iterations.
device = input_tensor.device
total_length = input_tensor.shape[0]
state.total_numel_before_compression += total_length
# View the input tensor as a 2D square-shape tensor, and pad 0s if necessary.
square_side_length = math.ceil(math.sqrt(total_length))
state.total_numel_after_compression += (
square_side_length * state.matrix_approximation_rank * 2
)
padded_total_length = square_side_length**2
input_tensor.resize_(padded_total_length)
input_tensor[total_length:padded_total_length].fill_(0)
_report_compression_stats(bucket, state)
# Incorporate the error from the previous state into the gradients.
bucket_index = bucket.index()
input_tensor_cp = None
if state.use_error_feedback:
if bucket_index in state.error_dict:
input_tensor.add_(state.error_dict[bucket_index])
else:
logger.info(
"A zero tensor of length %s that represents local error is created.",
padded_total_length,
)
state.error_dict[bucket_index] = torch.zeros(
padded_total_length, device=device, dtype=input_tensor.dtype
)
# Keep a copy of the input tensor,
# so that we can compute the local error caused by compression later,
# by comparing this copy and the input tensor updated after decompression.
input_tensor_cp = torch.clone(input_tensor).detach()
matrix = input_tensor.view(square_side_length, square_side_length)
# Reuse P and Q from the previous iteration if possible.
# The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied.
if not state.warm_start or bucket_index not in state.p_memory_dict:
# If warm-start is disabled, low-rank tensors will be initialized at every step.
# Only log this if warm-start to avoid spamming.
if state.warm_start:
logger.info(
"Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.",
square_side_length,
state.matrix_approximation_rank,
)
def create_low_rank_tensor(fill_random_values, rng):
"""Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank."""
if fill_random_values:
with torch.random.fork_rng(devices=[]):
# Fork this RNG to avoid changing the seed globally and affecting the random sampling
# anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# This seed should differ at every step.
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
torch.manual_seed(rng.randint(1_000_000_000))
return torch.randn(
square_side_length,
state.matrix_approximation_rank,
device="cpu",
dtype=input_tensor.dtype,
).to(device)
else:
return torch.empty(
square_side_length,
state.matrix_approximation_rank,
device=device,
dtype=input_tensor.dtype,
)
state.p_memory_dict[bucket_index] = create_low_rank_tensor(
fill_random_values=False, rng=state.rng
)
state.q_memory_dict[bucket_index] = create_low_rank_tensor(
fill_random_values=True, rng=state.rng
)
_orthogonalize(state.q_memory_dict[bucket_index])
torch.matmul(
matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index]
)
allreduce_p_fut = dist.all_reduce(
state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
).get_future()
def compute_q(fut):
state.p_memory_dict[bucket_index] = fut.value()[0]
_orthogonalize(state.p_memory_dict[bucket_index])
torch.matmul(
matrix.t(),
state.p_memory_dict[bucket_index],
out=state.q_memory_dict[bucket_index],
)
# TODO: The above procedure does two matmul+allreduce steps per iteration --
# one left multiplication and one right multiplication.
# For warm-start, can take one such step at a time, and alternate between them.
return (
dist.all_reduce(
state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
)
.get_future()
.wait()[0]
)
def decompress(fut):
state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
torch.matmul(
state.p_memory_dict[bucket_index],
state.q_memory_dict[bucket_index].t(),
out=matrix,
)
if state.use_error_feedback:
# Memorize the local errors.
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
# Removing this seemingly unnecessary sync somehow may cause failures.
# See: https://github.com/pytorch/pytorch/pull/54838
if torch.cuda.is_available():
torch.cuda.synchronize(device)
if not state.warm_start:
state.p_memory_dict.clear()
state.q_memory_dict.clear()
ret = input_tensor.resize_(total_length)
state.maybe_increase_iter(bucket)
return ret
return allreduce_p_fut.then(compute_q).then(decompress)

View File

@ -0,0 +1,218 @@
# mypy: allow-untyped-defs
import torch
import torch.distributed as dist
from torch import nn
def _quantize_per_tensor_cuda(x, scale, zero_point):
y = torch.round(x / scale) + zero_point
y = torch.clamp(y, 0, 255).to(torch.uint8)
return y
def _dequantize_per_tensor_cuda(y, scale, zero_point):
x = scale * (y.to(torch.float32) - zero_point)
return x
def _quantize_per_channel_cuda(x, scale, zero_point):
y = torch.zeros(x.size(), device=x.device)
for i in range(x.size()[0]):
y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i]
y = torch.clamp(y, 0, 255).to(torch.uint8)
return y
def _dequantize_per_channel_cuda(y, scale, zero_point):
y = y.to(torch.float32).cuda(y.device)
x = torch.zeros_like(y, device=y.device)
for i in range(x.size()[0]):
x[i, :] = scale[i] * (y[i, :] - zero_point[i])
return x
def _get_allgather_out_list(all_gather_in_list, world_size):
out_list = [
torch.zeros_like(
all_gather_in_list,
device=all_gather_in_list.device,
dtype=all_gather_in_list.dtype,
)
for _ in range(world_size)
]
return out_list
def quantization_pertensor_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
Apply ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol.
Workers first allgather the scale and zero point of their own
``GradBucket`` prior to the quantization. After all workers have that information,
the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
own gradient tensor, and uses ``allgather`` to communicate these across all workers.
The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and
aggregates each quantized gradient tensor locally and returns the mean.
.. warning ::
This is experimental, and uses ``allgather`` protocol which is considerably slower than
``allreduce`` protocol. It works only with flattened grads.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = group_to_use.size()
tensor = bucket.buffer()
myObserver = torch.ao.quantization.MinMaxObserver().cuda(tensor.device)
myObserver(tensor)
s, z = myObserver.calculate_qparams()
s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device)
all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
# First, allgather scale and zeros.
fut = dist.all_gather(
all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
).get_future()
def quantize_and_allgather(fut):
# Store scale and zeros across all workers.
all_ranks_s_and_z = fut.wait()[0]
# All workers quantize their own ``GradBucket`` tensors.
quantized_tensor = _quantize_per_tensor_cuda(
tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1]
)
# Allgather quantized tensors.
fut = dist.all_gather(
_get_allgather_out_list(quantized_tensor, world_size),
quantized_tensor,
group=group_to_use,
async_op=True,
).get_future()
return fut.wait()
def dequantize_and_aggregate(fut):
all_ranks_quantized_tensor = fut.wait()[0]
aggregated_dequantized_tensor = torch.zeros_like(
all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
)
# Using previously allgathered scales and zeros, dequantize gradient tensors
# locally and then aggregate them.
for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
aggregated_dequantized_tensor += _dequantize_per_tensor_cuda(
quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
)
return aggregated_dequantized_tensor / world_size
return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
def quantization_perchannel_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
) -> torch.futures.Future[torch.Tensor]:
"""
Apply``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol.
Compared to per-tensor, the main motivation of per-channel is
for considerably large tensors such as a tensor that contains 6 million
elements quantizing per a bucket size of 512 (or 128) elements may significantly
increase the resolution.
It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size``
elements. Then, workers allgather the scales and zero points of their own
``GradBucket`` prior to the quantization. After all workers have that information,
the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
own gradient tensor, and uses ``allgather`` to communicate these across all workers.
The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and
aggregates each quantized gradient tensor locally and returns the mean.
.. warning ::
This is experimental, and uses ``allgather`` protocol which is considerably slower than
``allreduce`` protocol. It works only with flattened grads.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
world_size = group_to_use.size()
tensor = bucket.buffer()
tensor_in_channels = (
nn.functional.pad(
input=tensor,
pad=(0, bucket_size - len(tensor) % bucket_size),
mode="constant",
value=0,
)
.view(-1, bucket_size)
.cuda(tensor.device)
)
myPerChannelObserver = torch.ao.quantization.PerChannelMinMaxObserver().cuda(
tensor.device
)
myPerChannelObserver(tensor_in_channels)
s_ch, z_ch = myPerChannelObserver.calculate_qparams()
s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device)
all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
# First, allgather scale and zeros.
fut = dist.all_gather(
all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
).get_future()
def quantize_and_allgather(fut):
# Store scale and zeros across all workers.
all_ranks_s_and_z = fut.wait()[0]
# All workers quantize their corresponding ``GradBucket`` tensors.
quantized_tensor = _quantize_per_channel_cuda(
tensor_in_channels,
all_ranks_s_and_z[rank, 0, :],
all_ranks_s_and_z[rank, 1, :],
)
# Allgather quantized tensors.
fut = dist.all_gather(
_get_allgather_out_list(quantized_tensor, world_size),
quantized_tensor,
group=group_to_use,
async_op=True,
).get_future()
return fut.wait()
def dequantize_and_aggregate(fut):
all_ranks_quantized_tensor = fut.wait()[0]
aggregated_dequantized_tensor = torch.zeros_like(
all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
)
# Using previously allgathered scales and zeros, dequantize gradient tensors
# locally and then aggregate them.
for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
aggregated_dequantized_tensor += _dequantize_per_channel_cuda(
quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
)
return (
torch.flatten(aggregated_dequantized_tensor).cuda(tensor.device)[
: tensor.size()[0]
]
/ world_size
)
return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)

View File

@ -0,0 +1,349 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, List, NamedTuple, Optional, Type
import torch
import torch.distributed as dist
__all__ = ["JoinHook", "Joinable", "Join"]
class JoinHook:
r"""
This defines a join hook, which provides two entry points in the join context manager.
Entry points : a main hook, which is called repeatedly while there exists a non-joined
process, and a post-hook, which is called once all processes have joined.
To implement a join hook for the generic join context manager, define a
class that inherits from :class:`JoinHook` and override ``main_hook()`` and
``post_hook()`` as appropriate.
"""
def main_hook(self) -> None:
r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
Training iteration i.e., in one forward pass, backward pass, and optimizer step.
"""
def post_hook(self, is_last_joiner: bool) -> None:
r"""
Call hook after all processes have joined.
It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
Arguments:
is_last_joiner (bool): ``True`` if the rank is one of the last to
join; ``False`` otherwise.
"""
class Joinable(ABC):
r"""
This defines an abstract base class for joinable classes.
A joinable class
(inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
which returns a :class:`JoinHook` instance, in addition to
:meth:`join_device` and :meth:`join_process_group` that return device and
process group information, respectively.
"""
@abstractmethod
def __init__(self) -> None:
super().__init__()
self._join_config = _JoinConfig.construct_disabled_join_config()
@abstractmethod
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a :class:`JoinHook` instance for the given :class:`Joinable`.
Arguments:
kwargs (dict): a :class:`dict` containing any keyword arguments
to modify the behavior of the join hook at run time; all
:class:`Joinable` instances sharing the same join context
manager are forwarded the same value for ``kwargs``.
"""
...
@property
@abstractmethod
def join_device(self) -> torch.device:
r"""Return the device from which to perform collective communications needed by the join context manager."""
...
@property
@abstractmethod
def join_process_group(self) -> Any:
r"""Returns the process group for the collective communications needed by the join context manager itself."""
...
class _JoinConfig(NamedTuple):
r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
enable: bool
throw_on_early_termination: bool
is_first_joinable: bool
@staticmethod
def construct_disabled_join_config():
r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
e.g. if the caller is not in a join context manager.
"""
return _JoinConfig(
enable=False, throw_on_early_termination=False, is_first_joinable=False
)
class Join:
r"""
This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
These hooks should shadow the
collective communications of non-joined processes to prevent hanging and
erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
for details about the hook definition.
.. warning::
The context manager requires each participating :class:`Joinable` to
call the method :meth:`notify_join_context()` before its own per-
iteration collective communications to ensure correctness.
.. warning::
The context manager requires that all ``process_group`` attributes in
the :class:`JoinHook` objects are the same. If there are multiple
:class:`JoinHook` objects, then the ``device`` of the first is used.
The process group and device information is used for checking for non-
joined processes and for notifying processes to throw an exception if
``throw_on_early_termination`` is enabled, both of which using an all-
reduce.
Arguments:
joinables (List[Joinable]): a list of the participating
:class:`Joinable` s; their hooks are iterated over in the given
order.
enable (bool): a flag enabling uneven input detection; setting to
``False`` disables the context manager's functionality and should
only be set when the user knows the inputs will not be uneven
(default: ``True``).
throw_on_early_termination (bool): a flag controlling whether to throw an
exception upon detecting uneven inputs (default: ``False``).
Example::
>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> # xdoctest: +SKIP
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>> dist.init_process_group("nccl", rank=rank, world_size=2)
>>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>> # Rank 1 gets one more input than rank 0
>>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>> with Join([model, optim]):
>>> for input in inputs:
>>> loss = model(input).sum()
>>> loss.backward()
>>> optim.step()
>>> # All ranks reach here without hanging/erroring
"""
def __init__(
self,
joinables: List[Joinable],
enable: bool = True,
throw_on_early_termination: bool = False,
**kwargs,
):
if len(joinables) == 0:
raise ValueError("The join context manager requires at least one joinable")
self._joinables = joinables
self._join_hooks = [
joinable.join_hook(**kwargs) for joinable in self._joinables
]
self._enable = enable
self._throw_on_early_termination = throw_on_early_termination
self._set_joinable_configs()
self._extract_dist_info()
def _set_joinable_configs(self) -> None:
r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
assert len(self._joinables) > 0
is_first_joinable = True
for joinable in self._joinables:
joinable._join_config = _JoinConfig(
enable=self._enable,
throw_on_early_termination=self._throw_on_early_termination,
is_first_joinable=is_first_joinable,
)
is_first_joinable = False
def _extract_dist_info(self) -> None:
r"""
Extract the process group and device information from the joinables.
If there are multiple joinables, then the context manager uses the
first specified device.
Preconditions:
``self._joinables`` is not ``None`` and is non-empty.
Raises:
ValueError
If there are multiple conflicting ``process_group`` attributes
among the ``Joinable`` objects.
"""
process_group = None
device = None
for joinable in self._joinables:
if process_group is None:
process_group = joinable.join_process_group
elif process_group != joinable.join_process_group:
raise ValueError(
"Using join context manager with multiple process groups"
)
if device is None:
device = joinable.join_device
self._process_group = process_group
self._rank = dist.get_rank(self._process_group)
self._device = device
def __enter__(self):
...
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
):
r"""
Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
Raises:
RuntimeError
If ``throw_on_early_termination=True``.
"""
if not self._enable or type:
return # propagate the exception directly if one was raised
all_procs_joined = False
is_last_joiner = True
i = 0
WARN_THRESHOLD = 1000
warnings.simplefilter("once")
while not all_procs_joined:
if i > WARN_THRESHOLD:
warnings.warn(
"Detected uneven input skew of greater than "
f"{WARN_THRESHOLD}. This means that rank "
f"{self._rank} has at least {WARN_THRESHOLD} "
f"fewer inputs than other currently-active ranks. "
"This level of skew could lead to performance "
"degradation during training."
)
# Shadow the all-reduce in non-joined processes
num_nonjoined_procs = self._get_num_nonjoined_procs()
if num_nonjoined_procs == 0:
all_procs_joined = True
else:
if self._throw_on_early_termination:
self._notify_procs_to_terminate()
# Run main hooks
for join_hook in self._join_hooks:
join_hook.main_hook()
is_last_joiner = False
i += 1
# Run post-hooks
for join_hook in self._join_hooks:
join_hook.post_hook(is_last_joiner)
def _get_num_nonjoined_procs(self):
r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
num_nonjoined_procs = torch.zeros(1, device=self._device)
dist.all_reduce(num_nonjoined_procs, group=self._process_group)
return num_nonjoined_procs.item()
def _notify_procs_to_terminate(self):
r"""Schedule an all-reduce to notify non-joined processes to terminate.
Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
"""
ones = torch.ones(1, device=self._device)
dist.all_reduce(ones, group=self._process_group)
raise RuntimeError(f"Rank {self._rank} exhausted all inputs.")
@staticmethod
def notify_join_context(joinable: Joinable):
r"""
Notifies the join context manager that the calling process has not yet joined.
Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
(i.e. if one process has already joined) and throws an exception if so.
This method should be called from a :class:`Joinable` object before
its per-iteration collective communications. For example, this should
be called at the beginning of the forward pass in
:class:`DistributedDataParallel`.
Only the first :class:`Joinable` object passed into the context
manager performs the collective communications in this method, and
for the others, this method is vacuous.
Arguments:
joinable (Joinable): the :class:`Joinable` object calling this
method.
Returns:
An async work handle for the all-reduce meant to notify the context
manager that the process has not yet joined if ``joinable`` is the
first one passed into the context manager; ``None`` otherwise.
"""
assert hasattr(joinable, "_join_config"), (
f"Check that the {type(joinable)} constructor calls the "
"``Joinable`` constructor"
)
join_config = joinable._join_config
# First joinable is responsible for the collective communications
if not join_config.is_first_joinable or not join_config.enable:
return None
device = joinable.join_device
process_group = joinable.join_process_group
# Schedule an all-reduce to indicate that the caller has not yet joined
ones = torch.ones(1, device=device)
work = dist.all_reduce(ones, group=process_group, async_op=True)
if join_config.throw_on_early_termination:
# Check if uneven inputs have been detected
zeros = torch.zeros(1, device=device)
dist.all_reduce(zeros, group=process_group)
should_throw = zeros.item()
if should_throw:
raise RuntimeError(
"Detected at least one rank that exhausted inputs. "
"Throwing across all ranks."
)
return work

View File

@ -0,0 +1,124 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union
import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.utils as utils
__all__ = ["ModelAverager", "PeriodicModelAverager"]
class ModelAverager(ABC):
r"""Base class for all model averagers.
Args:
process_group: The process group to be used for all-reduce.
If ``None``, the default process group, which
is created by :func:`torch.distributed.init_process_group`,
will be used. (default: ``None``)
"""
def __init__(self, process_group=None):
self.process_group = (
process_group if process_group is not None else dist.group.WORLD
)
self.step = 0
@abstractmethod
def average_parameters(self, params):
raise NotImplementedError
class PeriodicModelAverager(ModelAverager):
r"""
Averages parameters periodically after the warm-up stage.
This can be used for running `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
by running :class:`~torch.nn.DistributedDataParallel` (DDP)
using the subgroups created by :meth:`~torch.distributed.new_subgroups`.
Args:
period (int): The number of steps per model averaging.
Usually the period should be greater than ``1`` to reduce the communication cost.
Otherwise, only DDP needs to be used.
warmup_steps (int): The number of warm-up steps. During this stage,
model averaging is skipped.
process_group: The process group to be used for all-reduce.
If ``None``, the default process group, which
is created by :func:`torch.distributed.init_process_group`,
will be used. (default: ``None``)
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>>
>>> dist.init_process_group("nccl", rank=rank, world_size=16)
>>> torch.cuda.set_device(rank)
>>> module = nn.Linear(1, 1, bias=False).cuda()
>>> model = nn.parallel.DistributedDataParallel(
>>> module, device_ids=[rank], output_device=rank
>>> )
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
>>> # After 100 steps, run model averaging every 4 steps.
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> for step in range(0, 200):
>>> optimizer.zero_grad()
>>> loss = loss_fn(output, labels)
>>> loss.backward()
>>> optimizer.step()
>>> # Will average model parameters globally every 4 steps. Thus,
>>> # inter-node communication only occurs every 4 iterations after
>>> # the initial ``warmup_steps`` period.
>>> averager.average_parameters(model.parameters())
"""
def __init__(self, period, warmup_steps=0, process_group=None):
super().__init__(process_group)
if warmup_steps < 0:
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
self.warmup_steps = warmup_steps
if period < 1:
raise ValueError("Arg ``period`` must be a positive value.")
elif period == 1:
warnings.warn(
"When period is 1, no need to use model averaging because the communication cost "
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
"by DistributedDataParallel in the backward pass. Therefore, only "
"DistributedDataParallel should be used for this case."
)
self.period = period
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
],
):
"""
Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``.
Can be divided by ``period``, where ``step`` is increased by 1
at each iteration in the training loop.
Args:
params: The parameters of a model or parameter groups of an optimizer.
"""
if (
self.step >= self.warmup_steps
and (self.step - self.warmup_steps) % self.period == 0
):
utils.average_parameters_or_parameter_groups(params, self.process_group)
self.step += 1

View File

@ -0,0 +1,180 @@
# mypy: allow-untyped-defs
# Copyright 2022 Cruise LLC
import logging
import warnings
from collections import OrderedDict
from typing import Dict, Iterable, Union
import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.utils as utils
logger = logging.getLogger(__name__)
class HierarchicalModelAverager(averagers.ModelAverager):
r"""
Runs hierarchical model averaging (`hierarchical SGD <https://arxiv.org/pdf/2010.12998.pdf>`_).
Process groups of different sizes are organized in a hierarchy, and they average parameters
by using different periods concurrently after the warm-up stage.
This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports
a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
Similarly, the process groups within this class do not have such an intra-machine process
subgroup, which should be embedded by the post-local SGD communication hook instead.
Args:
period_group_size_dict: An ordered dict mapping keys of model averaging period to
process group size, used for initializing process groups of
different sizes in a hierarchy to average parameters concurrently.
Particularly, at each iteration, there will be at most a single
process group that runs averaging -- the period of such group should
have the largest period which the current step can be divided by.
For example, if the dict has three keys: 2, 4, and 8,
then this means totally three process groups will be created to
average parameters every 2, 4, and 8 iterations, respectively.
At the 4th iteration, only the second process group will run
averaging, because the first process group should be a
subset of the second process group, and no need to execute the first
process group redundantly.
On the other hand, the third process group can only be triggered
every 8 iterations, so it will not be triggered at the 4th iteration.
warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
If ``None``, the default process group, which is created
by :func:`torch.distributed.init_process_group`, will be used.
(default: ``None``)
Example::
>>> # xdoctest: +SKIP('undefined rank')
>>> from collections import OrderedDict
>>> import torch
>>> import torch.distributed as dist
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>> PostLocalSGDState,
>>> post_localSGD_hook,
>>> )
>>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
>>> import torch.nn as nn
>>>
>>> dist.init_process_group("nccl", rank=rank, world_size=16)
>>> torch.cuda.set_device(rank)
>>> module = nn.Linear(1, 1, bias=False).to(rank)
>>> model = nn.parallel.DistributedDataParallel(
>>> module, device_ids=[rank], output_device=rank
>>> )
>>> # Register a post-localSGD communication hook.
>>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
>>> subgroup, _ = dist.new_subgroups()
>>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Average parameters among each group of 8 processes every 4 iterations, and among all
>>> # the 16 processes every 16 iterations.
>>> averager = hierarchicalSGD.HierarchicalModelAverager(
>>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
>>> # After 100 steps, run model averaging at two levels.
>>> for step in range(0, 200):
>>> optimizer.zero_grad()
>>> loss = loss_fn(output, labels)
>>> loss.backward()
>>> optimizer.step()
>>> # Average parameters after ``optimizer.step()``.
>>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
>>> averager.average_parameters(model.parameters())
.. warning ::
The last group size in the dict must be the size of the provided ``process_group``,
which indicates model averaging at the highest level of the hierarchy.
If ``process_group`` is not provided, then the last group size should be equal to the world size.
.. warning ::
`HierarchicalModelAverager` is experimental and subject to change.
"""
def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
super().__init__(process_group)
if not period_group_size_dict:
raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
self._periods = list(period_group_size_dict.keys())
if self._periods[0] <= 0:
raise ValueError(
"The minimum period in arg ``period_group_size_dict`` must be a positive value."
)
elif self._periods[-1] == 1:
warnings.warn(
"When the maximum period in arg ``period_group_size_dict`` is 1, "
"no need to use model averaging because the communication cost "
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
"by DistributedDataParallel in the backward pass. Therefore, only "
"DistributedDataParallel should be used for this case."
)
overall_group_size = dist.get_world_size(group=self.process_group)
if list(period_group_size_dict.values())[-1] != overall_group_size:
raise ValueError(
f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} "
f"must be equal to the size of arg ``process_group`` {overall_group_size}."
)
self.period_process_group_dict = OrderedDict()
logger.info("Model averaging hierarchy:")
for period, group_size in period_group_size_dict.items():
logger.info(
"\tEach group that has %s processes average parameters every %s iterations, "
"if no higher-level averaging.",
group_size,
period,
)
if group_size != overall_group_size:
self.period_process_group_dict[period], _ = dist.new_subgroups(
group_size=group_size, group=self.process_group
)
else:
self.period_process_group_dict[period] = self.process_group
if warmup_steps < 0:
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
self.warmup_steps = warmup_steps
def _find_process_group(self):
"""
Return a process group as the value of an ``period_process_group_dict`` entry.
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
then the returned process group is the one corresponding to the largest period,
since this process group will be used for averaging parameters at this ``step``.
Returns ``None`` if not found.
"""
for period in reversed(self._periods):
if self.step % period == 0:
return self.period_process_group_dict[period]
return None
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
],
):
"""
Averages parameters or parameter groups of an optimizer.
Averaging only occurs if ``step`` is no less than ``warmup_steps``
and it can be divided by a period in the keys of ``period_process_group_dict``,
where ``step`` is increased by 1 at each iteration in the training loop.
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
only the largest period is used, and the corresponding process group is used for averaging parameters.
Args:
params: The parameters of a model or parameter groups of an optimizer.
"""
if self.step >= self.warmup_steps:
group = self._find_process_group()
if group is not None:
utils.average_parameters_or_parameter_groups(params, group)
self.step += 1

View File

@ -0,0 +1,89 @@
# mypy: allow-untyped-defs
# flake8: noqa C101
import itertools
from typing import Dict, Iterable, Iterator, Union
import torch
import torch.distributed as dist
# The two imports below are not always available depending on the
# USE_DISTRIBUTED compile flag. Make sure they raise import error
# if we're trying to use them.
from torch.distributed import group, ProcessGroup
__all__ = [
"average_parameters",
"get_params_to_average",
"average_parameters_or_parameter_groups",
]
def average_parameters(
params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
):
"""
Averages all the given parameters.
For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
Thus, it requires extra memory of the same size as the given parameters.
"""
group_to_use = process_group if process_group is not None else group.WORLD
# Do not update any parameter if not in the process group.
if dist._rank_not_in_group(group_to_use):
return
params_it1, params_it2 = itertools.tee(params)
# If the input parameters have different data types,
# packing these parameters will trigger an implicit type up-casting.
# The original parameter data types will be restored during the subsequent unpacking.
flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
flat_params /= dist.get_world_size(group_to_use)
# Make sure the allreduce will not conflict with any other ongoing process group.
if torch.cuda.is_available():
torch.cuda.synchronize()
dist.all_reduce(flat_params, group=group_to_use)
offset = 0
for p in params_it2:
p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
offset += p.numel()
def get_params_to_average(
params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]
):
"""
Return a list of parameters that need to average.
This filters out the parameters that do not contain any gradients.
Args:
params: The parameters of a model or parameter groups of an optimizer.
"""
filtered_params = []
for param in params:
if isinstance(param, torch.nn.Parameter):
# model.parameters() input
param_data = param
if param_data.grad is not None:
filtered_params.append(param_data)
elif isinstance(param, dict):
# optimizer.param_groups input
for param_data in param["params"]:
if param_data.grad is not None:
filtered_params.append(param_data)
else:
raise NotImplementedError(
f"Parameter input of type {type(param)} is not supported"
)
return filtered_params
def average_parameters_or_parameter_groups(
params: Union[
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
],
process_group: ProcessGroup,
):
"""Averages parameters of a model or parameter groups of an optimizer."""
average_parameters(iter(get_params_to_average(params)), process_group)