I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,157 @@
# mypy: allow-untyped-defs
import logging
import pdb
import sys
import traceback
import typing
import torch
log = logging.getLogger(__name__)
def is_available() -> bool:
"""
Return ``True`` if the distributed package is available.
Otherwise,
``torch.distributed`` does not expose any other APIs. Currently,
``torch.distributed`` is available on Linux, MacOS and Windows. Set
``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
``USE_DISTRIBUTED=0`` for MacOS.
"""
return hasattr(torch._C, "_c10d_init")
if is_available() and not torch._C._c10d_init():
raise RuntimeError("Failed to initialize torch.distributed")
# Custom Runtime Errors thrown from the distributed package
DistError = torch._C._DistError
DistBackendError = torch._C._DistBackendError
DistNetworkError = torch._C._DistNetworkError
DistStoreError = torch._C._DistStoreError
if is_available():
from torch._C._distributed_c10d import (
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_ControlCollectives,
_DEFAULT_FIRST_BUCKET_BYTES,
_make_nccl_premul_sum,
_register_builtin_comm_hook,
_register_comm_hook,
_StoreCollectives,
_test_python_store,
_verify_params_across_processes,
Backend as _Backend,
BuiltinCommHookType,
DebugLevel,
FileStore,
get_debug_level,
GradBucket,
Logger,
PrefixStore,
ProcessGroup as ProcessGroup,
Reducer,
set_debug_level,
set_debug_level_from_env,
Store,
TCPStore,
Work as _Work,
)
class _DistributedPdb(pdb.Pdb):
"""
Supports using PDB from inside a multiprocessing child process.
Usage:
_DistributedPdb().set_trace()
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
_breakpoint_cache: typing.Dict[int, typing.Any] = {}
def breakpoint(rank: int = 0, skip: int = 0):
"""
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
done with the breakpoint before continuing.
Args:
rank (int): Which rank to break on. Default: ``0``
skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``.
"""
if skip > 0:
key = hash(str(traceback.format_exc()))
counter = _breakpoint_cache.get(key, 0) + 1
_breakpoint_cache[key] = counter
if counter <= skip:
log.warning("Skip the breakpoint, counter=%d", counter)
return
if get_rank() == rank:
pdb = _DistributedPdb()
pdb.message(
"\n!!! ATTENTION !!!\n\n"
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
)
pdb.set_trace()
# If Meta/Python keys are in the TLS, we want to make sure that we ignore them
# and hit the (default) CPU/CUDA implementation of barrier.
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
torch._C._set_meta_in_tls_dispatch_include(False)
try:
barrier()
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
del guard
if sys.platform != "win32":
from torch._C._distributed_c10d import HashStore
from .device_mesh import DeviceMesh, init_device_mesh
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
_all_gather_base,
_coalescing_manager,
_CoalescingManager,
_create_process_group_wrapper,
_get_process_group_name,
_rank_not_in_group,
_reduce_scatter_base,
get_node_local_rank,
)
from .remote_device import _remote_device
from .rendezvous import (
_create_store_from_options,
register_rendezvous_handler,
rendezvous,
)
set_debug_level_from_env()
else:
# This stub is sufficient to get
# python test/test_public_bindings.py -k test_correct_module_names
# working even when USE_DISTRIBUTED=0. Feel free to add more
# stubs as necessary.
# We cannot define stubs directly because they confuse pyre
class _ProcessGroupStub:
pass
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]

View File

@ -0,0 +1,38 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Any, Protocol, runtime_checkable
import torch
@runtime_checkable
class _Checkpointable(Protocol): # noqa: PYI046
"""
Interface for checkpointable objects.
Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly.
This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
"""
def __create_write_items__(self, fqn: str, object: Any):
"""
Return a list of WriteItems based on object's contents.
"""
raise NotImplementedError(
"_Checkpointable._create_write_items is not implemented"
)
def __create_chunk_list__(self):
"""
Return a list of `ChunkStorageMetadata` based on object's contents.
"""
raise NotImplementedError(
"_Checkpointable._create_chunk_list is not implemented"
)
def __get_tensor_shard__(self, index) -> torch.Tensor:
"""
Return a 'torch.Tensor' shard based on 'MetadataIndex'.
"""
raise NotImplementedError(
"_Checkpointable._get_tensor_shard is not implemented"
)

View File

@ -0,0 +1,4 @@
from .checkpoint_activation import checkpoint
from .contract import _get_registry, contract
from .fully_shard import fully_shard
from .replicate import replicate

View File

@ -0,0 +1,126 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from contextlib import contextmanager, nullcontext
from typing import Any, ContextManager, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.utils.checkpoint import (
_checkpoint_without_reentrant_generator,
_DEFAULT_DETERMINISM_MODE,
)
from .contract import contract
@contextmanager
def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None):
r"""
Disable hooks installed by checkpoint to avoid unintentional recursion
during backward recomputation.
"""
with user_ctx if user_ctx else nullcontext():
orig_enable_hook = checkpoint.state(module).enable_hook
checkpoint.state(module).enable_hook = False
try:
yield
finally:
checkpoint.state(module).enable_hook = orig_enable_hook
@contract()
def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
r"""
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing model
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
this one does not modify model structure or fully-qualified names either.
Under the hood, it registers activation checkpointing logic as pre- and
post-forward hooks. Hence, this API can be easily applied to any model or
sub-modules in the model.
Args:
module (nn.Module): the target model or sub-module to apply activation
checkpointing.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> model = MyModel()
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
>>> model(torch.zeros(2, 10)).sum().backward()
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint")
use_reentrant = kwargs.pop("use_reentrant", False)
if use_reentrant:
raise NotImplementedError(
"use_reentrant=True is not supported in composable checkpoint. "
"Please use torch.utils.checkpoint.checkpoint instead."
)
preserve_rng_state = kwargs.pop("preserve_rng_state", True)
user_context_fns = kwargs.pop("context_fn", None)
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
debug = kwargs.pop("debug", False)
if kwargs:
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
def forward_pre_hook(
module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> None:
if checkpoint.state(module).enable_hook:
def context_fns():
if user_context_fns is not None:
ctx1, ctx2 = user_context_fns()
return ctx1, _no_hook(module, ctx2)
else:
return nullcontext(), _no_hook(module)
checkpoint.state(
module
)._ac_generator = _checkpoint_without_reentrant_generator(
module,
preserve_rng_state,
context_fns,
determinism_check,
debug,
*args,
**kwargs,
)
next(checkpoint.state(module)._ac_generator)
def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
if checkpoint.state(module).enable_hook:
try:
next(checkpoint.state(module)._ac_generator)
except StopIteration:
pass
else:
raise RuntimeError(
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
)
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
# clear this even in the case of exception in fwd pass.
checkpoint.state(module)._ac_generator = None
checkpoint.state(module).enable_hook = True
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
return module

View File

@ -0,0 +1,224 @@
# mypy: allow-untyped-defs
import uuid
from collections import OrderedDict
from functools import wraps
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
import torch
import torch.nn as nn
from torch.distributed._composable_state import _State
from torch.distributed.utils import _get_root_modules
def generate_state_key(string="__composable_api_state_key"):
return f"{string}_{str(uuid.uuid4())}"
STATE_KEY = generate_state_key()
REGISTRY_KEY = generate_state_key()
# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
# we can add args and kwargs here, and then we can detect whether fully_shard
# is combined with reentrant activation checkpointing and error out with a clear
# message.
class RegistryItem:
pass
def contract(state_cls: Type[_State] = _State):
r"""
Decorate a function as a composable distributed API, where the first
argument of the function must be an :class:`nn.Module` instance or sequence
of :class:`nn.Module` instances.
The decorator verifies that the decorated function does not modify
fully-qualified names (FQNs) for parameters, buffers, or modules. The
decorated function can return different module instances than the input
modules; the FQN invariant will be enforced following the input order.
When a function ``func`` is decorated by ``@contract()``, a
``.state(module: nn.Module)`` method will be installed to the decorated
function. Then you can retrieve and modify the state on a module by calling
``func.state(module)``.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> @contract()
>>> def my_feature(module: nn.Module) -> nn.Module:
>>> my_feature.state(module).some_state = "any value"
>>> return module
>>>
>>> model = MyModel()
>>> my_feature(model.l1)
>>> assert my_feature.state(model.l1).some_state == "any value"
>>> my_feature(model.l2)
>>> model(torch.randn(2, 10)).sum().backward()
"""
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
@wraps(state_cls)
def inner(func):
@wraps(func)
def wrapper(
module: Union[nn.Module, Sequence[nn.Module]], *args, **kwargs
) -> Optional[nn.Module]:
inp_module = module
if isinstance(module, nn.Module):
modules = [module]
else:
# If the user passes a sequence of modules, then we assume that
# we only need to insert the state object on the root modules
# (i.e. those without a parent) among the passed-in modules.
modules = _get_root_modules(list(module))
state = state_cls() # shared across all modules
registry_item = RegistryItem() # shared across all modules
# `func` is allowed to return different module instances than the
# input modules as long as FQNs are preserved following the input
# module order
all_orig_named_params: List[Dict[str, nn.Parameter]] = []
all_orig_named_buffers: List[Dict[str, torch.Tensor]] = []
all_orig_named_modules: List[Dict[str, nn.Module]] = []
for module in modules:
default_all_state: Dict[Callable, _State] = OrderedDict()
default_registry: Dict[str, RegistryItem] = OrderedDict()
all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY, default_all_state
)
if not isinstance(all_state, dict):
raise AssertionError(
f"Distributed composable API states corrupted: {all_state}"
)
registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
REGISTRY_KEY, default_registry
)
if not isinstance(registry, dict):
raise AssertionError(
f"Distributed composable API registry corrupted: {registry}"
)
if func in all_state or func.__name__ in registry:
raise AssertionError(
"Each distinct composable distributed API can only be applied to a "
f"module once. {func.__name__} has already been applied to the "
f"following module:\n{module}"
)
all_state.setdefault(func, state)
registry.setdefault(func.__name__, registry_item)
all_orig_named_params.append(OrderedDict(module.named_parameters()))
all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
all_orig_named_modules.append(OrderedDict(module.named_modules()))
updated = func(inp_module, *args, **kwargs)
if updated is None:
updated = inp_module
if isinstance(updated, nn.Module):
updated_modules = [updated]
else:
updated_modules = _get_root_modules(list(inp_module))
all_new_named_params: List[Dict[str, nn.Parameter]] = []
all_new_named_buffers: List[Dict[str, torch.Tensor]] = []
all_new_named_modules: List[Dict[str, nn.Module]] = []
for module in updated_modules:
all_new_named_params.append(OrderedDict(module.named_parameters()))
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
all_new_named_modules.append(OrderedDict(module.named_modules()))
num_orig_modules = len(all_orig_named_modules)
num_new_modules = len(all_new_named_modules)
if num_orig_modules != num_new_modules:
raise AssertionError(
f"{func.__name__} should return the same number of modules as input modules"
f"Inputs: {num_orig_modules} modules\n"
f"Outputs: {num_new_modules} modules"
)
def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
if orig_fqns == new_fqns:
return
orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
orig_only = orig_fqn_set - new_fqn_set
new_only = new_fqn_set - orig_fqn_set
if len(orig_only) or len(new_only):
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify FQNs.\n"
f"FQNs only in original: {orig_only}\n"
f"FQNs only in new: {new_only}"
)
else:
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify "
"the order of FQNs.\n"
f"Original FQNs: {orig_only}\n"
f"New FQNs: {new_only}"
)
for orig_named_params, new_named_params in zip(
all_orig_named_params, all_new_named_params
):
check_fqn(
list(orig_named_params.keys()),
list(new_named_params.keys()),
"Checking parameters: ",
)
for orig_named_buffers, new_named_buffers in zip(
all_orig_named_buffers, all_new_named_buffers
):
check_fqn(
list(orig_named_buffers.keys()),
list(new_named_buffers.keys()),
"Checking buffers: ",
)
for orig_named_modules, new_named_modules in zip(
all_orig_named_modules, all_new_named_modules
):
check_fqn(
list(orig_named_modules.keys()),
list(new_named_modules.keys()),
"Checking modules: ",
)
# TODO: verify that installed distributed paradigms are compatible with
# each other.
return updated
def get_state(module: nn.Module) -> Optional[_State]:
return module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY,
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
).get(
func
) # type: ignore[call-overload]
wrapper.state = get_state # type: ignore[attr-defined]
return wrapper
return inner
def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
r"""
Get an ``OrderedDict`` of composable APIs that have been applied to the
``module``, indexed by the API name. If no API has been applied, then this
returns ``None``.
"""
return getattr(module, REGISTRY_KEY, None)

View File

@ -0,0 +1,2 @@
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method

View File

@ -0,0 +1,80 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass(frozen=True)
class MixedPrecisionPolicy:
"""
This configures FSDP's mixed precision. Unlike autocast, this applies mixed
precision at the module level, not op level, which means low-precision
activations are saved for backward and high-to-low-precision casts are
incurred only at module boundaries.
FSDP works well with module-level mixed precision since it keeps the
high-precision sharded parameters in memory anyway. In other words, FSDP
does not require any extra memory to keep a high-precision copy of the
parameters for the optimizer step.
Attributes:
param_dtype (Optional[torch.dtype]): This specifies the dtype for
the unsharded parameter and hence the dtype for forward/backward
computation and the parameter all-gather. If this is ``None``, then
the unsharded parameter uses the original dtype. The optimizer step
uses the sharded parameter in the original dtype. (Default:
``None``)
reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
gradient reduction (i.e. reduce-scatter or all-reduce). If this is
``None`` but ``param_dtype`` is not ``None``, then the reduction
uses the compute dtype. This can be used to run gradient reduction
in full precision while using low precision for compute. If also
gradient reduction is disabled via :meth:`set_requires_gradient_sync`,
then FSDP will accumulate gradients using ``reduce_dtype``.
(Default: ``None``)
output_dtype (Optional[torch.dtype]): This specifies the dtype for
casting floating-point forward outputs. This can be used to
help implement cases where different modules have different mixed
precision policies. (Default: ``None``)
cast_forward_inputs (bool): This specifies whether FSDP should cast the
forward's floating-point input tensors to ``param_dtype`` or not.
"""
param_dtype: Optional[torch.dtype] = None
reduce_dtype: Optional[torch.dtype] = None
output_dtype: Optional[torch.dtype] = None
cast_forward_inputs: bool = True
def __post_init__(self):
# Clamp `reduce_dtype` to `None` if no casting is required: since
# gradients are computed in `param_dtype`, if `reduce_dtype` matches,
# then we do not need extra casting
if self.param_dtype == self.reduce_dtype:
# Bypass the frozen dataclass checks
object.__setattr__(self, "reduce_dtype", None)
@dataclass
class OffloadPolicy:
"""This base class represents the policy of no offloading."""
@dataclass
class CPUOffloadPolicy(OffloadPolicy):
"""
This offload policy offloads parameters, gradients, and optimizer states to
CPU. Sharded parameters are copied host-to-device before all-gather. The
all-gathered parameters are freed according to ``reshard_after_forward``.
Sharded gradients are copied device-to-host in backward, and the optimizer
step runs on CPU with CPU optimizer states.
Attributes:
pin_memory (bool): Whether to pin sharded parameter and gradient
memory. Pinning memory allows H2D/D2H copying without blocking the
CPU and in turn, overlap with compute, but pinned memory cannot be
used by other processes. Set this to ``False`` if you have
insufficient CPU memory. (Default: ``True``)
"""
pin_memory: bool = True

View File

@ -0,0 +1,477 @@
# mypy: allow-untyped-decorators
from typing import cast, List, NamedTuple, Optional, Tuple, Union
import torch
import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.tensor import DTensor
from ._fsdp_common import (
_get_dim0_padded_size,
_raise_assert_with_print,
_to_dtype_if_needed,
)
from ._fsdp_param import FSDPParam, ShardedState
class AllGatherResult(NamedTuple):
all_gather_output: torch.Tensor
all_gather_event: Optional[torch.cuda.Event]
all_gather_work: Optional[dist.distributed_c10d.Work]
# For each parameter, the all-gather input dtype for each input
param_all_gather_input_dtypes: List[List[torch.dtype]]
# For each parameter, the all-gather input numel for each input
param_all_gather_input_numels: List[List[int]]
# 1D flattened version of `param_all_gather_input_numels` saved to avoid
# CPU overhead from recomputing
all_gather_input_split_sizes: List[int]
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib.define(
"""
all_gather_copy_in(
Tensor[] all_gather_inputs,
SymInt[] inp_split_sizes,
SymInt all_gather_input_numel,
SymInt world_size,
SymInt rank,
ScalarType dtype,
Device device
) -> (Tensor, Tensor)
"""
)
@torch.library.impl(lib, "all_gather_copy_in", "Meta")
def all_gather_copy_in_meta(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device="meta"
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
return all_gather_input, all_gather_output
@torch.library.impl(lib, "all_gather_copy_in", "CUDA")
@torch.library.impl(lib, "all_gather_copy_in", "CPU")
def all_gather_copy_in_cuda(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
with torch.no_grad():
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
return all_gather_input, all_gather_output
lib.define(
"split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
)
@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
@torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: List[int],
dim: int,
out: List[torch.Tensor],
) -> None:
torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
)
lib.define(
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
)
@torch.library.impl(lib, "chunk_cat", "Meta")
@torch.library.impl(lib, "chunk_cat", "CUDA")
@torch.library.impl(lib, "chunk_cat", "CPU")
def chunk_cat(
tensors: List[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
) -> None:
torch._chunk_cat(tensors, dim, num_chunks, out=out)
@torch.no_grad()
def foreach_all_gather(
fsdp_params: List[FSDPParam],
group: dist.ProcessGroup,
async_op: bool,
all_gather_copy_in_stream: torch.cuda.Stream,
all_gather_stream: torch.cuda.Stream,
device: torch.device,
) -> Optional[AllGatherResult]:
world_size, rank = group.size(), group.rank()
with torch.cuda.stream(all_gather_copy_in_stream):
param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
(
param_all_gather_input_dtypes,
param_all_gather_input_numels,
dtype,
) = _get_all_gather_input_metadatas(param_all_gather_inputs)
if dtype == torch.uint8:
all_gather_inputs = [
t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts
]
else:
all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts]
inp_split_sizes = [t.numel() for t in all_gather_inputs]
all_gather_input_numel = sum(inp_split_sizes)
all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
all_gather_inputs,
inp_split_sizes,
all_gather_input_numel,
world_size,
rank,
dtype,
device,
)
del param_all_gather_inputs
all_gather_stream.wait_stream(all_gather_copy_in_stream)
with torch.cuda.stream(all_gather_stream):
all_gather_work = dist.all_gather_into_tensor(
output_tensor=all_gather_output,
input_tensor=all_gather_input,
group=group,
async_op=async_op,
)
all_gather_event = all_gather_stream.record_event()
return AllGatherResult(
all_gather_output,
all_gather_event,
all_gather_work,
param_all_gather_input_dtypes,
param_all_gather_input_numels,
inp_split_sizes,
)
@torch.no_grad()
def _get_param_all_gather_inputs(
fsdp_params: List[FSDPParam],
) -> List[List[torch.Tensor]]:
if ca.compiled_autograd_enabled:
return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params]
# Intentionally try to run a fast-path that bypasses abstractions for the
# common FSDP case of bf16/fp32 mixed precision in order to use foreach
# copy for lower CPU overhead and more efficient copying in eager
def use_foreach_copy(fsdp_param: FSDPParam) -> bool:
return (
fsdp_param.param_dtype is not None
and not fsdp_param.offload_to_cpu
and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather")
)
param_all_gather_inputs: List[List[torch.Tensor]] = [[] for _ in fsdp_params]
foreach_copy_indices: List[int] = []
foreach_copy_inputs: List[torch.Tensor] = []
foreach_copy_input_numels: List[int] = []
# 1st pass: for foreach-copy parameters, get inputs and metadata for the
# foreach copy, and for the others, actually get their all-gather inputs
for i, fsdp_param in enumerate(fsdp_params):
if use_foreach_copy(fsdp_param):
foreach_copy_indices.append(i)
all_gather_input = (
fsdp_param._sharded_param_data
if fsdp_param.sharded_state == ShardedState.SHARDED
else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data)
)
foreach_copy_inputs.append(all_gather_input)
foreach_copy_input_numels.append(all_gather_input.numel())
else:
param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
# 2nd pass: use foreach copy to compute the remaining all-gather inputs
if foreach_copy_inputs:
fsdp_param_0 = fsdp_params[foreach_copy_indices[0]]
param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device
flat_foreach_copy_input = torch.empty(
(sum(foreach_copy_input_numels),), device=device, dtype=param_dtype
)
splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels)
torch._foreach_copy_(splits, foreach_copy_inputs)
for i, split in zip(foreach_copy_indices, splits):
param_all_gather_inputs[i] = [split]
return param_all_gather_inputs
@torch.no_grad()
def foreach_all_gather_copy_out(
all_gather_result: AllGatherResult,
fsdp_params: List[FSDPParam],
group: dist.ProcessGroup,
) -> None:
(
all_gather_output,
all_gather_event,
all_gather_work,
param_all_gather_input_dtypes,
param_all_gather_input_numels,
all_gather_input_split_sizes,
) = all_gather_result
if all_gather_event is not None: # sync op
torch.cuda.current_stream().wait_event(all_gather_event)
if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op
all_gather_work.wait()
world_size, device = group.size(), all_gather_output.device
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
):
if ca.compiled_autograd_enabled:
fsdp_param.init_all_gather_outputs(
all_gather_input_numels,
all_gather_input_dtypes,
world_size,
device,
# NOTE: Under compile, make sure we always recreate all_gather_outputs
# per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2].
force_recreate=True,
)
else:
fsdp_param.init_all_gather_outputs(
all_gather_input_numels, all_gather_input_dtypes, world_size, device
) # no-op after 1st call
fsdp_param.alloc_all_gather_outputs()
all_gather_output = all_gather_output.view(world_size, -1)
gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs)
if all_gather_output.dtype == torch.uint8:
out = [t.view(world_size, -1).view(torch.uint8) for t in gen]
else:
out = [t.view(world_size, -1) for t in gen]
torch.ops.fsdp.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
)
@torch.no_grad()
def foreach_reduce(
fsdp_params: List[FSDPParam],
unsharded_grads: List[torch.Tensor],
reduce_scatter_group: dist.ProcessGroup,
reduce_scatter_stream: torch.cuda.Stream,
orig_dtype: torch.dtype,
reduce_dtype: Optional[torch.dtype],
device: torch.device,
reduce_scatter_reduce_op: Optional[Union[dist.ReduceOp, dist.ReduceOp.RedOpType]],
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
all_reduce_stream: torch.cuda.Stream,
all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
) -> Tuple[torch.Tensor, torch.cuda.Event, torch.cuda.Event, Optional[torch.Tensor]]:
"""
``unsharded_grads`` owns the references to the gradients computed by
autograd, so clearing the list frees the gradients.
"""
grad_dtypes = {grad.dtype for grad in unsharded_grads}
if len(grad_dtypes) != 1:
# Check this at runtime since it could be a real runtime error if e.g.
# fp8 weights do not produce the correct higher precision gradients
_raise_assert_with_print(
f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}"
)
grad_dtype = unsharded_grads[0].dtype
reduce_dtype = reduce_dtype or grad_dtype
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
reduce_scatter_group, all_reduce_group, reduce_dtype
)
world_size = reduce_scatter_group.size()
padded_unsharded_sizes = tuple(
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
)
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
reduce_scatter_input = torch.empty(
(reduce_scatter_input_numel,), dtype=reduce_dtype, device=device
)
foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size)
current_stream = torch.cuda.current_stream()
# Only after the copy-in finishes can we free the gradients
unsharded_grads.clear()
reduce_scatter_stream.wait_stream(current_stream)
with torch.cuda.stream(reduce_scatter_stream):
reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
_div_if_needed(reduce_scatter_input, predivide_factor)
if reduce_scatter_reduce_op is None:
if predivide_factor is None:
reduce_scatter_reduce_op = ReduceOp.AVG
else:
reduce_scatter_reduce_op = ReduceOp.SUM
dist.reduce_scatter_tensor(
output=reduce_output,
input=reduce_scatter_input,
group=reduce_scatter_group,
op=reduce_scatter_reduce_op,
)
reduce_scatter_event = reduce_scatter_stream.record_event()
post_reduce_stream = reduce_scatter_stream
if all_reduce_group is not None: # HSDP
# Accumulations must run in the reduce-scatter stream
if not all_reduce_grads:
if partial_reduce_output is not None:
partial_reduce_output += reduce_output
else:
partial_reduce_output = reduce_output
return (
reduce_scatter_input,
reduce_scatter_event,
post_reduce_stream.record_event(),
partial_reduce_output,
)
if partial_reduce_output is not None:
reduce_output += partial_reduce_output
post_reduce_stream = all_reduce_stream
all_reduce_stream.wait_stream(reduce_scatter_stream)
with torch.cuda.stream(all_reduce_stream):
dist.all_reduce(
reduce_output,
group=all_reduce_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
)
with torch.cuda.stream(post_reduce_stream):
_div_if_needed(reduce_output, postdivide_factor)
reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype)
# View out and accumulate sharded gradients
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
for padded_unsharded_size, fsdp_param in zip(
padded_unsharded_sizes, fsdp_params
):
new_sharded_grad = torch.as_strided(
reduce_output,
size=fsdp_param.sharded_size,
stride=fsdp_param.contiguous_sharded_stride,
storage_offset=flat_grad_offset,
)
to_accumulate_grad = fsdp_param.sharded_param.grad is not None
if fsdp_param.offload_to_cpu:
# Only overlap the D2H copy (copying to pinned memory) if not
# accumulating gradients since the CPU add kernel depends on
# the copy result and we cannot run the add as a callback
non_blocking = fsdp_param.pin_memory and not to_accumulate_grad
# Since the GPU sharded gradient is allocated in the RS stream,
# we can free it here by not keeping a ref without waiting for
# the D2H copy since future RS-stream ops run after the copy
new_sharded_grad = new_sharded_grad.to(
torch.device("cpu"), non_blocking=non_blocking
)
if non_blocking:
# Record an event on which to block the CPU thread to
# ensure that the D2H copy finishes before the optimizer
fsdp_param.grad_offload_event = reduce_scatter_stream.record_event()
if to_accumulate_grad:
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
else:
new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(
new_sharded_grad
)
fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
if not ca.compiled_autograd_enabled:
for hook in (
getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {})
or {}
).values():
hook(fsdp_param.sharded_param)
padded_sharded_numel = padded_unsharded_size.numel() // world_size
flat_grad_offset += padded_sharded_numel
post_reduce_event = post_reduce_stream.record_event()
# The RS output is allocated in the RS stream and used in the default
# stream (for optimizer). To ensure its memory is not reused for later
# RSs, we do not need extra synchronization since the sharded parameters
# hold refs through the end of backward.
return reduce_scatter_input, reduce_scatter_event, post_reduce_event, None
def foreach_reduce_scatter_copy_in(
unsharded_grads: List[torch.Tensor],
reduce_scatter_input: torch.Tensor,
world_size: int,
) -> None:
reduce_scatter_input = reduce_scatter_input.view(world_size, -1)
torch.ops.fsdp.chunk_cat(
unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input
)
def _get_all_gather_input_metadatas(
param_all_gather_inputs: List[List[torch.Tensor]],
) -> Tuple[List[List[torch.dtype]], List[List[int]], torch.dtype]:
param_all_gather_input_dtypes: List[List[torch.dtype]] = []
param_all_gather_input_numels: List[List[int]] = []
all_gather_dtype = param_all_gather_inputs[0][0].dtype
for all_gather_inputs in param_all_gather_inputs:
input_dtypes: List[torch.dtype] = []
input_numels: List[int] = []
for all_gather_input in all_gather_inputs:
if all_gather_input.dtype != all_gather_dtype:
all_gather_dtype = torch.uint8
input_dtypes.append(all_gather_input.dtype)
input_numels.append(all_gather_input.numel())
param_all_gather_input_dtypes.append(input_dtypes)
param_all_gather_input_numels.append(input_numels)
return (
param_all_gather_input_dtypes,
param_all_gather_input_numels,
all_gather_dtype,
)
def _get_gradient_divide_factors(
reduce_scatter_group: dist.ProcessGroup,
all_reduce_group: Optional[dist.ProcessGroup],
reduce_dtype: torch.dtype,
) -> Union[Tuple[None, None], Tuple[float, float]]:
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
# use NCCL's built-in division to avoid separate div kernels
if reduce_dtype in (torch.float32, torch.bfloat16):
return None, None
data_parallel_size = reduce_scatter_group.size()
if all_reduce_group is not None:
data_parallel_size *= all_reduce_group.size()
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
# overflow/underflow. For N data parallel workers, each worker computes
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
factor: int = 1
while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
factor *= 2
factor = float(factor)
return (factor, data_parallel_size / factor)
def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:
if div_factor is not None and div_factor > 1:
tensor.div_(div_factor)

View File

@ -0,0 +1,152 @@
# mypy: allow-untyped-defs
import math
import traceback
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, cast, List, Optional
import torch
import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import _get_registry
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@dataclass
class DataParallelMeshInfo:
mesh: DeviceMesh
shard_mesh_dim: Optional[int] = None
replicate_mesh_dim: Optional[int] = None
def __post_init__(self):
if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
raise AssertionError(
"At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
)
@dataclass
class FSDPMeshInfo(DataParallelMeshInfo):
def __post_init__(self):
super().__post_init__()
if self.shard_mesh_dim is None:
raise AssertionError("Expects non-None shard_mesh_dim")
self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim)
self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim)
self.shard_mesh_rank: int = self.shard_process_group.rank()
@dataclass
class DDPMeshInfo(DataParallelMeshInfo):
def __post_init__(self):
super().__post_init__()
if self.replicate_mesh_dim is None:
raise AssertionError("Expects non-None replicate_mesh_dim")
self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim)
self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim)
self.replicate_mesh_rank: int = self.replicate_process_group.rank()
@dataclass
class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
def __post_init__(self):
# Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
super().__post_init__()
class TrainingState(Enum):
"""Describes the training state of one FSDP state / parameter group."""
# Transition to forward starting pre-forward until post-forward
FORWARD = auto()
# Transition to pre-backward when unsharding in backward
PRE_BACKWARD = auto()
# Transition to post-backward when resharding and reducing gradients
POST_BACKWARD = auto()
# Idle before/after forward or before pre-backward/after post-backward
IDLE = auto()
def _raise_assert_with_print(*args: Any, **kwargs: Any):
print(f"[Rank {dist.get_rank()}] ", end="")
print(*args, **kwargs)
traceback.print_stack()
raise AssertionError(*args, **kwargs)
def _is_composable_with_fsdp(module: nn.Module) -> bool:
registry = _get_registry(module)
if registry is None:
return True
# Registry keys by function name
return "replicate" not in registry
def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size:
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
return cast(torch.Size, torch.Size([padded_dim0]) + tensor_size[1:])
def _chunk_with_empty(
tensor: torch.Tensor, num_chunks: int, dim: int
) -> List[torch.Tensor]:
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
while len(chunks) < num_chunks:
chunks.append(chunks[0].new_empty(0))
return chunks
def _get_dim0_chunked_size(
chunk: torch.Tensor, unchunked_size: torch.Size
) -> torch.Size:
if chunk.numel() > 0:
return chunk.size()
# For 0 numel, we need to preserve trailing dims for DTensor APIs
return cast(torch.Size, torch.Size([0]) + unchunked_size[1:])
def _from_local_no_grad(
local_tensor: torch.Tensor,
sharding_spec: DTensorSpec,
) -> DTensor:
"""
This method is similar to ``DTensor.from_local()`` except that in eager mode
it avoids some CPU overhead by avoiding default args and not being differentiable.
"""
if not ca.compiled_autograd_enabled:
return DTensor(
# Use the local tensor directly instead of constructing a new tensor
# variable, e.g. with `view_as()`, since this is not differentiable
local_tensor,
sharding_spec,
requires_grad=local_tensor.requires_grad,
)
else:
return DTensor.from_local(
local_tensor,
sharding_spec.mesh,
sharding_spec.placements,
shape=sharding_spec.shape,
stride=sharding_spec.stride,
)
def _to_dtype_if_needed(
tensor: torch.Tensor, dtype: Optional[torch.dtype]
) -> torch.Tensor:
if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype)
return tensor
def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
if (
not isinstance(x, torch.Tensor)
or not torch.is_floating_point(x)
or x.dtype == dtype
):
return x
return x.to(dtype)

View File

@ -0,0 +1,168 @@
import itertools
from typing import List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
from ._fsdp_state import _get_module_fsdp_state
def _get_post_forward_mesh_info(
reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo
) -> Optional[FSDPMeshInfo]:
shard_mesh_size = mesh_info.shard_mesh_size
if not isinstance(reshard_after_forward, (bool, int)):
raise ValueError(
"reshard_after_forward should be a bool or an int representing the "
f"group size to reshard to, not {reshard_after_forward}"
)
# NOTE: `isinstance(False, int)` returns `True`.
if not isinstance(reshard_after_forward, bool) and isinstance(
reshard_after_forward, int
):
if (
reshard_after_forward < 1
or reshard_after_forward > shard_mesh_size
or shard_mesh_size % reshard_after_forward != 0
):
raise ValueError(
"If passing reshard_after_forward as an int, it should be a "
f"factor of {shard_mesh_size}, not {reshard_after_forward}"
)
elif reshard_after_forward == 1:
reshard_after_forward = False
elif reshard_after_forward == shard_mesh_size:
reshard_after_forward = True
post_forward_mesh_info = None
if reshard_after_forward is True:
post_forward_mesh_info = mesh_info
elif reshard_after_forward is not False: # int case
# For HSDP, we can flatten the two replicate dims into the 0th dim
post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward)
post_forward_mesh = DeviceMesh(
mesh_info.mesh.device_type, post_forward_mesh_tensor
)
post_forward_mesh_info = HSDPMeshInfo(
post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0
)
return post_forward_mesh_info
def _init_default_fully_shard_mesh() -> DeviceMesh:
"""Default to global CUDA mesh if possible else global CPU mesh."""
if not dist.distributed_c10d.is_initialized():
dist.distributed_c10d.init_process_group()
default_pg = dist.distributed_c10d._get_default_group()
device_type = "cuda" if torch.cuda.is_available() else "cpu"
mesh = init_device_mesh(device_type, mesh_shape=(default_pg.size(),))
return mesh
def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
if mesh.device_type == "cpu":
return torch.device("cpu")
device_handle = _get_device_handle(mesh.device_type)
return torch.device(mesh.device_type, device_handle.current_device())
def _get_managed_modules(root_modules: Tuple[nn.Module, ...]) -> List[nn.Module]:
modules: List[nn.Module] = []
root_modules_set = set(root_modules)
# Track visisted modules to avoid visiting shared modules multiple times
visited_modules: Set[nn.Module] = set()
def dfs(module: nn.Module) -> None:
"""
Runs a DFS to collect managed modules, not recursing into modules with
a non-composable API or ``fully_shard`` already applied.
"""
if not _is_composable_with_fsdp(module):
return
elif (
module not in root_modules_set
and _get_module_fsdp_state(module) is not None
):
return # nested `fully_shard` module
visited_modules.add(module)
for submodule in module.children():
if submodule not in visited_modules:
dfs(submodule)
modules.append(module)
for root_module in root_modules:
dfs(root_module)
return modules
def _verify_managed_param(name: str, param: nn.Parameter) -> None:
"""
Verify if the parameter is accepted by fully_shard. The only restriction now
is that the parameter cannot be a scalar tensor (param.numel == 0) since we
need at least one dim to shard.
"""
if len(param.shape) == 0:
raise ValueError(
"fully_shard doesn't support salar parameters. "
f"Change {name} to a 1D tensor with numel equal to 1."
)
def _get_managed_states(
modules: List[nn.Module],
) -> Tuple[List[nn.Parameter], List[torch.Tensor]]:
params: List[nn.Parameter] = []
buffers: List[torch.Tensor] = []
# Track visited parameters/buffers to avoid visiting shared parameters and
# buffers multiple times
visited_params: Set[nn.Parameter] = set()
visited_buffers: Set[torch.Tensor] = set()
for module in modules:
for name, param in module.named_parameters(recurse=False):
if param not in visited_params:
_verify_managed_param(name, param)
params.append(param)
visited_params.add(param)
for buffer in module.buffers(recurse=False):
if buffer not in visited_buffers:
buffers.append(buffer)
visited_buffers.add(buffer)
return params, buffers
def _move_states_to_device(
params: List[nn.Parameter],
buffers: List[torch.Tensor],
device: torch.device,
) -> None:
"""
We have FSDP move states to device for simpler and faster initialization
since FSDP almost always uses CUDA for training. We move parameters/buffers
rather than modules since modules to support ignoring parameters/buffers in
the future.
"""
# Follow the logic in `nn.Module._apply`
for tensor in itertools.chain(params, buffers):
if tensor.device == device or tensor.device.type == "meta":
# Keep meta-device tensors on meta device for deferred init
continue
if isinstance(tensor, DTensor):
if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type:
raise ValueError(
"Requires DTensor to have mesh of the same type as the FSDP mesh "
f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP"
)
raise AssertionError(
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
)
tensor_ = tensor
if is_traceable_wrapper_subclass(tensor_):
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
tensor_on_device = nn.Parameter(tensor.to(device))
torch.utils.swap_tensors(tensor, tensor_on_device)
else:
tensor.data = tensor.to(device)

View File

@ -0,0 +1,754 @@
# mypy: allow-untyped-defs
import itertools
from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Any, cast, List, Optional, Sequence, Tuple
import torch
import torch._dynamo.compiled_autograd as ca
import torch.nn as nn
from torch._prims_common import make_contiguous_strides_for
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.device_mesh import _mesh_resources
from torch.distributed.tensor.placement_types import _StridedShard, Placement
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_common import (
_chunk_with_empty,
_from_local_no_grad,
_get_dim0_chunked_size,
_raise_assert_with_print,
_to_dtype_if_needed,
FSDPMeshInfo,
HSDPMeshInfo,
)
"""
[Note: FSDP tensors]
FSDP considers the following tensors:
- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one
on the module when applying FSDP
- Sharded parameter: sharding the original parameter on dim-0 as a DTensor
over the main mesh
- All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather,
derived from the sharded parameter
- All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from
all-gathering the all-gather inputs
- Unsharded parameter: parameter used for forward/backward computation, derived
from the all-gather output; autograd leaf
We define these tensors to describe the general framework that can accomodate
extensions, where:
- all-gather-inputs = pre-all-gather-transform(sharded-parameter)
- unsharded-parameter = post-all-gather-transform(all-gather-outputs)
For the default ``torch.Tensor`` case, there is only one all-gather input, and
it shares the same underlying tensor data as the sharded parameter, meaning
that they can be thought of as the same tensors. The same applies for the
all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions,
these equivalences may no longer hold due to the pre/post-all-gather
transforms, and some may have multiple all-gather inputs/outputs (e.g.
quantized data and scales).
[Note: FSDP and autograd]
FSDP dynamically frees and allocates the unsharded parameter. Since autograd
can pack a reference to it or a view to save for backward, we use storage
resizing to implement the freeing/allocation since that preserves the aliasing.
This implies that we construct the unsharded parameter object once and write to
it in-place thereafter. For the default ``torch.Tensor` original parameter
case, the all-gather output and unsharded parameter share the same
data, so we use storage resizing on the all-gather output.
"""
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()")
@torch.library.impl(lib, "set_", "Meta")
@torch.library.impl(lib, "set_", "CUDA")
@torch.library.impl(lib, "set_", "CPU")
def set_(tensor, data):
tensor.set_(data)
"""
[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)]
Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op
(i.e. they show up as a mutation op in the middle of the AOT joint graph).
Reason:
Traceable FSDP2 compiled autograd BWD graph have the following traits:
(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param).
(3) They are both subclasses.
The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
So this doesn't work at all for Traceable FSDP2.
The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops.
This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
We can avoid this functionalization because:
(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created),
so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream.
(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places?
A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process.
Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner
(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to
make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph.
This requires a custom FX pass but we believe it's not hard to write and maintain.
Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors?
A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use,
so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input.
This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original
nn.Parameter in order to see the result of .set_.
"""
@torch.library.impl(lib, "set_", "Functionalize")
def set__functionalize(tensor, data):
torch._sync(tensor)
torch._sync(data)
# AOTDispatcher needs to know if any inputs had their storages mutated.
# (Why? It sometimes detaches inputs before sending them into the graph,
# when it sees that they do not need to have any gradients computed)
torch._functionalize_set_storage_changed(tensor)
tensor_inner = torch._from_functional_tensor(tensor)
data_inner = torch._from_functional_tensor(data)
with torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
):
torch.ops.fsdp.set_.default(tensor_inner, data_inner)
torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default)
class ShardedState(Enum):
"""
- ``SHARDED``: The sharded parameter is registered to the module. It is the
only contributor to parameter memory.
- ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a
smaller world size. Since this data should not be used for computation,
we do not register it to the module. Users should reshard the module
before any in-place modifications. Both it and the sharded parameter
contribute to parameter memory.
- ``UNSHARDED``: The unsharded parameter is registered to the module. Both
it and the sharded parameter contribute to parameter memory.
"""
SHARDED = auto()
SHARDED_POST_FORWARD = auto()
UNSHARDED = auto()
@dataclass
class ParamModuleInfo:
"""
For a parameter, this stores the module and the parameter name to be able
to do a parameter swap via ``setattr(module, param_name, ...)`` or to get
the parameter via ``getattr(module, param_name)``. We additionally save
shared modules and shared parameter names to update them accordingly.
"""
# Parameter names are unprefixed, e.g. "weight", not "lin.weight"
module: nn.Module
param_name: str
shared_modules: List[nn.Module] = field(default_factory=list)
shared_param_names: List[str] = field(default_factory=list)
@dataclass
class ExtensionsData:
# User-defined metadata passed from pre to post-all-gather
all_gather_metadata: Optional[Any] = None
# Save the all-gather input sizes to unflatten the all-gather outputs to ND
all_gather_input_sizes: Sequence[torch.Size] = () # ND
def clear(self):
self.all_gather_metadata = None
self.all_gather_input_sizes = ()
class FSDPParam:
"""
This class manages a parameter with FSDP or FSDP variants applied,
implementing dim-0 per-parameter sharding.
"""
orig_dtype: torch.dtype
param_dtype: Optional[torch.dtype]
reduce_dtype: Optional[torch.dtype]
_orig_size: torch.Size # ND
sharded_size: torch.Size # ND
contiguous_sharded_stride: Tuple[int, ...]
padded_sharded_param_size: torch.Size # ND
sharded_post_forward_size: torch.Size # ND
contiguous_sharded_post_forward_stride: Tuple[int, ...]
_sharded_param_data: torch.Tensor # 1D
sharded_param: nn.Parameter # ND
_sharded_post_forward_param_data: Optional[torch.Tensor] # 1D
_sharded_post_forward_param: Optional[nn.Parameter] # ND
_unsharded_param: nn.Parameter # ND
unsharded_accumulated_grad: Optional[torch.Tensor] # ND
_sharding_spec: DTensorSpec
# DTensor attributes (only defined for DTensor `param`):
_tp_spec: DTensorSpec
all_gather_outputs: List[torch.Tensor] # 1D
# All-gather extension attributes
_extensions_data: ExtensionsData
_unsharded_inner_tensors: List[torch.Tensor]
def __init__(
self,
param: nn.Parameter,
module_info: ParamModuleInfo,
mesh_info: FSDPMeshInfo,
post_forward_mesh_info: Optional[FSDPMeshInfo],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
offload_policy: OffloadPolicy,
):
self._module_info: ParamModuleInfo = module_info
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
self.device = device
self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
self.pin_memory = (
self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
)
self.grad_offload_event: Optional[torch.cuda.Event] = None
self._init_sharded_param(param, device)
if self.post_forward_mesh_info:
self._init_sharded_post_forward_param_metadata(param)
self._init_extensions()
self.all_gather_outputs: List[torch.Tensor] = []
self.unsharded_accumulated_grad = None
self._param_fqn: Optional[str] = None # prefixed from root module
# TODO: Remove this padding logic once DTensor pads the local tensor:
# https://github.com/pytorch/pytorch/issues/113045
self._post_load_hook_handle = (
module_info.module.register_load_state_dict_post_hook(
lambda *args, **kwargs: self.reset_sharded_param()
)
)
@torch.no_grad()
def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
if param.device != device and param.device.type != "meta":
raise AssertionError(
f"Expects the parameter to already be moved to device {device} but got {param.device}"
)
# TODO: Replace the sharded DTensor parameter construction logic with
# `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101
# TODO: Simplify the following sharded parameter padding logic after
# https://github.com/pytorch/pytorch/issues/113045
self.is_dtensor = isinstance(param, DTensor)
if self.is_dtensor:
self._tp_spec = cast(DTensor, param)._spec
dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh)
tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh)
if dp_global_mesh != tp_global_mesh or (
dp_global_mesh is None or tp_global_mesh is None
):
raise AssertionError(
"FSDP requires the DP and TP mesh to have the same parent mesh but got: \n"
f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}"
)
name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism"
assert dp_mesh.mesh_dim_names is not None, name_dims_error
assert tp_mesh.mesh_dim_names is not None, name_dims_error
submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
self._spmd_mesh = dp_global_mesh[submesh_names]
if len(self._tp_spec.placements) != 1:
raise NotImplementedError(
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
)
split_factor = self._tp_spec.num_shards_map[0]
assert (
2 <= self._spmd_mesh.ndim <= 3
), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
self._spmd_placements: Tuple[Placement, ...]
dp_shard_tp_placement = (
(
_StridedShard(0, split_factor=split_factor)
if split_factor > 1
else Shard(0)
),
self._tp_spec.placements[0],
)
if self._spmd_mesh.ndim == 2:
self._spmd_placements = dp_shard_tp_placement
else:
assert self.mesh_info.replicate_mesh_dim == 0
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
self._sharding_spec = DTensorSpec(
self._spmd_mesh,
self._spmd_placements,
tensor_meta=self._tp_spec.tensor_meta,
)
# NOTE: FSDP+TP does not support uneven sharding for now
# TODO: enable uneven sharding for FSDP+TP
if split_factor > 1: # FSDP has strided sharding on tensor dim 0
num_shards = self._sharding_spec.num_shards_map[0]
tensor_size_dim_0 = self._sharding_spec.shape[0]
if tensor_size_dim_0 % num_shards != 0:
raise NotImplementedError(
"FSDP+TP sharding does not support uneven sharding for now: "
f"tensor dim 0 has size {tensor_size_dim_0} which cannot be "
f"evenly sharded into {num_shards} shards."
)
param_data = cast(DTensor, param)._local_tensor
else:
self._spmd_mesh = self.mesh_info.mesh
if isinstance(self.mesh_info, HSDPMeshInfo):
self._spmd_placements = (Replicate(), Shard(0))
else:
self._spmd_placements = (Shard(0),)
self._sharding_spec = DTensorSpec(
self._spmd_mesh,
self._spmd_placements,
tensor_meta=TensorMeta(
param.size(),
param.stride(),
param.dtype,
),
)
param_data = param
self._orig_size = param_data.size()
self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
shard_rank = self.mesh_info.shard_mesh_rank
shard_world_size = self.mesh_info.shard_mesh_size
chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
sharded_param = chunks[shard_rank]
self.sharded_size = _get_dim0_chunked_size(sharded_param, param_data.size())
self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
padded_sharded_size = chunks[0].size() # 0th always padded
padded_sharded_param = param_data.new_zeros(padded_sharded_size)
self.padded_sharded_param_size = padded_sharded_param.size()
if sharded_param.numel() > 0:
padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param)
if self.offload_to_cpu and not padded_sharded_param.is_meta:
padded_sharded_param = padded_sharded_param.cpu()
if self.pin_memory:
padded_sharded_param = padded_sharded_param.pin_memory()
self._sharded_param_data = padded_sharded_param.view(-1)
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(padded_sharded_param[: sharded_param.size(0)])
)
self.sharded_param.requires_grad_(param.requires_grad)
# Let `param_data` be freed normally when its ref count reaches 0 when
# the `fully_shard` call returns to allow provided parameters to alias
self._setattr_on_modules(self.sharded_param)
self.sharded_state = ShardedState.SHARDED
def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None:
mesh_info = self.post_forward_mesh_info
assert mesh_info is not None # mypy
param_data = param._local_tensor if isinstance(param, DTensor) else param
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
self.sharded_post_forward_size = _get_dim0_chunked_size(
chunks[mesh_info.shard_mesh_rank], param_data.size()
)
self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
self.sharded_post_forward_size
)
def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
self.orig_dtype = self.sharded_param.dtype
# Clamp `param_dtype` to `None` if no casting is required
if param_dtype == self.orig_dtype:
param_dtype = None
self.param_dtype = param_dtype
self.reduce_dtype = reduce_dtype
# None indicates that the mixed precision is not enabled
def _init_extensions(self) -> None:
inner_tensor = self._sharded_local_tensor
has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather")
has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather")
if has_fsdp_pre_all_gather != has_fsdp_post_all_gather:
raise AssertionError(
"Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined "
f"if using all-gather extensions: {inner_tensor}"
)
if has_fsdp_pre_all_gather:
if self.padded_sharded_param_size != self._sharded_local_tensor.size():
raise NotImplementedError(
"FSDP all-gather extensions require even sharding on dim-0.\n"
f"{self._orig_size} is not divisible by FSDP world size {self.mesh_info.mesh.size()}."
)
self._extensions_data = ExtensionsData()
self._unsharded_inner_tensors: List[torch.Tensor] = []
def init_all_gather_outputs(
self,
all_gather_input_numels: List[int],
all_gather_input_dtypes: List[torch.dtype],
world_size: int,
device: torch.device,
force_recreate: bool = False,
):
if not force_recreate and len(self.all_gather_outputs) > 0:
return # already initialized
self.all_gather_outputs = [
torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
]
def init_unsharded_param(self):
"""
[Note: Invariants for torch.compile Traceable FSDP2]
1. Under compile, we always re-populate the content of `self._unsharded_param`
per AllGather using the slow path.
2. Under compile, we always recreate `self.all_gather_outputs` per AllGather.
This is to ensure the buffer creation is internal to the graph and
avoid `self.all_gather_outputs` being captured as a graph input.
3. Under compile, at the end of `free_unsharded_param()`, we always clean up
`self.all_gather_outputs` and `self._unsharded_inner_tensors`,
to avoid them being captured as graph output.
With these invariants, only these tensors will be inputs to the graph:
- Sharded parameters
- Placeholders for the `self._unsharded_param` nn.Parameter
"""
if not ca.compiled_autograd_enabled and hasattr(
self, "_unsharded_param"
): # after the 1st all-gather
inner_tensor = self._sharded_local_tensor
if not hasattr(inner_tensor, "fsdp_post_all_gather"):
return # already initialized
for tensor in self._unsharded_inner_tensors:
alloc_storage(tensor)
all_gather_outputs = self._unflatten_all_gather_outputs()
inner_tensor.fsdp_post_all_gather(
all_gather_outputs,
self._extensions_data.all_gather_metadata,
self.param_dtype or self.orig_dtype,
out=self._unsharded_param,
)
self._extensions_data.clear()
return
inner_tensor = self._sharded_local_tensor
if not ca.compiled_autograd_enabled and hasattr(
inner_tensor, "fsdp_post_all_gather"
):
all_gather_outputs = self._unflatten_all_gather_outputs()
(
unsharded_tensor,
self._unsharded_inner_tensors,
) = inner_tensor.fsdp_post_all_gather(
all_gather_outputs,
self._extensions_data.all_gather_metadata,
self.param_dtype or self.orig_dtype,
)
self._extensions_data.clear()
else:
# For the default path (no post-all-gather), the all-gather output
# gives the unsharded parameter data directly
assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}"
unsharded_tensor = self.all_gather_outputs[0]
unsharded_param = torch.as_strided(
unsharded_tensor,
self._orig_size,
self._contiguous_orig_stride,
storage_offset=0,
)
if self.is_dtensor:
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
if hasattr(self, "_unsharded_param"):
assert ca.compiled_autograd_enabled
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
self._unsharded_param
):
torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)
def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]:
return tuple(
t.view(-1, *s[1:])
for t, s in zip(
self.all_gather_outputs, self._extensions_data.all_gather_input_sizes
)
)
def to_sharded(self) -> None:
self._setattr_on_modules(self.sharded_param)
self.free_unsharded_param()
self.sharded_state = ShardedState.SHARDED
def to_sharded_post_forward(self) -> None:
if self.is_dtensor:
raise NotImplementedError(
"Resharding to smaller mesh with TP is not supported yet"
)
self._assert_in_states(ShardedState.UNSHARDED)
assert self.post_forward_mesh_info is not None # mypy
assert len(self.all_gather_outputs) == 1
shard_world_size = self.post_forward_mesh_info.shard_mesh_size
if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0:
_raise_assert_with_print(
f"All-gather output size ({numel}) must be divisible by the shard "
f"world size ({shard_world_size})"
)
shard_rank = self.post_forward_mesh_info.shard_mesh_rank
sharded_numel = numel // shard_world_size
self._sharded_post_forward_param_data = (
self.all_gather_outputs[0].narrow(
0, sharded_numel * shard_rank, sharded_numel
)
).clone() # clone to be able to free all-gather output
sharded_post_forward_tensor = torch.as_strided(
self._sharded_post_forward_param_data,
size=self.sharded_post_forward_size,
stride=self.contiguous_sharded_post_forward_stride,
storage_offset=0,
)
self._sharded_post_forward_param = nn.Parameter(
self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
)
self._setattr_on_modules(self._sharded_post_forward_param)
self.free_unsharded_param()
self.sharded_state = ShardedState.SHARDED_POST_FORWARD
def to_unsharded(self) -> None:
# Assume that the data has been allocated and all-gathered
set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
self._setattr_on_modules(self._unsharded_param)
if self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
# The data is allocated in the default stream via the post-forward
# reshard and must be kept alive for the next all-gather copy-in.
# Since we call this method after the copy-out, the data's lifetime
# is ensured without further synchronization.
self._sharded_post_forward_param = None
self._sharded_post_forward_param_data = None # free
self.sharded_state = ShardedState.UNSHARDED
def _setattr_on_modules(self, param: nn.Parameter) -> None:
unsafe_setattr_param(
self._module_info.module, self._module_info.param_name, param
)
for shared_module, shared_param_name in zip(
self._module_info.shared_modules, self._module_info.shared_param_names
):
unsafe_setattr_param(shared_module, shared_param_name, param)
def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
"""
Converts a local tensor representing either the sharded parameter or
sharded gradient to DTensor.
"""
if tensor.shape != self.sharded_size:
_raise_assert_with_print(
f"Expects size {self.sharded_size} but got {tensor.shape}"
)
return _from_local_no_grad(
tensor,
self._sharding_spec,
)
def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor:
if tensor.shape != self.sharded_post_forward_size:
_raise_assert_with_print(
f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}"
)
assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo)
# TODO: Prefer this DTensor to be read-only and generalize the
# placement once we support TP.
post_forward_sharding_spec = DTensorSpec(
self.post_forward_mesh_info.mesh,
(Replicate(), Shard(0)),
tensor_meta=self._sharding_spec.tensor_meta,
)
return _from_local_no_grad(tensor, post_forward_sharding_spec)
def to_accumulated_grad_if_needed(self) -> None:
# Access `_unsharded_param` to bypass the sharded state check since we
# prefer to reshard before upcasting the gradient to save memory
if (
self.reduce_dtype is None
or self._unsharded_param.grad is None
or self._unsharded_param.grad.dtype == self.reduce_dtype
):
return
unsharded_grad = self._unsharded_param.grad
self._unsharded_param.grad = None
self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype)
def accumulate_unsharded_grad_if_needed(self) -> None:
if (
self.unsharded_accumulated_grad is not None
and self.unsharded_param.grad is not None
):
self.unsharded_accumulated_grad += self.unsharded_param.grad
self.unsharded_param.grad = None
def alloc_all_gather_outputs(self) -> None:
for tensor in self.all_gather_outputs:
alloc_storage(tensor)
def free_unsharded_param(self) -> None:
for tensor in itertools.chain(
self.all_gather_outputs, self._unsharded_inner_tensors
):
free_storage(tensor)
if ca.compiled_autograd_enabled:
self.all_gather_outputs = []
self._unsharded_inner_tensors = []
@property
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D
self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
if self.sharded_state == ShardedState.SHARDED:
if not ca.compiled_autograd_enabled and hasattr(
self._sharded_local_tensor, "fsdp_pre_all_gather"
):
sharded_local_tensor = self._sharded_local_tensor
if self.offload_to_cpu:
sharded_local_tensor = sharded_local_tensor.to(
self.device, non_blocking=True
)
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh)
self._extensions_data.all_gather_input_sizes = [
t.size() for t in all_gather_inputs
]
return [t.view(-1) for t in all_gather_inputs]
sharded_param_data = self._sharded_param_data
if self.offload_to_cpu:
sharded_param_data = sharded_param_data.to(
self.device, non_blocking=True
)
return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
if not ca.compiled_autograd_enabled and hasattr(
self._sharded_local_tensor, "fsdp_pre_all_gather"
):
raise NotImplementedError
all_gather_input = _to_dtype_if_needed(
cast(torch.Tensor, self._sharded_post_forward_param_data),
self.param_dtype,
)
return [all_gather_input]
return [torch.empty(0)] # mypy
@property
def unsharded_param(self) -> nn.Parameter: # ND
self._assert_in_states(ShardedState.UNSHARDED)
return self._unsharded_param
@property
def unsharded_grad_data(self) -> torch.Tensor:
grad = self.unsharded_param.grad
assert grad is not None, "Expects unsharded_param.grad to not be None"
return self._get_grad_inner_tensor(grad)
@property
def unsharded_accumulated_grad_data(self) -> torch.Tensor:
grad = self.unsharded_accumulated_grad
assert grad is not None, "Expects unsharded_accumulated_grad to not be None"
return self._get_grad_inner_tensor(grad)
def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor:
if self.is_dtensor:
if isinstance(grad, AsyncCollectiveTensor):
grad = grad.wait()
assert isinstance(grad, DTensor), f"{type(grad)}"
if any(pl.is_partial() for pl in grad.placements):
placements = [
Replicate() if pl.is_partial() else pl for pl in grad.placements
]
grad = grad.redistribute(placements=placements)
grad = grad._local_tensor
return grad
@property
def _sharded_local_tensor(self) -> torch.Tensor:
return cast(DTensor, self.sharded_param)._local_tensor
def _assert_in_states(self, *states: ShardedState) -> None:
if self.sharded_state not in states:
_raise_assert_with_print(
f"Expects to be in one of {states}, not {self.sharded_state}"
)
def reset_sharded_param(self):
# For ops like `nn.Module._apply` or `load_state_dict(assign=True)`
# that change the sharded parameter tensor, we may need to re-pad the
# sharded local tensor and re-save the reference.
module_info = self._module_info
new_param = getattr(module_info.module, module_info.param_name)
if new_param is not self.sharded_param:
if torch.__future__.get_swap_module_params_on_conversion():
raise AssertionError(
f"Expects swap_tensors to preserve object but got {new_param} "
f"instead of {self.sharded_param}"
)
self.sharded_param = new_param
local_tensor = new_param._local_tensor
if local_tensor.is_meta:
return
padded_sharded_size = self.padded_sharded_param_size
if local_tensor.size() != padded_sharded_size:
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
padded_local_tensor[: local_tensor.size(0)].copy_(local_tensor)
local_tensor = padded_local_tensor
if self.pin_memory and not local_tensor.is_pinned():
local_tensor = local_tensor.cpu().pin_memory()
self._sharded_param_data = local_tensor.view(-1)
assert isinstance(self.sharded_param, DTensor) # mypy
self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]]
def __repr__(self):
return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})"
def alloc_storage(tensor: torch.Tensor) -> None:
size = tensor.numel() * tensor.itemsize
if (storage := tensor.untyped_storage()).size() != size:
storage.resize_(size)
def free_storage(tensor: torch.Tensor) -> None:
if (storage := tensor.untyped_storage()).size() != 0:
storage.resize_(0)
# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial
# CPU overhead, if the module did not override it. For FSDP, we know we do not
# need those checks when transitioning between sharded/unsharded parameters.
def unsafe_setattr_param(
module: nn.Module, param_name: str, param: nn.Parameter
) -> None:
if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__:
module._parameters[param_name] = param
else: # slow path
setattr(module, param_name, param)
def set_requires_grad_if_needed(
src_tensor: torch.Tensor, dst_tensor: torch.Tensor
) -> None:
# Only call `requires_grad_` if needed to avoid the Python <> C++ context
# switch overhead
if src_tensor.requires_grad != dst_tensor.requires_grad:
dst_tensor.requires_grad_(src_tensor.requires_grad)

View File

@ -0,0 +1,614 @@
# mypy: allow-untyped-defs
import contextlib
import logging
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
import torch
import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates
from torch.profiler import record_function
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils.hooks import RemovableHandle
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_collectives import (
AllGatherResult,
foreach_all_gather,
foreach_all_gather_copy_out,
foreach_reduce,
)
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
logger = logging.getLogger("torch.distributed._composable.fsdp")
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
"""
[Note: Overlapping all-gather copy-in and all-gather]
For implicit forward prefetching, we want to overlap the next copy-in with the
current all-gather. We do so using a separate copy-in stream. However, since
we have the all-gather input as a view into the output, we must make sure to
copy into different memory from the current all-gather's output. Thus, we keep
a reference to the current all-gather's output and have the next FSDP parameter
group free it after its copy-in. Finally, we have the last FSDP state flush the
reference to avoid holding onto memory after forward.
"""
class FSDPCommContext:
"""This has the communication state shared across FSDP states/parameter groups."""
def lazy_init(self):
if not torch.cuda.is_available():
raise RuntimeError("FSDP requires CUDA for streams")
# Setting the all-gather/reduce-scatter streams to be higher priority
# can help avoid some issues where their copies in/out are delayed and
# block computation (this is different from high-pri NCCL streams)
high_priority = -1
# All-gather state and copy-in stream allow overlapping the next
# copy-in with the current all-gather in forward; copy-in overlaps with
# reduce-scatter in backward without the separate copy-in stream
self.all_gather_copy_in_stream = torch.cuda.Stream(priority=high_priority)
# All-gather stream allows overlapping next all-gather with current
# forward compute
self.all_gather_stream = torch.cuda.Stream(priority=high_priority)
# Reduce-scatter stream gives separate execution "thread" for post-
# backward logic like pre/post-gradient division and reduce-scatter
self.reduce_scatter_stream = torch.cuda.Stream(priority=high_priority)
# Run the HSDP all-reduces concurrently with all-gather/reduce-scatter
# since collectives use different network resources and can overlap
# in the typical intra-node sharding / inter-node replication case
self.all_reduce_stream = torch.cuda.Stream()
# All-gather/reduce-scatter states keep references to collective
# tensors produced in one stream and used in another and accompanying
# CUDA events for synchronization
self.all_gather_state: Optional[AllGatherState] = None
self.reduce_scatter_state: Optional[ReduceScatterState] = None
# Post-forward order for explicit backward prefetching
self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles
def get_all_gather_streams(
self, training_state: TrainingState
) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]:
if training_state in (TrainingState.FORWARD, TrainingState.PRE_BACKWARD):
# Use separate streams for implicit prefetching
return self.all_gather_copy_in_stream, self.all_gather_stream
current_stream = torch.cuda.current_stream()
return current_stream, current_stream
# See [Note: Overlapping all-gather copy-in and all-gather]
class AllGatherState(NamedTuple):
all_gather_result: AllGatherResult
event: torch.cuda.Event # all-gather copy-out
class ReduceScatterState(NamedTuple):
reduce_scatter_input: torch.Tensor
event: torch.cuda.Event # reduce-scatter event
class FSDPParamGroup:
"""This class represents a parameter group to communicate together."""
_orig_dtype: torch.dtype
_reduce_dtype: Optional[torch.dtype]
def __init__(
self,
params: List[nn.Parameter],
modules: Tuple[nn.Module, ...],
mesh_info: FSDPMeshInfo,
post_forward_mesh_info: Optional[FSDPMeshInfo],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
offload_policy: OffloadPolicy,
):
self.modules = modules # permit ref cycle because 1:1 lifetime
param_module_infos = _get_param_module_infos(params, modules)
self.fsdp_params = [
FSDPParam(
param,
module_info,
mesh_info,
post_forward_mesh_info,
device,
mp_policy,
offload_policy,
)
for param, module_info in zip(params, param_module_infos)
]
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
self.device = device
self.mp_policy = mp_policy
self._training_state = TrainingState.IDLE
# Group's sharded state always matches its parameters' sharded states
self._sharded_state = ShardedState.SHARDED
self._module_fqn: Optional[str] = None # prefixed from root module
# Only consider resetting sharded parameters once in lazy init since it
# can incur nontrivial overhead to reset them
self._reset_sharded_params: bool = False
# - Hook state
self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {}
self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {}
# - Communication and communication/computation overlap
self.comm_ctx = FSDPCommContext()
# Group's indices in the shared post-forward order
self._post_forward_indices: List[int] = []
# Whether to reduce gradients at all (whether for FSDP or HSDP)
self.reduce_grads: bool = True
# Whether to all-reduce gradients for HSDP; only used if
# `self.reduce_grads` is true, in which case setting this to false
# means reduce-scatter but no all-reduce
self.all_reduce_grads: bool = True
# Whether to reshard parameters after backward (only useful for
# gradient accumulation)
self.reshard_after_backward: bool = True
# Optional custom reduce-scatter reduce op (e.g. to divide by a
# factor other than the shard world size)
self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None
# - CUDA events for stream synchronization
# Holds the all-gather output buffer, sync objects, and metadata
self._all_gather_result: Optional[AllGatherResult] = None
# Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of
# the group's post-backward (e.g. reduce-scatter, all-reduce and div), which
# should be waited on at the end of backward
self._post_reduce_event: Optional[torch.cuda.Event] = None
# Holds the reshard-after-forward CUDA event when resharding to a
# different world size, which should be waited on in the next unshard
self._reshard_after_forward_event: Optional[torch.cuda.Event] = None
# Only for HSDP, if accumulating gradients without all-reduce, save the
# partial reduce output (only reduce-scattered but not all-reduced)
self._partial_reduce_output: Optional[torch.Tensor] = None
# Initialization #
def _init_mp_dtypes(self) -> None:
for fsdp_param in self.fsdp_params:
fsdp_param.init_dtype_attrs(self.mp_policy)
orig_dtypes = {fsdp_param.orig_dtype for fsdp_param in self.fsdp_params}
if len(orig_dtypes) != 1:
# This can be relaxed if we copy-out for the reduce-scatter
raise AssertionError(
f"FSDP expects uniform original parameter dtype but got {orig_dtypes}"
)
self._orig_dtype = next(iter(orig_dtypes))
reduce_dtypes = {fsdp_param.reduce_dtype for fsdp_param in self.fsdp_params}
if len(reduce_dtypes) != 1:
# This can be relaxed if we issue one reduce-scatter per reduce
# dtype (but we would need a way for users to specify multiple
# reduce dtypes)
raise AssertionError(
f"FSDP expects uniform reduce dtype but got {reduce_dtypes}"
)
self._reduce_dtype = next(iter(reduce_dtypes))
def lazy_init(self):
# Lazy init should be idempotent
# Users may change or register parameters after construction time.
# For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
# other parameters (e.g. loaded from the state dict).
if self.is_sharded and not self._reset_sharded_params:
for fsdp_param in self.fsdp_params:
fsdp_param.reset_sharded_param()
self._reset_sharded_params = True
param_names_on_meta = [
fsdp_param._param_fqn
for fsdp_param in self.fsdp_params
if fsdp_param.sharded_param.device.type == "meta"
]
if param_names_on_meta:
raise RuntimeError(
"FSDP parameters should be materialized from meta device before training, "
f"but the following were still on meta device: {param_names_on_meta}\n"
"For example, call module.to_empty(device) to materialize to device and "
"call module.reset_parameters() on each module to initialize values."
)
# Initialize mixed precision attributes lazily in case the user changes
# the parameter dtypes after construction time but before forward
self._init_mp_dtypes()
self._register_state_dict_hooks()
# Runtime #
def unshard(self, async_op: bool = False):
if self._all_gather_result is not None: # already called, pending wait
return
if self.is_unsharded:
return # no-op
if self._reshard_after_forward_event is not None:
# Resharded parameter data is allocated in the default stream and
# used in the all-gather streams
self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
self._reshard_after_forward_event = None
with record_function(self._with_fqn("FSDP::all_gather")):
self._all_gather_result = foreach_all_gather(
self.fsdp_params,
self._all_gather_process_group,
async_op,
*self.comm_ctx.get_all_gather_streams(self._training_state),
self.device,
)
def wait_for_unshard(self):
"""
1. In forward with implict prefetching, to overlap the current copy-out
with the next all-gather, we save a reference to the current all-gather
result to free after the next copy-out.
2. Otherwise (explicit prefetching or in backward), we free the
all-gather result immediately after the current copy-out since we can
already overlap the current copy-out with the previous reduce-scatter.
"""
if not self._all_gather_result:
return # no preceding unshard
if self._training_state == TrainingState.FORWARD: # implicit prefetch
if prev_all_gather_state := self.comm_ctx.all_gather_state:
self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
self.comm_ctx.all_gather_state = None # free the all-gather result
with record_function(self._with_fqn("FSDP::all_gather_copy_out")):
foreach_all_gather_copy_out(
self._all_gather_result,
self.fsdp_params,
self._all_gather_process_group,
)
for fsdp_param in self.fsdp_params:
fsdp_param.init_unsharded_param()
self._to_unsharded()
all_gather_copy_out_event = torch.cuda.Event()
all_gather_copy_out_event.record()
if self._training_state == TrainingState.FORWARD:
self.comm_ctx.all_gather_state = AllGatherState(
self._all_gather_result, all_gather_copy_out_event
)
else:
self._wait_all_gather_streams_on_event(all_gather_copy_out_event)
self._all_gather_result = None # free unless saved in `all_gather_state`
def _wait_all_gather_streams_on_event(self, event: torch.cuda.Event):
# Calling `unshard` before lazy init means streams are not initialized
if hasattr(self.comm_ctx, "all_gather_copy_in_stream"):
self.comm_ctx.all_gather_copy_in_stream.wait_event(event)
if hasattr(self.comm_ctx, "all_gather_stream"):
self.comm_ctx.all_gather_stream.wait_event(event)
def reshard(self):
if self._training_state == TrainingState.FORWARD:
if not self._reshard_after_forward:
return
if self._use_post_forward_mesh:
self._to_sharded_post_forward()
self._reshard_after_forward_event = torch.cuda.Event()
self._reshard_after_forward_event.record()
return
self._to_sharded()
def pre_forward(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::pre_forward"))
with record_function(self._with_fqn("FSDP::pre_forward")):
self._training_state = TrainingState.FORWARD
self.unshard()
self.wait_for_unshard()
args, kwargs = self._register_post_backward_hook(args, kwargs)
return args, kwargs
def post_forward(self, module: nn.Module, input: Any, output: Any):
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::post_forward"))
with record_function(self._with_fqn("FSDP::post_forward")):
self.reshard()
self._record_post_forward()
self._training_state = TrainingState.IDLE
return output
def _record_post_forward(self) -> None:
# Since a group has one pre-backward unshard for each forward call
# before the backward, we record each usage (with multiplicity)
post_forward_index = len(self.comm_ctx.post_forward_order)
self.comm_ctx.post_forward_order.append(self)
self._post_forward_indices.append(post_forward_index)
def pre_backward(self, default_prefetch: bool, *unused: Any):
if self._training_state == TrainingState.PRE_BACKWARD:
return
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::pre_backward"))
with record_function(self._with_fqn("FSDP::pre_backward")):
self._training_state = TrainingState.PRE_BACKWARD
self.unshard() # no-op if prefetched
self.wait_for_unshard()
if default_prefetch and not ca.compiled_autograd_enabled:
self._backward_prefetch()
def post_backward(self, *unused: Any):
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::post_backward"))
self._training_state = TrainingState.POST_BACKWARD
with record_function(self._with_fqn("FSDP::post_backward_accumulate")):
for fsdp_param in self.fsdp_params:
fsdp_param.accumulate_unsharded_grad_if_needed()
with record_function(self._with_fqn("FSDP::post_backward_reshard")):
if not self.reduce_grads:
if self.reshard_after_backward:
self.reshard()
for fsdp_param in self.fsdp_params:
fsdp_param.to_accumulated_grad_if_needed()
return
# Save the autograd-computed gradients before resharding to only
# access the unsharded parameters when their data is present
fsdp_params_with_grad: List[FSDPParam] = []
unsharded_grads: List[torch.Tensor] = []
for fsdp_param in self.fsdp_params:
# May have an accumulated gradient of the reduce dtype if the
# previous backward did not reduce-scatter
if fsdp_param.unsharded_accumulated_grad is not None:
fsdp_params_with_grad.append(fsdp_param)
unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data)
fsdp_param.unsharded_accumulated_grad = None
elif fsdp_param.unsharded_param.grad is not None:
fsdp_params_with_grad.append(fsdp_param)
unsharded_grads.append(fsdp_param.unsharded_grad_data)
fsdp_param.unsharded_param.grad = None
if self.reshard_after_backward:
self.reshard()
if len(fsdp_params_with_grad) == 0:
return
with record_function(self._with_fqn("FSDP::post_backward_reduce")):
if self.comm_ctx.reduce_scatter_state is not None:
torch.cuda.current_stream().wait_event(
self.comm_ctx.reduce_scatter_state.event
)
self.comm_ctx.reduce_scatter_state = None
(
reduce_scatter_input,
reduce_scatter_event,
self._post_reduce_event,
self._partial_reduce_output,
) = foreach_reduce(
fsdp_params_with_grad,
unsharded_grads,
self._reduce_scatter_process_group,
self.comm_ctx.reduce_scatter_stream,
self._orig_dtype,
self._reduce_dtype,
self.device,
self.reduce_scatter_reduce_op,
self._all_reduce_process_group if self._is_hsdp else None,
self.comm_ctx.all_reduce_stream,
self.all_reduce_grads,
self._partial_reduce_output,
)
self.comm_ctx.reduce_scatter_state = ReduceScatterState(
reduce_scatter_input, reduce_scatter_event
)
def finalize_backward(self):
if self._post_reduce_event is not None:
torch.cuda.current_stream().wait_event(self._post_reduce_event)
self._post_reduce_event = None
for fsdp_param in self.fsdp_params:
if fsdp_param.grad_offload_event is not None:
fsdp_param.grad_offload_event.synchronize()
fsdp_param.grad_offload_event = None
self._post_forward_indices.clear()
def _backward_prefetch(self) -> None:
if self._training_state == TrainingState.PRE_BACKWARD:
if not self._post_forward_indices:
# Can be cleared if running multiple `backward`s
return
curr_index = self._post_forward_indices.pop()
if (target_index := curr_index - 1) < 0:
return
# Prefetch naively using the reverse post-forward order, which may
# have mistargeted prefetches if not all modules used in forward
# are used in this backward
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
self._prefetch_unshard(target_fsdp_param_group, "backward")
@staticmethod
def _prefetch_unshard(
target_fsdp_param_group: "FSDPParamGroup", pass_type: str
) -> None:
if pass_type == "backward":
training_state = TrainingState.PRE_BACKWARD
elif pass_type == "forward":
training_state = TrainingState.FORWARD
else:
raise ValueError(f"Unknown pass type: {pass_type}")
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
f"FSDP::{pass_type}_prefetch for {target_fqn}"
), target_fsdp_param_group.use_training_state(training_state):
target_fsdp_param_group.unshard()
# Utilities #
def _to_sharded(self):
if not self.is_sharded:
for fsdp_param in self.fsdp_params:
fsdp_param.to_sharded()
self._sharded_state = ShardedState.SHARDED
def _to_sharded_post_forward(self):
if not self.is_sharded_post_forward:
for fsdp_param in self.fsdp_params:
fsdp_param.to_sharded_post_forward()
self._sharded_state = ShardedState.SHARDED_POST_FORWARD
def _to_unsharded(self):
if not self.is_unsharded:
for fsdp_param in self.fsdp_params:
fsdp_param.to_unsharded()
self._sharded_state = ShardedState.UNSHARDED
@property
def is_sharded(self) -> bool:
return self._sharded_state == ShardedState.SHARDED
@property
def is_sharded_post_forward(self) -> bool:
return self._sharded_state == ShardedState.SHARDED_POST_FORWARD
@property
def is_unsharded(self) -> bool:
return self._sharded_state == ShardedState.UNSHARDED
@contextlib.contextmanager
def use_training_state(self, training_state: TrainingState):
old_training_state = self._training_state
self._training_state = training_state
try:
yield
finally:
self._training_state = old_training_state
# Hook Registration #
def _register_post_backward_hook(
self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
# Compile relies on `root_post_backward_callback` to call each
# `FSDPParamGroup.post_backward`
if ca.compiled_autograd_enabled:
return args, kwargs
if not torch.is_grad_enabled():
return args, kwargs
args_list, args_spec = tree_flatten(args)
kwargs_list, kwargs_spec = tree_flatten(kwargs)
args_kwargs_list = list(args_list) + list(kwargs_list)
inp_tensor_indices: List[int] = []
inp_tensors: List[torch.Tensor] = []
for i, obj in enumerate(args_kwargs_list):
if torch.is_tensor(obj) and obj.requires_grad:
inp_tensor_indices.append(i)
inp_tensors.append(obj)
if len(inp_tensors) == 0:
return args, kwargs # no tensors that require gradients
inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors)
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
args_kwargs_list[inp_tensor_idx] = inp_tensor
args_list = args_kwargs_list[: len(args_list)]
kwargs_list = args_kwargs_list[len(args_list) :]
args = tree_unflatten(args_list, args_spec)
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
return args, kwargs
def _register_state_dict_hooks(self) -> None:
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
assert (
num_pre_save_hooks == num_pre_load_hooks
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
if num_pre_save_hooks > 0:
return # already registered
modules_with_fsdp_params: Set[nn.Module] = {
fsdp_param._module_info.module for fsdp_param in self.fsdp_params
}
def to_sharded_hook(*args: Any, **kwargs: Any) -> None:
self._to_sharded()
for module in modules_with_fsdp_params:
self._module_to_pre_save_state_dict_hook_handle[
module
] = module.register_state_dict_pre_hook(to_sharded_hook)
self._module_to_pre_load_state_dict_hook_handle[
module
] = module._register_load_state_dict_pre_hook(to_sharded_hook)
# Properties #
@property
def _reshard_after_forward(self) -> bool:
return self.post_forward_mesh_info is not None
@property
def _use_post_forward_mesh(self) -> bool:
return (
self._reshard_after_forward
and self.mesh_info != self.post_forward_mesh_info
)
@property
def _is_hsdp(self) -> bool:
return isinstance(self.mesh_info, HSDPMeshInfo)
@property
def _all_gather_process_group(self) -> dist.ProcessGroup:
mesh_info = (
cast(FSDPMeshInfo, self.post_forward_mesh_info)
if self.is_sharded_post_forward
else self.mesh_info
)
assert isinstance(mesh_info, FSDPMeshInfo)
return mesh_info.shard_process_group
@property
def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
assert isinstance(self.mesh_info, FSDPMeshInfo)
return self.mesh_info.shard_process_group
@property
def _all_reduce_process_group(self) -> dist.ProcessGroup:
assert isinstance(self.mesh_info, HSDPMeshInfo)
return self.mesh_info.replicate_process_group
def _with_fqn(self, label: str) -> str:
if self._module_fqn:
return f"{label} ({self._module_fqn})"
return label
def __repr__(self):
return f"FSDPParamGroup(fqn={self._module_fqn})"
def _get_param_module_infos(
params: List[nn.Parameter], modules: Tuple[nn.Module, ...]
) -> List[ParamModuleInfo]:
"""
Shared parameter: lin1.weight = lin2.weight
Shared module: mlp.lin1 = mlp.lin2
We do not remove duplicates when traversing both modules and parameters to
find shared modules' parameters and shared parameters within a module.
"""
params_set = set(params)
param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {}
for module in modules:
for _, submodule in module.named_modules(remove_duplicate=False):
for param_name, param in _named_parameters_with_duplicates(
submodule, recurse=False
):
if param in params_set:
if param not in param_to_module_info:
param_to_module_info[param] = ParamModuleInfo(
submodule, param_name
)
else:
param_to_module_info[param].shared_modules.append(submodule)
param_to_module_info[param].shared_param_names.append(
param_name
)
if len(param_to_module_info) != len(params):
raise AssertionError(f"Some parameters are not in the module tree of {module}")
return [param_to_module_info[param] for param in params]
class RegisterPostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
# All tensors in `inputs` should require gradient
ctx.param_group = param_group
return inputs
@staticmethod
def backward(ctx, *grads: torch.Tensor):
ctx.param_group.post_backward()
return (None,) + grads

View File

@ -0,0 +1,383 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
import logging
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
)
import torch
import torch._dynamo.compiled_autograd as ca
import torch.nn as nn
from torch._logging import warning_once
from torch.autograd import Variable
from torch.autograd.graph import _MultiHandle
from torch.distributed._composable_state import (
_get_module_state,
_insert_module_state,
_State,
)
from torch.distributed.utils import _to_kwargs
from torch.utils._pytree import tree_flatten, tree_map
from ._fsdp_api import MixedPrecisionPolicy
from ._fsdp_common import _cast_fp_tensor, TrainingState
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
if TYPE_CHECKING:
from ._fsdp_param import FSDPParam
logger = logging.getLogger("torch.distributed._composable.fsdp")
class FSDPStateContext:
"""This has state shared across FSDP states."""
def __init__(self) -> None:
# All FSDP states in the root state's module tree
self.all_states: List[FSDPState] = []
# Iteration's forward root runs the once-per-forward logic; this root
# may not be the overall root set by lazy initialization in cases where
# only a submodule runs forward (e.g. encoder-only for eval)
self.iter_forward_root: Optional[FSDPState] = None
# Final callback should only be queued once per backward
self.post_backward_final_callback_queued: bool = False
# Whether to finalize backward in this backward's final callback
self.is_last_backward: bool = True
# Optional user-provided event recorded after optimizer for the
# all-gather streams to wait on in the root pre-forward
self.post_optim_event: Optional[torch.cuda.Event] = None
def disable_if_config_true(func):
@functools.wraps(func)
def fsdp_hook_wrapper(*args, **kwargs):
if torch._dynamo.config.skip_fsdp_hooks:
return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
else:
return func(*args, **kwargs)
return fsdp_hook_wrapper
class FSDPState(_State):
def __init__(self) -> None:
super().__init__()
self._fsdp_param_group: Optional[FSDPParamGroup] = None
self._is_root: Optional[bool] = None # root set during lazy init
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []
self._modules_to_run_forward: Set[nn.Module] = set()
# Define a separate init since `__init__` is called in the contract
def init(
self,
modules: Tuple[nn.Module, ...],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
) -> None:
for module in modules:
_insert_module_state(module, self)
self._modules = modules
self._device = device
self._mp_policy = mp_policy
if len(modules) == 1:
self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
self._pre_forward, prepend=True, with_kwargs=True
)
self._post_forward_hook_handle = modules[0].register_forward_hook(
self._post_forward, prepend=False
)
else:
hook_handle = _register_group_forward_hooks(
modules,
self._pre_forward,
self._post_forward,
self._modules_to_run_forward,
)
self._pre_forward_hook_handle = hook_handle
self._post_forward_hook_handle = hook_handle
def _root_pre_forward(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
self._lazy_init()
if self._state_ctx.iter_forward_root is not None:
return args, kwargs
if not ca.compiled_autograd_enabled:
logger.debug("FSDP::root_pre_forward")
self._state_ctx.iter_forward_root = self
with torch.profiler.record_function("FSDP::root_pre_forward"):
# Wait for optimizer before implicitly prefetched all-gathers
if (event := self._state_ctx.post_optim_event) is not None:
self._comm_ctx.all_gather_copy_in_stream.wait_event(event)
self._comm_ctx.all_gather_stream.wait_event(event)
self._state_ctx.post_optim_event = None
else:
current_stream = torch.cuda.current_stream()
self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream)
self._comm_ctx.all_gather_stream.wait_stream(current_stream)
if self._device.type == "cuda":
with torch.profiler.record_function("FSDP::inputs_to_device"):
args_tuple, kwargs_tuple = _to_kwargs(
args, kwargs, self._device, False
) # same as DDP
args, kwargs = args_tuple[0], kwargs_tuple[0]
return args, kwargs
def _lazy_init(self) -> None:
"""
Lazy initialization represents when all modules' parallelisms have
finalized (e.g. FSDP has been applied to all desired modules). This
means that we can determine which state is the root, and we do so by
the 1st state to run forward.
"""
if self._is_root is not None:
return # no-op: already initialized
self._is_root = True
if len(self._modules) > 1:
raise RuntimeError(
f"FSDP requires a single root module but got {self._modules}"
)
root_module = self._modules[0]
visited_states: Set[FSDPState] = set()
for module_name, module in root_module.named_modules():
if (state := _get_module_fsdp_state(module)) is None:
continue
if module is not root_module:
if state not in visited_states and state._is_root is not None:
raise RuntimeError(
"FSDP state has already been lazily initialized for "
f"{module_name}\nFSDP requires running forward through "
"the root module first"
)
state._is_root = False
self._state_ctx.all_states.append(state)
visited_states.add(state)
if self._fsdp_param_group:
# For the root, do not reshard after forward since for training,
# the parameters would be freed and all-gathered immediately
self._fsdp_param_group.post_forward_mesh_info = None
self._init_fqns()
self._init_shared_state()
# Run parameter group lazy inits after initializing FQNs for improved
# error messages
for state in self._state_ctx.all_states:
if state._fsdp_param_group:
state._fsdp_param_group.lazy_init()
def _init_shared_state(self) -> None:
self._comm_ctx.lazy_init()
for state in self._state_ctx.all_states:
state._state_ctx = self._state_ctx
state._comm_ctx = self._comm_ctx
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.comm_ctx = self._comm_ctx
def _init_fqns(self) -> None:
"""Sets module and parameter FQN attributes for debugging."""
assert self._is_root
root_module = self._modules[0]
param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {}
module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {}
for state in self._state_ctx.all_states:
if fsdp_param_group := state._fsdp_param_group:
for fsdp_param in fsdp_param_group.fsdp_params:
param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param
for module in fsdp_param_group.modules:
module_to_fsdp_param_group[module] = fsdp_param_group
for param_name, param in root_module.named_parameters():
if param in param_to_fsdp_param:
param_to_fsdp_param[param]._param_fqn = param_name
for module_name, module in root_module.named_modules():
if module in module_to_fsdp_param_group:
module_fqn = module_to_fsdp_param_group[module]._module_fqn
if module_fqn is None:
module_to_fsdp_param_group[module]._module_fqn = module_name
else:
assert isinstance(module_fqn, str), f"{module_fqn}"
module_fqn += f", {module_name}"
module_to_fsdp_param_group[module]._module_fqn = module_fqn
@disable_if_config_true
def _pre_forward(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
# When composing with module-hook-based activation checkpointing, the
# the pre-backward hook is responsible for the unshard
if self._training_state == TrainingState.PRE_BACKWARD:
return args, kwargs
self._training_state = TrainingState.FORWARD
args, kwargs = self._root_pre_forward(module, args, kwargs)
if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype:
with torch.profiler.record_function("FSDP::cast_forward_inputs"):
cast_fn = functools.partial(
_cast_fp_tensor, self._mp_policy.param_dtype
)
args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)
if self._fsdp_param_group:
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
for fsdp_state in self._states_to_forward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "forward")
return args, kwargs
@disable_if_config_true
def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any:
# When composing with module-hook-based activation checkpointing, the
# post-backward hook is responsible for the reshard
if self._training_state == TrainingState.PRE_BACKWARD:
return output
if self._fsdp_param_group:
output = self._fsdp_param_group.post_forward(module, input, output)
output = self._register_pre_backward_hook(output)
self._training_state = TrainingState.IDLE
if self._state_ctx.iter_forward_root is self:
if all_gather_state := self._comm_ctx.all_gather_state:
# Free the last all-gather result if needed; refer to
# [Note: Overlapping all-gather copy-in and all-gather]
self._comm_ctx.all_gather_copy_in_stream.wait_event(
all_gather_state.event
)
self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event)
self._comm_ctx.all_gather_state = None # free the all-gather result
self._state_ctx.iter_forward_root = None
if self._mp_policy.output_dtype is not None:
with torch.profiler.record_function("FSDP::cast_forward_outputs"):
output = tree_map(
functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype),
output,
)
return output
def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor:
self._training_state = TrainingState.PRE_BACKWARD
self._register_root_post_backward_final_callback()
if self._fsdp_param_group:
default_prefetch = len(self._states_to_backward_prefetch) == 0
self._fsdp_param_group.pre_backward(default_prefetch)
for fsdp_state in self._states_to_backward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "backward")
return grad
def _root_post_backward_final_callback(self) -> None:
if not ca.compiled_autograd_enabled:
logger.debug("FSDP::root_post_backward")
with torch.profiler.record_function("FSDP::root_post_backward_callback"):
for state in self._state_ctx.all_states:
if state._fsdp_param_group and state._fsdp_param_group.is_unsharded:
# Run post-backward in case forward inputs did not require
# gradient so the autograd backward did not run
state._fsdp_param_group.post_backward()
state._training_state = TrainingState.IDLE
if state._fsdp_param_group:
state._fsdp_param_group._training_state = TrainingState.IDLE
if self._state_ctx.is_last_backward:
state._finalize_backward()
if self._state_ctx.is_last_backward:
self._comm_ctx.post_forward_order.clear()
if self._comm_ctx.reduce_scatter_state is not None:
torch.cuda.current_stream().wait_event(
self._comm_ctx.reduce_scatter_state.event
)
self._comm_ctx.reduce_scatter_state = None
self._state_ctx.post_backward_final_callback_queued = False
def _finalize_backward(self) -> None:
if self._modules_to_run_forward:
msg = (
f"{len(self._modules_to_run_forward)} of the {len(self._modules)} "
f"modules passed to fully_shard did not run forward before backward, "
"which is error-prone since FSDP post-forward/pre-backward logic "
"will not run for these modules. We recommend passing only modules "
"that run forward together. Modules that did not run forward: "
f"{list(self._modules_to_run_forward)}"
)
warning_once(logger, msg, stacklevel=2)
# Clear since we want the next forward to run
self._modules_to_run_forward.clear()
if self._fsdp_param_group:
self._fsdp_param_group.finalize_backward()
def _register_pre_backward_hook(self, output: Any) -> Any:
if not torch.is_grad_enabled():
return output
flat_outputs, _ = tree_flatten(output)
for t in flat_outputs:
if torch.is_tensor(t) and t.requires_grad:
t.register_hook(self._pre_backward)
return output
def _register_root_post_backward_final_callback(self):
if self._state_ctx.post_backward_final_callback_queued:
return
self._state_ctx.post_backward_final_callback_queued = True
Variable._execution_engine.queue_callback(
self._root_post_backward_final_callback
)
def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]:
state = _get_module_state(module)
if isinstance(state, FSDPState):
return state
return None
def _register_group_forward_hooks(
modules: Sequence[nn.Module],
pre_hook: Callable,
post_hook: Callable,
modules_to_run: Set[nn.Module],
):
"""
Registers group forward pre and post-hooks. The pre-hook runs upon the
first module pre-forward, and the post-hook runs upon the last. If at least
one module does not run forward, then the post-hook does not run.
"""
modules_set = set(modules)
@disable_if_config_true
@functools.wraps(pre_hook)
def wrapped_pre_hook(*args: Any, **kwargs: Any):
if len(modules_to_run) == 0: # first to run
modules_to_run.update(modules_set)
return pre_hook(*args, **kwargs)
@disable_if_config_true
def get_wrapped_post_hook(module: nn.Module):
@functools.wraps(post_hook)
def wrapped_post_hook(*args: Any, **kwargs: Any):
modules_to_run.discard(module)
if len(modules_to_run) == 0:
return post_hook(*args, **kwargs)
return wrapped_post_hook
pre_handles = [
module.register_forward_pre_hook(
wrapped_pre_hook, prepend=True, with_kwargs=True
)
for module in modules
]
post_handles = [
module.register_forward_hook(
get_wrapped_post_hook(module), prepend=False, always_call=True
)
for module in modules
]
return _MultiHandle(tuple(pre_handles + post_handles))

View File

@ -0,0 +1,446 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Type, Union
import torch
import torch.nn as nn
from torch.distributed._composable import contract
from torch.distributed.tensor import DeviceMesh
from torch.distributed.utils import _get_root_modules
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
from ._fsdp_init import (
_get_device_from_mesh,
_get_managed_modules,
_get_managed_states,
_get_post_forward_mesh_info,
_init_default_fully_shard_mesh,
_move_states_to_device,
)
from ._fsdp_param_group import FSDPParamGroup
from ._fsdp_state import _get_module_fsdp_state, FSDPState
cls_to_fsdp_cls: Dict[Type, Type] = {}
# The decorator adds a state object to `module` that can be accessed via
# `fully_shard.state(module)`. The state object and module are 1:1.
@contract(state_cls=FSDPState) # type: ignore[operator]
def fully_shard(
module: Union[nn.Module, List[nn.Module]],
*,
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
):
"""
Shard module parameters across data parallel workers.
This function applies fully sharded data parallelism (FSDP) or a variant to
``module``, a technique for memory savings at the cost of communication.
Parameters are sharded across ``mesh``, and in turn, so are their gradients
and optimizer states.
The sharded parameters are all-gathered to construct the unsharded
parameters for forward or backward computation. The unsharded parameters
are freed after computation to save memory. The gradients are reduced
across the mesh and divided by the mesh size for data parallelism. The
optimizer step runs on the sharded parameters.
Each call to ``fully_shard`` constructs one communication group that
includes the parameters in ``module.parameters()`` except those already
assigned to a group from a nested call. Each group's parameters and its
gradients are communicated together in one collective, respectively.
Constructing multiple groups across the model (e.g. "layer by layer")
allows for peak memory savings and communication/computation overlap.
Implementation-wise, the sharded parameters are represented as
:class:`DTensor` s, sharded on dim-0, and the unsharded parameters are
represented as :class:`Tensor` s. A module forward pre-hook all-gathers the
parameters, and a module forward hook frees them. Similar backward hooks
gather parameters and later free parameters/reduce gradients.
Args:
module (Union[nn.Module, List[nn.Module]): The module or modules to
shard with FSDP and group together for communication.
mesh (Optional[DeviceMesh]): This data parallel mesh defines the
sharding and device. If 1D, then parameters are fully sharded
across the 1D mesh (FSDP). If 2D, then parameters are sharded
across the 0th dim and replicated across the 1st dim (HSDP). The
mesh's device type gives the device type used for communication;
if a CUDA or CUDA-like device type, then we use the current device.
reshard_after_forward (Union[bool, int]): This controls the parameter
behavior after forward and can trade off memory and communication:
- If ``True``, then this reshards parameters after forward and
all-gathers in backward.
- If ``False``, then this keeps the unsharded parameters in memory
after forward and avoids the all-gather in backward.
- If an ``int``, then this represents the world size to reshard to
after forward. It should be a non-trivial divisor of the ``mesh``
shard dim size (i.e. excluding 1 and the dim size itself). A choice
may be the intra-node size (e.g. ``torch.cuda.device_count()``).
This allows the all-gather in backward to be over a smaller world
size at the cost of higher memory usage than setting to ``True``.
- The root FSDP state has its value specially set to ``False`` as a
heuristic since its parameters would typically be immediately
all-gathered for backward.
- After forward, the parameters registered to the module depend on
to this: The registered parameters are the sharded parameters if
``True``; unsharded parameters if ``False``; and the paramters
resharded to the smaller mesh otherwise. To modify the parameters
between forward and backward, the registered parameters must be the
sharded parameters. For ``False`` or an ``int``, this can be done
by manually resharding via :meth:`reshard`.
mp_policy (MixedPrecisionPolicy): This controls the mixed precision
policy, which offers parameter/reduction mixed precision for this
module. See :class:`MixedPrecisionPolicy` for details.
offload_policy (OffloadPolicy): This controls the offloading policy,
which offers parameter/gradient/optimizer state offloading. See
:class:`OffloadPolicy` and its subclasses for details.
"""
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
raise ValueError(
f"fully_shard does not support containers that do not implement forward: {module}"
)
mesh = mesh or _init_default_fully_shard_mesh()
if mesh.ndim not in (1, 2):
raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
elif mesh.ndim == 1:
mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
else:
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
device = _get_device_from_mesh(mesh)
post_forward_mesh_info = _get_post_forward_mesh_info(
reshard_after_forward, mesh_info
)
arg_module = module
modules = (
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
)
state = fully_shard.state(modules[0])
state.init(modules, device, mp_policy)
managed_modules = _get_managed_modules(modules)
params, buffers = _get_managed_states(managed_modules)
_move_states_to_device(params, buffers, device)
if params:
state._fsdp_param_group = FSDPParamGroup(
params,
modules,
mesh_info,
post_forward_mesh_info,
device,
mp_policy,
offload_policy,
)
# For Dynamo
for managed_module in managed_modules:
managed_module._is_fsdp_managed_module = True # type: ignore[assignment]
managed_module._fsdp_use_orig_params = True # type: ignore[assignment]
# Place FSDP leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
new_cls = cls_to_fsdp_cls.get(cls, None)
if not new_cls:
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
cls_to_fsdp_cls[cls] = new_cls
module.__class__ = new_cls
return arg_module
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
raise AssertionError(
"FSDP does not support deepcopy. Please use state dict for serialization."
)
class FSDPModule:
def __new__(cls, *args, **kwargs):
"""
Override ``__new__`` to remove the FSDP class and directly construct
the original class for cases like indexing into a container module.
"""
# Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
# and index 1 is the `FSDPModule` class itself
orig_cls = cls.__mro__[2]
self = orig_cls.__new__(orig_cls, *args, **kwargs)
self.__init__(*args, **kwargs)
return self
def reshard(self) -> None:
"""
Reshards the module's parameters, registering the sharded parameters
to the module and freeing the unsharded parameters if needed. This
method is *not* recursive.
"""
state = self._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reshard()
def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]:
"""
Unshards the module's parameters by allocating memory and all-gathering
the parameters. This method is *not* recursive.
Args:
async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
that has a :meth:`wait` method to wait on the unshard op. If
``False``, then returns ``None`` and waits on the handle inside
this function.
.. warning:: This method is experimental and subject to change.
.. note:: If ``async_op=True``, then the user does not have to call
:meth:`wait` on the returned handle if waiting on the unshard op
in the module's pre-forward is tolerable. FSDP will wait on the
pending unshard op in the pre-forward automatically.
"""
state = self._get_fsdp_state()
fsdp_param_group = state._fsdp_param_group
if fsdp_param_group is not None:
fsdp_param_group.lazy_init()
fsdp_param_group.unshard(async_op=async_op)
handle = UnshardHandle(fsdp_param_group)
if async_op:
return handle
handle.wait()
return None
def set_is_last_backward(self, is_last_backward: bool) -> None:
"""
Sets whether the next backward is the last one, meaning that FSDP
should wait for gradient reduction to finish and clear internal data
structures used for explicit prefetching.
"""
state = self._get_fsdp_state()
state._state_ctx.is_last_backward = is_last_backward
def set_requires_gradient_sync(
self, requires_gradient_sync: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should sync gradients. This can be used to implement
gradient accumulation without communication. For HSDP, this controls
both reduce-scatter and all-reduce together.
Args:
requires_gradient_sync (bool): Whether to reduce gradients for the
module's parameters.
recurse (bool): Whether to set for all submodules or just the
passed-in module.
"""
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reduce_grads = requires_gradient_sync
fsdp_param_group.all_reduce_grads = requires_gradient_sync
def set_requires_all_reduce(
self, requires_all_reduce: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should all-reduce gradients. This can be used to
implement gradient accumulation with only reduce-scatter but not
all-reduce for HSDP.
"""
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.all_reduce_grads = requires_all_reduce
def set_reshard_after_backward(
self, reshard_after_backward: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should reshard parameters after backward. This can
be used during gradient accumulation to trade off higher memory for
reduced communication.
Args:
reshard_after_backward (bool): Whether to reshard parameters after
backward.
recurse (bool): Whether to set for all submodules or just the
passed-in module.
"""
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reshard_after_backward = reshard_after_backward
def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None:
"""
Sets the FSDP modules for which this FSDP module should explicitly
prefetch all-gathers in forward. The prefetching runs after this
module's all-gather copy-out.
Passing a singleton list containing the next FSDP module gives the same
all-gather overlap behavior as the default overlap behavior, except the
prefetched all-gather is issued earlier from the CPU. Passing a list
with at least length two is required for more aggressive overlap and
will use more reserved memory.
Args:
modules (List[FSDPModule]): FSDP modules to prefetch.
"""
_assert_all_fsdp_modules(modules)
self._get_fsdp_state()._states_to_forward_prefetch = [
module._get_fsdp_state() for module in modules
]
def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
"""
Sets the FSDP modules for which this FSDP module should explicitly
prefetch all-gathers in backward. This overrides the default backward
pretching implementation that prefetches the next FSDP module based on
the reverse post-forward order.
Passing a singleton list containing the previous FSDP module gives the
same all-gather overlap behavior as the default overlap behavior.
Passing a list with at least length two is required for more aggressive
overlap and will use more reserved memory.
Args:
modules (List[FSDPModule]): FSDP modules to prefetch.
"""
_assert_all_fsdp_modules(modules)
self._get_fsdp_state()._states_to_backward_prefetch = [
module._get_fsdp_state() for module in modules
]
def set_post_optim_event(self, event: torch.cuda.Event) -> None:
"""
Sets a post-optimizer-step event for the root FSDP module to wait the
all-gather streams on.
By default, the root FSDP module waits the all-gather streams on the
current stream to ensure that the optimizer step has finished before
all-gathering. However, this may introduce false dependencies if
there is unrelated computation after the optimizer step. This API
allows the user to provide their own event to wait on. After the root
waits on the event, the event is discarded, so this API should be
called with a new event each iteration.
Args:
event (torch.cuda.Event): Event recorded after the optimizer step
to wait all-gather streams on.
"""
self._get_fsdp_state()._state_ctx.post_optim_event = event
def set_reduce_scatter_divide_factor(self, factor: float) -> None:
"""
Sets a custom divide factor for the reduce-scatter. This becomes a
custom reduce op using NCCL's PreMulSum, which allows multiplying by
the factor before reduction.
Args:
factor (float): Custom divide factor.
"""
state = self._get_fsdp_state()
if (fsdp_param_group := state._fsdp_param_group) is not None:
mul_factor = 1.0 / float(factor)
reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor)
fsdp_param_group.reduce_scatter_reduce_op = reduce_op
def _get_fsdp_state(self) -> FSDPState:
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
raise AssertionError(f"No FSDP state found on {self}")
return state
def _apply(self, *args: Any, **kwargs: Any) -> Any:
# Reshard to ensure that sharded parameters are registered
self.reshard()
ret = super()._apply(*args, **kwargs) # type: ignore[misc]
state = self._get_fsdp_state()
if not (fsdp_param_group := state._fsdp_param_group):
return ret
# TODO: Remove this padding logic once DTensor pads the local tensor:
# https://github.com/pytorch/pytorch/issues/113045
with torch.no_grad():
for fsdp_param in fsdp_param_group.fsdp_params:
fsdp_param.reset_sharded_param()
return ret
class UnshardHandle:
"""
A handle to wait on the unshard op.
Args:
fsdp_param_group (FSDPParamGroup, optional): FSDP parameter group to
unshard. This should be ``None`` iff the FSDP module does not
manage any parameters, meaning the unshard is a no-op.
"""
def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
self._fsdp_param_group = fsdp_param_group
def wait(self):
"""
Waits on the unshard op.
This ensures that the current stream can use the unsharded parameters,
which are now registered to the module.
"""
if self._fsdp_param_group is not None:
self._fsdp_param_group.wait_for_unshard()
# Avoid keeping a reference
self._fsdp_param_group = None
def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
"""
Registers a method on ``module`` to be a forward method for FSDP.
FSDP only knows to run its pre-forward and post-forward hooks on the
default :meth:`nn.Module.forward` method. This function patches a user
specified method to run the pre/post-forward hooks before/after the method,
respectively. If ``module`` is not an :class:`FSDPModule`, then this is a
no-op.
Args:
module (nn.Module): Module to register the forward method on.
method_name (str): Name of the forward method.
"""
if not isinstance(module, FSDPModule):
# Make no-op to allow including both when using/not using FSDP
return
if not hasattr(module, method_name):
raise ValueError(f"{type(module)} does not have a method {method_name}")
orig_method = getattr(module, method_name)
@functools.wraps(orig_method)
def wrapped_method(self, *args, **kwargs):
fsdp_state = self._get_fsdp_state()
args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
out = orig_method(*args, **kwargs)
return fsdp_state._post_forward(self, args, out)
# Use `__get__` to make `wrapped_method` an instance method
setattr(
module,
method_name,
wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined]
)
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
for module in modules:
if not isinstance(module, FSDPModule):
raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}")

View File

@ -0,0 +1,131 @@
# mypy: allow-untyped-decorators
from typing import Callable, Iterable, Optional, Union
from typing_extensions import deprecated
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import contract
from torch.distributed._composable_state import _get_module_state, _insert_module_state
from torch.distributed.fsdp._common_utils import _FSDPState
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handle_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
HYBRID_SHARDING_STRATEGIES,
)
from torch.distributed.fsdp._runtime_utils import (
_register_post_forward_hook,
_register_pre_forward_hook,
_register_root_pre_forward_hook,
)
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _Policy
@contract(state_cls=_FSDPState)
@deprecated(
"`torch.distributed._composable.fully_shard` is being deprecated. "
"You can continue to use the wrapper based FSDP. "
"See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. "
"`torch.distributed._composable.fully_shard` will be removed after PyTorch 2.5.",
category=FutureWarning,
)
def fully_shard(
module: nn.Module,
*,
process_group: Optional[dist.ProcessGroup] = None,
policy: Optional[_Policy] = None,
strategy: Optional[ShardingStrategy] = None,
mixed_precision: Optional[MixedPrecision] = None,
cpu_offload: Optional[CPUOffload] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
device_id: Optional[Union[int, torch.device]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
) -> nn.Module:
"""Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``."""
torch._C._log_api_usage_once("torch.distributed.fully_shard")
# Enforce the new auto wrap policy
if policy is not None and not isinstance(policy, _Policy):
raise ValueError(f"Expects a `_Policy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
state = _init_device_handle(state, module, state._ignored_params, device_id)
_annotate_modules_for_dynamo(module, state._ignored_modules, True)
state = _init_process_group_state(state, process_group, strategy, policy)
if policy is not None:
root_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
"cpu_offload": cpu_offload,
"ignored_modules": ignored_modules,
"device_id": device_id,
"param_init_fn": param_init_fn,
"sync_module_states": sync_module_states,
"forward_prefetch": forward_prefetch,
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
root_kwargs,
fully_shard,
)
state = _init_core_state(
state,
strategy or ShardingStrategy.FULL_SHARD,
mixed_precision,
cpu_offload,
limit_all_gathers=True,
use_orig_params=True,
backward_prefetch_limit=1,
forward_prefetch_limit=1,
)
state = _init_runtime_state(state)
state = _init_prefetching_state(
state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
)
state = _init_buffer_state(state, module)
state = _init_param_handle_from_module(
state, module, device_id, param_init_fn, sync_module_states
)
state = _init_state_dict_state(state)
_register_all_state_dict_hooks(state)
_register_pre_forward_hook(state, module)
_register_post_forward_hook(state, module)
_register_root_pre_forward_hook(state, module) # prepend last
# Always insert the state for the passed-in module even if it has no
# managed parameters, in which case it has no handles and does not appear
# in `_fully_sharded_module_to_handles`
_insert_module_state(module, state)
for submodule in module.modules():
if (
submodule in state._fully_sharded_module_to_handle
and _get_module_state(submodule) is None
):
_insert_module_state(submodule, state)
return module

View File

@ -0,0 +1,256 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import weakref
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple
import torch
import torch.nn as nn
from torch.distributed._composable_state import _State
from torch.nn.parallel import DistributedDataParallel
from .contract import _get_registry, contract
_ROOT_MODULE_PREFIX = ""
class _ReplicateState(_State):
def __init__(self) -> None:
super().__init__()
self.module: nn.Module = nn.ParameterList()
self.has_initialized: bool = False
self._param_list: nn.ParameterList = nn.ParameterList()
# TODO(@fegin): this variable is originally create for testing, we
# should remove this if possible.
self._orig_module = self.module
self._param_names: List[str] = []
self._no_sync: bool = False
self._init_args: Optional[Tuple[Any, ...]] = None
self._init_kwargs: Dict[str, Any] = {}
self._comm_hook_args: List[Any] = []
def _collect_params(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
prefix: str = _ROOT_MODULE_PREFIX,
) -> None:
# skip if managed by fully_sharded API
if _is_fully_sharded(module):
return
# if a module is ignored, all descendants of the module are ignored.
if module in ignored_modules:
return
recurse_prefix = (
f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
)
for n, p in module.named_parameters(recurse=False):
if p not in ignored_params:
self._param_list.append(p)
self._param_names.append(f"{recurse_prefix}{n}")
for name, child_module in module.named_children():
self._collect_params(
child_module,
ignored_modules,
ignored_params,
prefix=f"{recurse_prefix}{name}",
)
def lazy_init(self) -> None:
@torch._disable_dynamo(recursive=True)
def _lazy_init():
assert self._init_args is not None
self.init(*self._init_args, **self._init_kwargs)
self.register_comm_hook()
self._init_args = ()
self._init_kwargs = {}
_lazy_init()
def init(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
**kwargs,
) -> None:
if self.has_initialized:
return
self.has_initialized = True
device_mesh = kwargs.get("device_mesh", None)
self.module = module
ignored_params = {p for m in ignored_modules for p in m.parameters()}
for submodule in module.modules():
if _is_fully_sharded(submodule):
ignored_params.update(submodule.parameters())
from torch.distributed.tensor.parallel.ddp import _localize_dtensor
_localize_dtensor(module, ignored_params=ignored_params)
self._collect_params(module, ignored_modules, ignored_params)
if "device_id" in kwargs:
# replicate() supports a small usability enhancement where
# user can pass in device_id as a Union[int, torch.device] even for
# CPU devices so users don't have to change code for CPU/GPU runs.
# We derive the right device_ids to feed into DDP to support this.
if kwargs["device_id"] is not None:
device_id = kwargs["device_id"]
# Convert to device_ids that DDP expects.
if isinstance(device_id, torch.device) and device_id.type == "cpu":
# CPU modules receive device_ids None
kwargs["device_ids"] = None
else:
# GPU modules expect device_ids=[cuda_device]
kwargs["device_ids"] = [device_id]
else:
kwargs["device_ids"] = None
kwargs.pop("device_id")
self._ddp = DistributedDataParallel(self._param_list, **kwargs)
# Weakref to the DDP instance is currently only used for testing.
replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
def register_comm_hook(self) -> None:
for comm_args, comm_kwargs in self._comm_hook_args:
self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
self._comm_hook_args.clear()
def record_init_args(self, *args, **kwargs) -> None:
self._init_args = args
self._init_kwargs = kwargs
def forward_pre_hook(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
if self._init_args or self._init_kwargs:
self.lazy_init()
self._ddp.require_backward_grad_sync = not self._no_sync
return self._ddp._pre_forward(*args, **kwargs)
def forward_post_hook(
self,
module: nn.Module,
input: Tuple[torch.Tensor],
output: torch.Tensor,
) -> torch.Tensor:
return self._ddp._post_forward(output)
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
raise AssertionError(
"DDP does not support deepcopy. Please use state dict for serialization."
)
# Follow the same pattern as FSDP/fully_shard
class DDP:
def __new__(cls, *args, **kwargs):
"""
Override ``__new__`` to remove the DDP class and directly construct
the original class for cases like indexing into a container module.
"""
# Use index 2 since 0 is the dynamically constructed `DDP<...>` class
# and index 1 is the `DDP` class itself
orig_cls = cls.__mro__[2]
return orig_cls.__new__(orig_cls, *args, **kwargs)
def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
"""
Sets if the module should sync gradients. This can be used to implement
gradient accumulation without communication.
Args:
requires_gradient_sync (bool): Whether to reduce gradients for the
module's parameters.
"""
replicate.state(self)._no_sync = not requires_gradient_sync
def register_comm_hook(self, *args, **kwargs) -> None:
replicate.state(self)._comm_hook_args.append((args, kwargs))
@contract(state_cls=_ReplicateState)
def replicate(
module: nn.Module,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
**kwargs,
) -> nn.Module:
r"""Replicates a module
Args:
module (torch.nn.Module): module to replicate
Example::
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> module = nn.Linear(3, 3)
>>> replicate(module)
"""
torch._C._log_api_usage_once("torch.distributed.replicate")
# TODO(fegin): using kwargs is not a good idea if we would like to make
# replicate a formal API to replace DDP.
if "device_id" in kwargs:
if not isinstance(kwargs["device_id"], (int, torch.device)):
raise RuntimeError(
"Expected device_id to be int or torch.device, "
f"but got {type(kwargs['device_id'])}"
)
if _is_fully_sharded(module):
raise RuntimeError(
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
)
if ignored_modules is None:
ignored_modules = {}
else:
ignored_modules = set(ignored_modules)
state = cast(_ReplicateState, replicate.state(module))
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
device_mesh = kwargs.get("device_mesh", None)
if device_mesh is not None:
from torch.distributed.device_mesh import _mesh_resources
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
# if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh.
if root_mesh != device_mesh:
# TODO: This is a temporary work around to enable DDP + TP.
# We should do the logic in DDP so that the 2D implementation is
# sound and the state_dict works out of the box.
#
# This won't conflict with what is done in DDP class as the module
# replicate is going to pass is NOT the original module.
from torch.distributed.tensor.parallel.ddp import (
_localize_dtensor,
_reconstruct_dtensor,
)
module.register_forward_pre_hook(_reconstruct_dtensor)
module.register_forward_hook(_localize_dtensor)
module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
state.record_init_args(module, ignored_modules, **kwargs)
# Place DDP leftmost for highest priority in the method resolution order
cls = module.__class__
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
module.__class__ = new_cls
return module
def _is_fully_sharded(module: nn.Module) -> bool:
r"""Check if module is marked with fully_shard."""
registry = _get_registry(module)
if registry is None:
return False
return "fully_shard" in registry

View File

@ -0,0 +1,37 @@
from typing import cast, Dict, Optional
import torch.nn as nn
class _State:
pass
_module_state_mapping: Dict[nn.Module, _State] = {}
def _insert_module_state(module: nn.Module, state: _State) -> None:
global _module_state_mapping
assert module not in _module_state_mapping, f"Inserting {module} more than once."
_module_state_mapping[module] = state
def _get_module_state(module: nn.Module) -> Optional[_State]:
"""
Return the ``_State`` in ``model``.
Given a ``module``, this API finds out if the module is also a ``_State``
instance or if the module is managed by a composable API. If the module
is also a ``_State``, ``module`` will be casted to ``_State` and returned.
If it is managed by a composable API, the corresponding ``_State`` will
be returned.
"""
global _module_state_mapping
if isinstance(module, _State):
return cast(_State, module)
else:
# https://github.com/pytorch/pytorch/issues/107054
if module in _module_state_mapping:
return _module_state_mapping[module]
else:
return None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,117 @@
# mypy: allow-untyped-defs
from typing import List, Optional
import torch
import torch.distributed.distributed_c10d as c10d
"""
This file contains the op impls for the legacy (c10d_functional) functional collectives.
These impls simply call into the native (_c10d_functional) functional collectives.
"""
def _broadcast(input, src, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.broadcast(
input,
src,
group_name,
)
def _all_reduce(input, reduce_op, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_reduce(
input,
reduce_op,
group_name,
)
def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_reduce_coalesced(
inputs,
reduce_op,
group_name,
)
def _all_gather_into_tensor(input, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_gather_into_tensor(
input,
group_size,
group_name,
)
def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
input,
group_size,
group_name,
)
def _reduce_scatter_tensor(
input: torch.Tensor,
reduce_op: str,
tag: str,
ranks: List[int],
group_size: int,
):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.reduce_scatter_tensor(
input,
reduce_op,
group_size,
group_name,
)
def _reduce_scatter_tensor_coalesced(
inputs: List[torch.Tensor],
reduce_op: str,
tag: str,
ranks: List[int],
group_size: int,
):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
inputs,
reduce_op,
group_size,
group_name,
)
def _all_to_all_single(
input: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
tag: str,
ranks: List[int],
group_size: int,
):
if output_split_sizes is None or input_split_sizes is None:
assert output_split_sizes is None and input_split_sizes is None, (
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
output_split_sizes = [input.shape[0] // group_size] * group_size
input_split_sizes = output_split_sizes
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_to_all_single(
input,
output_split_sizes,
input_split_sizes,
group_name,
)
def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
return torch.ops._c10d_functional.wait_tensor(tensor)

View File

@ -0,0 +1 @@
from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter

View File

@ -0,0 +1,32 @@
from typing import Sequence
import torch
from torch.distributed._shard.metadata import ShardMetadata
DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor."
def narrow_tensor_by_index(
tensor: torch.Tensor,
offsets: Sequence[int],
sizes: Sequence[int],
) -> torch.Tensor:
"""
Narrow the tensor according to ``offsets`` and ``sizes``.
"""
narrowed_tensor = tensor
for idx, (offset, size) in enumerate(zip(offsets, sizes)):
if size < tensor.size(idx):
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
narrowed_tensor = narrowed_tensor.narrow(idx, offset, size)
return narrowed_tensor
def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor:
"""
Narrow the tensor according to the metadata
"""
return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes)

View File

@ -0,0 +1,306 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from .sharder import Sharder
from .sharding_plan import ShardingPlan
from .sharding_spec import ChunkShardingSpec, ShardingSpec
def _shard_tensor(
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
) -> ShardedTensor:
"""
Given a :class:`torch.Tensor`, it shards that tensor according to the provided
``sharding_spec``. ``src_rank`` denotes the source rank which would be
used as the ground truth of the data which would be scattered as shards
across the rest of the ranks.
Args:
tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
A :class:`ShardedTensor` sharded from the given tensor.
.. warning::
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
currently supported as the ``sharding_spec``.
"""
if not tensor.is_contiguous():
raise ValueError("input tensor is not a contiguous Tensor")
pg = (
process_group
if process_group is not None
else distributed_c10d._get_default_group()
)
world_size = dist.get_world_size(pg)
current_rank = dist.get_rank(pg)
# Validate src_rank and sharding_spec are same across all ranks.
gathered_list = [None] * world_size
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
for idx, entry in enumerate(gathered_list):
if src_rank != entry[0]: # type: ignore[index]
raise ValueError(
f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index]
f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index]
)
if sharding_spec != entry[1]: # type: ignore[index]
raise ValueError(
f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index]
f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index]
)
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg)
return st
def shard_parameter(
module: torch.nn.Module,
param_name: str,
sharding_spec: ShardingSpec,
src_rank=0,
process_group=None,
):
"""
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
module, it shards that parameter according to the provided
``sharding_spec``. ``src_rank`` denotes the source rank which would be
used as the ground truth of the data which would be scattered as shards
across the rest of the ranks.
This method replaces ``module.param_name`` with a
:class:`torch.distributed._sharded_tensor.ShardedTensor`
Args:
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
.. warning::
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
currently supported as the ``sharding_spec``.
"""
# Perform some validation first.
if not hasattr(module, param_name):
raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`")
tensor = getattr(module, param_name)
if not isinstance(tensor, torch.Tensor):
raise ValueError(
f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}"
)
if not tensor.is_contiguous():
raise ValueError(f"param: {param_name} is not a contiguous Tensor")
st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
# Replace param with ShardedTensor.
module.register_parameter(param_name, nn.Parameter(st))
# Tracks the current process group in the load context manager.
_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None
@contextmanager
def load_with_process_group(process_group):
"""
Context manager to set the process group with which to load a ShardedTensor.
"""
global _CURRENT_PROCESS_GROUP
if _CURRENT_PROCESS_GROUP is not None:
raise RuntimeError(
'ProcessGroup already set by previous "load_with_process_group" '
"context manager"
)
_CURRENT_PROCESS_GROUP = process_group
try:
yield process_group
finally:
_CURRENT_PROCESS_GROUP = None
def _get_current_process_group():
"""
Retrieves the current process group set by ``load_with_process_group``.
If not set, it just returns the default group.
"""
global _CURRENT_PROCESS_GROUP
if _CURRENT_PROCESS_GROUP is None:
return distributed_c10d._get_default_group()
else:
return _CURRENT_PROCESS_GROUP
def _reshard_output(
module: torch.nn.Module, resharding_spec: ShardingSpec
) -> torch.nn.Module:
"""
Hook a module with output resharding in the forward pass according
to the given ``resharding_spec``.
Args:
module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
The specification describing how the output of the module will be resharded.
Returns:
A :class:`torch.nn.Module` object with reshard API hooked.
"""
def hook_func(_module, _input, output):
if isinstance(output, ShardedTensor):
return output.reshard(resharding_spec)
return output
module.register_forward_hook(hook_func)
return module
def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
"""
Hook a module with local shards collection in the forward pass.
This API is typically used to convert a sharded representation back to data parallel
representation. In particular, it returns the local tensor for this Shard. If the
size along the sharding dimension for the local tensor is 1, this dimension is removed
from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
a local Tensor of size [16] across each rank and not [1, 16] across each rank.
Args:
module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
local tensor value needs to be returned.
Returns:
A :class:`torch.nn.Module` object with collection API hooked.
"""
def hook_func(_module, _input, output):
if isinstance(output, ShardedTensor):
local_tensor = output.local_tensor()
# Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
sharding_spec = output._sharding_spec
if (
isinstance(sharding_spec, ChunkShardingSpec)
and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type]
):
local_tensor = local_tensor.squeeze(
output._sharding_spec.dim # type: ignore[attr-defined]
)
return local_tensor
module.register_forward_hook(hook_func)
return module
def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None):
"""
Shards a given module according to the provided sharding `plan`. This method
first shards all the parameters according to the given sharding `plan`. Then if
`output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
will tag the output of modules according `output_plan`, convert the module's
output back to data parallel according to `return_local_tensor`.
Needs to be called on all ranks in an SPMD fashion.
Args:
module (:class:`torch.nn.Module`): The module to apply sharding to
plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
The ShardingPlan which specified param name to ShardingSpec to apply to
each parameter.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the module that would be sharded and scattered across the rest
of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
"""
# record Sharder paths for sanity check on the plan to ensure items in the plan
# does not conflict with the submodule tree that the Sharder is working with
sharder_paths = []
for name, spec in plan.plan.items():
if isinstance(spec, Sharder):
sharder_paths.append(name)
# shard the parameter according to the ShardingPlan
for name, spec in plan.plan.items():
if isinstance(spec, ShardingSpec):
# if found a sharding spec, try to shard the parameter
module_path, _, param_name = name.rpartition(".")
for sharder_path in sharder_paths:
if module_path.startswith(sharder_path):
raise RuntimeError(
f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
f" but there's already a Sharder entry for module {sharder_path},"
f" parameter sharding should not conflict with the submodule tree"
f" that a Sharder is working with!"
)
mod = module.get_submodule(module_path)
shard_parameter(
mod, param_name, spec, src_rank=src_rank, process_group=process_group
)
elif isinstance(spec, Sharder):
parent_mod_path, _, mod_name = name.rpartition(".")
if name == "":
raise KeyError("Module path must not be empty for custom sharder!")
mod = module.get_submodule(name)
parent_mod = module.get_submodule(parent_mod_path)
sharded_mod = spec.shard(mod)
# swap this submodule with the sharded module
parent_mod.mod_name = sharded_mod
else:
raise TypeError(
f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'"
)
# reshard output if there's an entry in `reshard_output` for this module
if plan.output_plan is not None:
for module_path, output_spec in plan.output_plan.items():
if isinstance(output_spec, ShardingSpec):
mod = module.get_submodule(module_path)
_reshard_output(mod, output_spec)
else:
raise TypeError(
f"Only `ShardingSpec` is supported as output_plan for '{module_path}'"
)
# convert the output back to data parallel for the modules appears in
# `return_local_tensor` of the plan, we will call `_collect_local_shard`
# to collect the local tensor for output of modules
if plan.return_local_tensor is not None:
for module_path in plan.return_local_tensor:
mod = module.get_submodule(module_path)
_collect_local_shard(mod)

View File

@ -0,0 +1,19 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `torch.distributed.checkpoint` package.
import sys
import warnings
import torch
from torch.distributed.checkpoint import * # noqa: F403
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`torch.distributed._shard.checkpoint` will be deprecated, "
"use `torch.distributed.checkpoint` instead",
DeprecationWarning,
stacklevel=2,
)
sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint

View File

@ -0,0 +1,65 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch.utils import _pytree as pytree
def _basic_validation(op, args=(), kwargs=None):
"""
Common validation across all ops go in here.
"""
from torch.distributed._shard.sharded_tensor import ShardedTensor
if len(args) == 0 and (kwargs is None or len(kwargs) == 0):
raise ValueError(f" No input for '{op.__name__}'!")
# Validate types
has_distributed_tensor = False
def is_distributed_tensor(e):
nonlocal has_distributed_tensor
if isinstance(e, ShardedTensor):
has_distributed_tensor = True
pytree.tree_map_(is_distributed_tensor, args)
pytree.tree_map_(is_distributed_tensor, kwargs)
if not has_distributed_tensor:
raise TypeError(
f"torch function '{op.__name__}', with args: {args} and "
f"kwargs: {kwargs} are called without any distributed tensor!"
)
# Validate all distributed tensors use the same PG.
cur_pg: Optional[torch.distributed.ProcessGroup] = None
def validate_pg(e):
nonlocal cur_pg
if isinstance(e, ShardedTensor):
if cur_pg is not None and e._process_group is not cur_pg:
raise RuntimeError(
"All distributed tensors should use the "
"same ProcessGroup if used together in an op."
)
cur_pg = e._process_group
pytree.tree_map_(validate_pg, args)
pytree.tree_map_(validate_pg, kwargs)
def _register_default_op(op, decorator):
@decorator(op)
def tensor_default_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the default tensor ops that
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
``torch.Tensor.dtype``. We simply lower to the real op call with
DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
to avoid recursions.
"""
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunctionSubclass():
return op(*args, **kwargs)

View File

@ -0,0 +1,64 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import reduce
from typing import List, Optional, Union
from torch.distributed.remote_device import _remote_device
@dataclass
class ShardMetadata:
"""
Represents a shard of the overall Tensor including its
offsets, lengths and device placement.
Args:
shard_offsets(List[int]): Offsets in the original tensor indicating
the start offsets for this shard. Should have the same rank as
the original tensor.
shard_sizes(List[int]): Integers indicating the size of each
dimension for this shard. Should have the same rank as the
original tensor.
placement(:class:`torch.distributed._remote_device`):
Specifies the placement of this shard.
"""
__slots__ = ["shard_offsets", "shard_sizes", "placement"]
shard_offsets: List[int]
shard_sizes: List[int]
placement: Optional[_remote_device]
def __init__(
self,
shard_offsets: List[int],
shard_sizes: List[int],
placement: Optional[Union[str, _remote_device]] = None,
):
self.shard_offsets = shard_offsets
self.shard_sizes = shard_sizes
if isinstance(placement, str):
self.placement = _remote_device(placement)
else:
self.placement = placement
if len(self.shard_offsets) != len(self.shard_sizes):
raise ValueError(
f"shard_offsets and shard_sizes should have "
f"the same number of elements, found {len(self.shard_offsets)} "
f"and {self.shard_sizes} respectively"
)
for i in range(len(self.shard_offsets)):
if self.shard_offsets[i] < 0:
raise ValueError("shard_offsets should be >=0")
if self.shard_sizes[i] < 0:
raise ValueError("shard_sizes should be >= 0")
def __hash__(self):
def _hash_reduce(a, b):
return (a << 8) + hash(b)
res = reduce(_hash_reduce, self.shard_offsets, 37)
res = reduce(_hash_reduce, self.shard_sizes, res)
res = _hash_reduce(res, self.placement)
return res

View File

@ -0,0 +1,41 @@
# mypy: allow-untyped-defs
import functools
from inspect import signature
from .common_op_utils import _basic_validation
"""
Common utilities to register ops on ShardedTensor
and PartialTensor.
"""
def _register_op(op, func, op_table):
"""
Performs basic validation and registers the provided op in the given
op_table.
"""
if len(signature(func).parameters) != 4:
raise TypeError(
f"Custom sharded op function expects signature: "
f"(types, args, kwargs, process_group), but received "
f"signature: {signature(func)}"
)
op_table[op] = func
def _decorator_func(wrapped_func, op, op_table):
"""
Decorator function to register the given ``op`` in the provided
``op_table``
"""
@functools.wraps(wrapped_func)
def wrapper(types, args, kwargs, process_group):
_basic_validation(op, args, kwargs)
return wrapped_func(types, args, kwargs, process_group)
_register_op(op, wrapper, op_table)
return wrapper

View File

@ -0,0 +1,52 @@
from typing import Iterator, Tuple, Union
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from .api import ShardedOptimizer
def named_params_with_sharded_tensor(
module: nn.Module,
prefix: str = "",
recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ShardedTensor]]]:
r"""Returns an iterator over module parameters (together with the
ShardedTensor parameters), yielding both the name of the parameter
as well as the parameter itself. This is typically passed to a
:class:torch.distributed._shard.sharded_optim.ShardedOptimizer
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(str, Union[Tensor, ShardedTensor]): Tuple containing
the name and parameter (or ShardedTensor parameter)
Example::
>>> # xdoctest: +SKIP
>>> model = torch.nn.Linear(*linear_size)
>>> shard_parameter(model, "weight", spec)
>>> for name, param in named_params_with_sharded_tensor(model):
>>> if name in ['weight']:
>>> print(param.size())
"""
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
memo = set()
for mod_prefix, mod in modules:
# find all sharded tensor params
for name, val in vars(mod).items():
if isinstance(val, ShardedTensor) and val not in memo:
memo.add(val)
name = mod_prefix + ("." if mod_prefix else "") + name
yield name, val
# find all nn.Parameters
for name, val in module.named_parameters():
yield name, val

View File

@ -0,0 +1,100 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Mapping, Union
import torch.optim as optim
from torch import Tensor
from torch.distributed._shard.sharded_tensor import ShardedTensor
class ShardedOptimizer(optim.Optimizer):
def __init__(
self,
named_params: Mapping[str, Union[Tensor, ShardedTensor]],
optimizer_class,
*optimizer_args,
**optimizer_kwargs,
):
"""
ShardedOptimizer collects all tensors and local shard tensors of
ShardedTensor, then use these tensors as ``params`` for optimizers
Args:
named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
of parameters, where key is the parameter key, value is either
Tensor or ShardedTensor parameter.
optimizer_class (torch.optim.Optimizer): the Optimizer to use
locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
*optimizer_args: the arguments to initialize the optimizer.
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
"""
tensors: List[Tensor] = []
for value in named_params.values():
if isinstance(value, ShardedTensor):
for local_shard in value.local_shards():
tensors.append(local_shard.tensor)
else:
tensors.append(value)
self.named_params = named_params
self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
self.param_groups = self._optim.param_groups
self.state = self._optim.state
def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
Args:
set_to_none (bool): instead of setting to zero, set the grads to None.
This will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When the user tries to access a gradient and perform manual ops on it,
a None attribute or a Tensor full of 0s will behave differently.
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
are guaranteed to be None for params that did not receive a gradient.
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skips
the step altogether).
"""
self._optim.zero_grad(set_to_none)
def step(self, closure=None):
r"""Performs a single optimization step (parameter update).
Args:
closure (Callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note::
Unless otherwise specified, this function should not modify the
``.grad`` field of the parameters.
"""
self._optim.step(closure)
def state_dict(self) -> Dict[str, Any]:
"""
Returned state and param_groups will contain parameter keys
instead of parameter indices like torch.optim.Optimizer.
This allows for advanced functionality like optimizer re-sharding to be implemented.
"""
# TODO: implement state_dict
raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!")
def load_state_dict(self, state_dict: Mapping[str, Any]):
r"""Loads the ShardedOptimizer state.
Args:
state_dict (dict): ShardedOptimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# TODO: implement load_state_dict
raise NotImplementedError(
"ShardedOptimizer load_state_dict not implemented yet!"
)
def add_param_group(self, param_group: Any):
r"""Add a new param group"""
# TODO: implement add_param_group
raise NotImplementedError(
"ShardedOptimizer add_param_group not implemented yet!"
)

View File

@ -0,0 +1,490 @@
# mypy: allow-untyped-defs
import functools
from typing import List, TYPE_CHECKING
import torch
from torch.distributed._shard.op_registry_utils import _decorator_func
from .api import (
_CUSTOM_SHARDED_OPS,
_SHARDED_OPS,
Shard,
ShardedTensor,
ShardedTensorBase,
ShardedTensorMetadata,
TensorProperties,
)
from .metadata import ShardMetadata # noqa: F401
if TYPE_CHECKING:
from torch.distributed._shard.sharding_spec import ShardingSpec
else:
ShardingSpec = "ShardingSpec"
def empty(
sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False,
) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with uninitialized data.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
returned Tensor. Default: ``torch.contiguous_format``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
return ShardedTensor(
sharding_spec,
*size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
process_group=process_group,
init_rrefs=init_rrefs,
)
def ones(
sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False,
) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` with the scalar value 1.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
return full(
sharding_spec,
size,
fill_value=1,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
process_group=process_group,
init_rrefs=init_rrefs,
)
def zeros(
sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False,
) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with the scalar value 0.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
return full(
sharding_spec,
size,
fill_value=0,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
process_group=process_group,
init_rrefs=init_rrefs,
)
def full(
sharding_spec: ShardingSpec,
size,
fill_value,
*,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False,
) -> ShardedTensor:
"""
Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype
is inferred from fill_value. If dtype is specified, it will override the
inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
output tensor.
fill_value (Scalar) - the value to fill the output tensor with.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
sharded_tensor = ShardedTensor(
sharding_spec,
*size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
process_group=process_group,
init_rrefs=init_rrefs,
)
torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type]
return sharded_tensor
def rand(
sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False,
) -> ShardedTensor:
"""
Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution
on the interval :math:`[0, 1)`. The shape of the tensor is defined by the
variable argument `size`. Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
output tensor.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
sharded_tensor = ShardedTensor(
sharding_spec,
*size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
process_group=process_group,
init_rrefs=init_rrefs,
)
torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type]
return sharded_tensor
def randn(
sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False,
) -> ShardedTensor:
"""
Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution
with mean `0` and variance `1` (also called standard normal distribution). The shape
of the tensor is defined by the variable argument `size`. Needs to be called on all ranks
in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
output tensor.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
sharded_tensor = ShardedTensor(
sharding_spec,
*size,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
process_group=process_group,
init_rrefs=init_rrefs,
)
torch.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type]
return sharded_tensor
def init_from_local_shards(
local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False
) -> ShardedTensor:
"""
Creates an :class:`ShardedTensor` from local shards and the global metadata.
Needs to be called on all ranks in an SPMD fashion.
Args:
local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list
of shards that represent the local shards on this rank.
global_size (int...): a list, tuple, or `torch.Size` of integers defining the
shape of the overall sharded tensor.
Keyword args:
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object handle on this rank
Examples:
Suppose we want construct a sharded tensor on two ranks, global size = (10, 5),
each shard have a (5, 5) local tensor, we can do it like below:
on rank 0:
>>> # xdoctest: +SKIP("not distributed")
>>> local_shard_metadata = ShardMetadata(
>>> shard_offsets=[0, 0],
>>> shard_lengths=[5, 5],
>>> placement="rank:0/cuda:0"
>>> )
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
on rank 1:
>>> # xdoctest: +SKIP("not distributed")
>>> local_shard_metadata = ShardMetadata(
>>> shard_offsets=[5, 0],
>>> shard_lengths=[5, 5],
>>> placement="rank:1/cuda:1"
>>> )
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
"""
return ShardedTensor._init_from_local_shards(
local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs
)
def state_dict_hook(module, destination, prefix, local_metadata):
"""
Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
registered to the Module using
:meth:`torch.nn.Module._register_state_dict_hook`.
"""
for submodule_name, submodule in module.named_modules():
for attr_name, attr in submodule.__dict__.items():
if isinstance(attr, ShardedTensor):
mod_prefix = prefix + submodule_name
key = mod_prefix + ("." if mod_prefix else "") + attr_name
destination[key] = attr
def pre_load_state_dict_hook(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""
Pre-load state dict hook to add ShardedTensor to the module.
"""
for submodule_name, submodule in module.named_modules():
for attr_name in submodule.__dict__.keys():
mod_prefix = prefix + submodule_name
key = mod_prefix + ("." if mod_prefix else "") + attr_name
if key in state_dict:
if isinstance(state_dict[key], ShardedTensor):
setattr(submodule, attr_name, state_dict[key])
def custom_sharded_op_impl(func):
"""
Provides a way for users to write their own custom sharded operator. This
can be used to override existing ShardedTensor operators or write a new
one not supported by ShardedTensor. If the operator in question is covered
by ``__torch_function__`` dispatch and has a ShardedTensor as any of its
parameters, the function provided will be invoked for that operator.
Example::
>>> # xdoctest: +SKIP
>>> @custom_sharded_op_impl(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ...
>>> # xdoctest: +SKIP("Undefined variables")
>>> input = torch.rand(10, 32)
>>> weight = sharded_tensor.rand(32, 16)
>>> bias = torch.rand(16)
>>> # This will call 'my_custom_sharded_linear'
>>> torch.nn.functional.linear(input, weight, bias)
The types, args and kwargs parameters are the same parameters that are
passed to ``__torch_function__`` dispatch API
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
There is an additional ``process_group`` parameter which is the
process_group used for the ShardedTensor and can be used by
implementations for communications within a sharded implementation.
Args:
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS)
def _sharded_op_impl(func):
"""
Decorator to register a default sharded op.
"""
return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS)
# Import all builtin sharded ops
from ._ops import * # noqa: F403

View File

@ -0,0 +1,13 @@
import torch.distributed._shard.sharded_tensor._ops.misc_ops
import torch.distributed._shard.sharded_tensor._ops.tensor_ops
# Import all ChunkShardingSpec ops
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import (
sharded_embedding,
)
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import (
sharded_embedding_bag,
)
from .binary_cmp import allclose, equal
from .init import constant_, kaiming_uniform_, normal_, uniform_

View File

@ -0,0 +1,113 @@
# mypy: allow-untyped-defs
import functools
from torch.distributed._shard.common_op_utils import _basic_validation
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
Shard,
ShardedTensor,
)
def _sharded_op_common(op, early_stop_func, extra_check):
"""
Inject sharded tensor op registration with common logics executed before
different behaviors are done on either local shards or a local tensor.
Example::
>>> # xdoctest: +SKIP("Undefined variables")
>>> op = torch.transpose
>>> @_sharded_op_impl(op)
>>> @_sharded_op_common(op, early_stop_func, extra_check)
>>> def sharded_tensor_op(types, args, kwargs, process_group):
>>> ...
>>>
>>> st = sharded_tensor.rand(32, 16)
>>> st.transpose(1, 2)
>>> # This will call '_sharded_op_common'
Args:
op: The op to be registered and applied to all shards of the st.
early_stop_func (Callable, optional): the func for early stop.
Default: if ``None``, no early stop.
extra_check (Callable, optional): the func for extra condition check.
Default: if ``None``, no extra check.
Return:
func (Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.transpose)
"""
def decorator_sharded_func(wrapped_func):
@functools.wraps(wrapped_func)
def wrapper(types, args=(), kwargs=None, pg=None):
_basic_validation(op, args, kwargs)
st = args[0]
if kwargs is None:
kwargs = {}
if extra_check:
extra_check(*args, **kwargs)
if early_stop_func:
early_stop = early_stop_func(*args, **kwargs)
if early_stop:
return st
return wrapped_func(types, args, kwargs, pg)
return wrapper
return decorator_sharded_func
def _register_sharded_op_on_local_shards(
op, early_stop_func=None, extra_check=None, customized_func=None
):
"""
Handles ``__torch_function__`` dispatch for ops which are performed on
each shard of the sharded tensor such as elementwise op like
``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
For more complicated ops, a customized func can be used to generate
the new shards and sharded tensor size.
This function expects that the original ShardingSpec for the ShardedTensor
is preserved irrespective of whether or not a customized function is used.
Args:
op: The op to be registered and applied to all shards of the st.
early_stop_func (Callable, optional): the func for early stop.
Default: if ``None``, no early stop.
extra_check (Callable, optional): the func for extra condition check.
Default: if ``None``, no extra check.
customized_func (Callable, optional): the func for customized logic
to generate new shards and sharded tensor size.
Default: if ``None``, we simply lower to the real op call with
all local shards of the st.
Return:
func (Callable): registered implementation for sharded op for
``__torch_function__`` dispatch.
"""
@_sharded_op_impl(op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
st = args[0]
st_metadata = st.metadata()
local_shards = st.local_shards()
local_shards_new = []
if customized_func:
local_shards_new, st_metadata = customized_func(args, kwargs, pg)
else:
for local_shard in local_shards:
args = (local_shard.tensor, *args[1:])
local_shards_new.append(
Shard(op(*args, **kwargs), local_shard.metadata)
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards_new,
st_metadata,
process_group=pg,
init_rrefs=st._init_rrefs,
sharding_spec=st.sharding_spec(),
)

View File

@ -0,0 +1,79 @@
# mypy: allow-untyped-defs
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as distributed_c10d
from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor
def _communicate_result(result, pg):
# Gather results from all ranks.
if result:
result_tensor = torch.ones(1, device=torch.device(torch.cuda.current_device()))
else:
result_tensor = torch.zeros(1, device=torch.device(torch.cuda.current_device()))
dist.all_reduce(result_tensor, group=pg)
expected_result = torch.ones(
1, device=torch.device(torch.cuda.current_device())
) * dist.get_world_size(pg)
return torch.equal(result_tensor, expected_result)
def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
if len(args) != 2:
raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}")
result = True
st1 = args[0]
st2 = args[1]
if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
raise TypeError(
f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor"
)
# Verify same PG
if st1._process_group != st2._process_group:
return False
if distributed_c10d._rank_not_in_group(
st1._process_group
) or distributed_c10d._rank_not_in_group(st2._process_group):
return distributed_c10d._rank_not_in_group(
st1._process_group
) == distributed_c10d._rank_not_in_group(st2._process_group)
# Verify metadata
if st1.metadata() != st2.metadata():
return _communicate_result(False, st1._process_group)
# Verify number of local shards
st1_local_shards = st1.local_shards()
st2_local_shards = st2.local_shards()
if len(st1_local_shards) != len(st2_local_shards):
return _communicate_result(False, st1._process_group)
# kwargs must be dict-like
if kwargs is None:
kwargs = {}
# Verify each local shard
for idx in range(len(st1_local_shards)):
if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata:
return _communicate_result(False, st1._process_group)
if not cmp_fun(
st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs
):
return _communicate_result(False, st1._process_group)
return _communicate_result(True, st1._process_group)
@_sharded_op_impl(torch.equal)
def equal(types, args, kwargs, process_group):
return binary_cmp(torch.equal, types, args, kwargs, process_group)
@_sharded_op_impl(torch.allclose)
def allclose(types, args, kwargs, process_group):
return binary_cmp(torch.allclose, types, args, kwargs, process_group)

View File

@ -0,0 +1,151 @@
# mypy: allow-untyped-defs
import torch
import torch.distributed._shard.sharded_tensor as sharded_tensor
from torch.distributed._shard.sharded_tensor import _sharded_op_impl
def validate_param(param, param_name):
if param is None:
raise ValueError(f"param: {param_name} shouldn't be None!")
@_sharded_op_impl(torch.nn.init.uniform_)
def uniform_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensor in tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
tensor: tensor sharded across devices
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
"""
validate_param(kwargs, "kwargs")
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
a = kwargs["a"]
validate_param(a, "a")
b = kwargs["b"]
validate_param(b, "b")
for shard in sharded_tensor.local_shards():
torch.nn.init.uniform_(shard.tensor, a=a, b=b)
return sharded_tensor
@_sharded_op_impl(torch.nn.init.normal_)
def normal_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensors in tensor.local_shards with values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
tensor: tensor sharded across devices
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
"""
validate_param(kwargs, "kwargs")
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
mean = kwargs["mean"]
validate_param(mean, "mean")
std = kwargs["std"]
validate_param(std, "std")
for shard in sharded_tensor.local_shards():
torch.nn.init.normal_(shard.tensor, mean=mean, std=std)
return sharded_tensor
@_sharded_op_impl(torch.nn.init.kaiming_uniform_)
def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensors in tensor.local_shards with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: tensor sharded across devices
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
validate_param(kwargs, "kwargs")
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
a = kwargs["a"]
validate_param(a, "a")
mode = kwargs["mode"]
validate_param(mode, "mode")
nonlinearity = kwargs["nonlinearity"]
validate_param(nonlinearity, "nonlinearity")
for shard in sharded_tensor.local_shards():
torch.nn.init.kaiming_uniform_(
shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity
)
return sharded_tensor
@_sharded_op_impl(torch.nn.init.constant_)
def constant_(types, args=(), kwargs=None, pg=None):
r"""
Fills the input ShardedTensor with the value \text{val}val.
Args:
tensor: tensor sharded across devices
val: the value to fill the tensor with
"""
validate_param(kwargs, "kwargs")
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
val = kwargs["val"]
validate_param(val, "val")
for shard in sharded_tensor.local_shards():
torch.nn.init.constant_(shard.tensor, val=val)
return sharded_tensor
tensor_like_creation_op_map = {
torch.full_like: sharded_tensor.full,
torch.empty_like: sharded_tensor.empty,
torch.zeros_like: sharded_tensor.zeros,
torch.ones_like: sharded_tensor.ones,
torch.rand_like: sharded_tensor.rand,
torch.randn_like: sharded_tensor.randn,
}
# tensor ops that behave the same as the default tensor
def register_tensor_creation_op(op):
@_sharded_op_impl(op)
def tensor_creation_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for tensor creation ops that
takes a ShardedTensor as argument, such as ``torch.zeros_like`` or
``torch.full_like``.
"""
creation_op = tensor_like_creation_op_map.get(op, None)
if creation_op is None:
raise RuntimeError(f"Tensor creation {op} not supported!")
if kwargs is None:
kwargs = {}
st = args[0]
new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator]
return new_st
register_tensor_creation_op(torch.full_like)
register_tensor_creation_op(torch.empty_like)
register_tensor_creation_op(torch.zeros_like)
register_tensor_creation_op(torch.ones_like)
register_tensor_creation_op(torch.rand_like)
register_tensor_creation_op(torch.randn_like)

View File

@ -0,0 +1,12 @@
# mypy: allow-untyped-defs
import torch
from torch.distributed._shard.sharded_tensor import _sharded_op_impl
# This is used by `_apply()` within module.py to set new
# parameters after apply a certain method, we should follow
# the future behavior of overwriting the existing tensor
# instead of doing in-place change using `.data = `.
@_sharded_op_impl(torch._has_compatible_shallow_copy_type)
def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None):
return False

View File

@ -0,0 +1,218 @@
# mypy: allow-untyped-defs
import copy
import torch
from torch.distributed._shard.common_op_utils import _register_default_op
from torch.distributed._shard.sharded_tensor import (
_sharded_op_impl,
Shard,
ShardedTensor,
)
from ._common import _register_sharded_op_on_local_shards
# Tensor properties access
_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.size, _sharded_op_impl)
_register_default_op(torch.Tensor.dim, _sharded_op_impl)
_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl)
_register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
_register_default_op(torch.Tensor.is_floating_point, _sharded_op_impl)
# __reduce_ex__ to dispatch to get_state/set_state
_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
# autograd related properties
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
# TODO: set grad with a ShardedTensor that consists of all local grads
_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl) # type: ignore[union-attr]
_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[union-attr]
_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined]
# device property is ambiguous as from a global prospective,
# ShardedTensor.device consists of multiple devices (might even across hosts)
# We choose to return the current device of the local tensor to represent
# the device property on each rank
@_sharded_op_impl(torch.Tensor.device.__get__)
def tensor_device(types, args=(), kwargs=None, pg=None):
self_st = args[0]
# Validate types
if not isinstance(self_st, ShardedTensor):
raise TypeError("input needs to be a ShardedTensor")
dev: torch.device
if self_st._local_shards:
dev = self_st._local_shards[0].tensor.device
elif pg and pg._get_backend_name() == "gloo":
dev = torch.device("cpu")
else:
dev = torch.device(torch.cuda.current_device())
return dev
@_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined]
def st_is_meta(types, args=(), kwargs=None, pg=None):
return args[0].local_tensor().is_meta
def sharded_type_as_check(*args, **kwargs):
"""
Perform extra checks for the sharded_type_as op such as the input needs to
be either a Tensor or ShardedTensor.
Args: same as ``torch.Tensor.type_as``.
Return: None
"""
if len(args) < 2:
raise ValueError("Needs to give a tensor to cast type as!")
if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
def same_dtype(*args, **kwargs):
"""
When the dtype is the same, return the original ShardedTensor.
Args: same as ``torch.Tensor.type_as``.
Return (bool): Whether to return early or not.
"""
return args[0].dtype == args[1].dtype
def sharded_type_as(args, kwargs, pg):
"""
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
Args: same as ``torch.Tensor.type_as``.
Return:
new_local_shards (List[Shard]): Local shards for the new sharded tensor.
st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
"""
st = args[0]
tensor = args[1]
if isinstance(tensor, ShardedTensor):
tensor = tensor.local_tensor()
new_local_shards = []
for shard in st.local_shards():
new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
st_meta = copy.deepcopy(st._metadata)
st_meta.tensor_properties.dtype = tensor.dtype
return new_local_shards, st_meta
_register_sharded_op_on_local_shards(
torch.Tensor.type_as,
early_stop_func=same_dtype,
extra_check=sharded_type_as_check,
customized_func=sharded_type_as,
)
def sharded_deepcopy(args, kwargs, pg):
# NOTE: we directly implement deepcopy magic method
# instead of using the default tensor.__deepcopy__
# and implement clone(). This is because the default
# tensor deepcopy copies every attribute, but the
# process_group in ShardedTensor cannot be deep copied.
self_st = args[0]
new_local_shards = copy.deepcopy(self_st.local_shards())
new_metadata = copy.deepcopy(self_st.metadata())
return new_local_shards, new_metadata
_register_sharded_op_on_local_shards(
torch.Tensor.__deepcopy__,
customized_func=sharded_deepcopy,
)
@_sharded_op_impl(torch.Tensor.copy_)
def sharded_inplace_copy(types, args, kwargs, pg):
# NOTE: inplace op don't need to rewrap
kwargs = {} if kwargs is None else kwargs
self_st = args[0]
new_st = args[1]
nonblocking = kwargs.get("non_blocking", False)
for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
if local_shard.metadata != new_shard.metadata:
raise RuntimeError(
"inplace copy can only happen between two ShardedTensor with same metadata!"
)
for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
local_shard.tensor.copy_(new_shard.tensor, nonblocking)
return self_st
def sharded_clone(args, kwargs, pg):
self_st = args[0]
desire_memory_format = kwargs.get("memory_format", None)
if desire_memory_format and desire_memory_format != torch.preserve_format:
raise RuntimeError("Only support torch.preserve_format for ShardedTensor!")
cloned_local_shards = [
Shard(
local_shard.tensor.clone(memory_format=desire_memory_format),
metadata=copy.deepcopy(local_shard.metadata),
)
for local_shard in self_st.local_shards()
]
new_metadata = copy.deepcopy(self_st.metadata())
return cloned_local_shards, new_metadata
_register_sharded_op_on_local_shards(
torch.Tensor.clone,
customized_func=sharded_clone,
)
def sharded_detach(args, kwargs, pg):
self_st = args[0]
detached_local_shards = [
Shard(
local_shard.tensor.detach(),
metadata=copy.deepcopy(local_shard.metadata),
)
for local_shard in self_st.local_shards()
]
new_metadata = copy.deepcopy(self_st.metadata())
new_metadata.tensor_properties.requires_grad = False
return detached_local_shards, new_metadata
_register_sharded_op_on_local_shards(
torch.Tensor.detach,
customized_func=sharded_detach,
)
@_sharded_op_impl(torch.Tensor.requires_grad_)
def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
self_st = args[0]
# Validate types
if not isinstance(self_st, ShardedTensor):
raise TypeError("input needs to be a ShardedTensor")
if kwargs is None:
kwargs = {}
requires_grad = args[1] if len(args) > 1 else kwargs.get("requires_grad", True)
if requires_grad == self_st.requires_grad:
return self_st
for local_shard in self_st.local_shards():
local_shard.tensor.requires_grad_(requires_grad)
# update the wrapper class property
with torch._C.DisableTorchFunctionSubclass():
self_st.requires_grad_(requires_grad)
# update the metadata in the meanwhile
self_st._metadata.tensor_properties.requires_grad = requires_grad
return self_st

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Tuple
from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers
__all__: List[str] = []
def _get_or_create_logger() -> logging.Logger:
logging_handler, log_handler_name = _get_logging_handler()
logger = logging.getLogger(f"sharding-spec-{log_handler_name}")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
logging_handler.setFormatter(formatter)
logger.propagate = False
logger.addHandler(logging_handler)
return logger
def _get_logging_handler(
destination: str = "default",
) -> Tuple[logging.Handler, str]:
log_handler = _log_handlers[destination]
log_handler_name = type(log_handler).__name__
return (log_handler, log_handler_name)

View File

@ -0,0 +1,17 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List
__all__: List[str] = []
_log_handlers: Dict[str, logging.Handler] = {
"default": logging.NullHandler(),
}

View File

@ -0,0 +1,95 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass, field
from enum import Enum
from typing import List
import torch
from torch.distributed._shard.metadata import ShardMetadata
class MEM_FORMAT_ENCODING(Enum):
TORCH_CONTIGUOUS_FORMAT = 0
TORCH_CHANNELS_LAST = 1
TORCH_PRESERVE_FORMAT = 2
@dataclass
class TensorProperties:
"""Properties used to create :class:`Tensor`"""
# Regular tensor fields
dtype: torch.dtype = field(default=torch.get_default_dtype())
layout: torch.layout = field(default=torch.strided)
requires_grad: bool = False
memory_format: torch.memory_format = field(default=torch.contiguous_format)
pin_memory: bool = False
def __getstate__(self):
# Since torch.memory_format cannot be pickled!
memory_format = self.memory_format
if memory_format == torch.contiguous_format:
mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
elif memory_format == torch.channels_last:
mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
elif memory_format == torch.preserve_format:
mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
else:
raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
return (
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
)
def __setstate__(
self,
state,
):
(
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
) = state
if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
memory_format = torch.contiguous_format
elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
memory_format = torch.channels_last
elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
memory_format = torch.preserve_format
else:
raise RuntimeError(
f"Invalid torch.memory_format encoding: {mem_format_encoding}"
)
self.memory_format = memory_format
@staticmethod
def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
return TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
)
@dataclass
class ShardedTensorMetadata:
"""
Represents metadata for :class:`ShardedTensor`
"""
# Metadata about each shard of the Tensor
shards_metadata: List[ShardMetadata] = field(default_factory=list)
# Size of each dim of the overall Tensor.
size: torch.Size = field(default=torch.Size([]))
tensor_properties: TensorProperties = field(default_factory=TensorProperties)

View File

@ -0,0 +1,246 @@
# mypy: allow-untyped-defs
import copy
from typing import List, Tuple
import torch
import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
from torch._C._distributed_c10d import ProcessGroup
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size,
get_split_size,
)
from torch.distributed.nn.functional import all_to_all, all_to_all_single
from .shard import Shard
def get_idx_from_placements(placements, current_rank) -> int:
"""
Return the position of the current rank in the given placements.
Args:
placements(List[Union[_remote_device, str]]):
Specifies the placement of each shard of the Tensor. The size of
the list represents the number of shards to be created. This could
be a list of
:class:`torch.distributed._remote_device`'s. This list
could also contain a string which represents remote
device as accepted by
:class:`torch.distributed._remote_device`
current_rank (int): number of current device.
Returns:
A int which contains the position of current device in the placement list.
"""
for idx, placement in enumerate(placements): # type: ignore[attr-defined]
if current_rank == placement.rank(): # type: ignore[union-attr]
return idx
raise RuntimeError("current_rank not in the placement.")
def build_reshard_metadata(
st_size: torch.Size,
sharding_spec: shard_spec.ShardingSpec,
world_size: int,
) -> Tuple[List[ShardMetadata], List[int]]:
"""
Based the given sharding spec, we calculate the offset and local shard size.
We then build a ShardMetadata on top of the calculation result.
Args:
st_size (torch.Size): The size of the sharded tensor.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
specification describing how the tensor is sharded.
world_size (int): number of ranks.
Returns:
A Tuple of the followings:
A List[`ShardMetadata`] which contains the metadata for the shard, including
offsets, lengths and device placement.
A List[int] which contains the ranks in the order of placement.
"""
shard_dim = int(sharding_spec.dim) # type: ignore[attr-defined]
shards_metadata = [None] * world_size
ranks = []
offsets = [0] * len(st_size)
split_size = get_split_size(st_size[shard_dim], world_size)
for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined]
ranks.append(placement.rank())
sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx)
local_tensor_size = list(st_size)
local_tensor_size[shard_dim] = sharded_dim_size
shards_metadata[placement.rank()] = ShardMetadata( # type: ignore[call-overload]
shard_offsets=copy.deepcopy(offsets),
shard_sizes=local_tensor_size,
placement=placement,
)
offsets[shard_dim] += sharded_dim_size
return shards_metadata, ranks # type: ignore[return-value]
def reshuffle_local_shard(
local_shard: torch.Tensor,
st_size: torch.Size,
sharding_spec: shard_spec.ShardingSpec,
resharding_spec: shard_spec.ShardingSpec,
pg: ProcessGroup,
) -> Tuple[List[Shard], List[ShardMetadata]]:
"""
Reshuffle the local shard directly when the reshard dim is same as the original
sharding dim. Logically we do this in two step:
1. To collect all shards based on original sharding spec.
2. Reshard the tensor based on the given resharding spec.
In reality, we consolidate the two steps into one by sending the local tensor to
the new shard directly based on the resharding spec.
Args:
local_shard (Tensor): Local tensor stored in the current rank.
st_size (torch.Size): The size of the sharded tensor.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
specification describing how the tensor is sharded originally.
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
specification describing how the tensor will be resharded.
pg (ProcessGroup): The process group to aggregate on.
Returns:
A Tuple of the followings:
A List[`Shard`] which contains the local tensor and its metadata.
A List[`ShardMetadata`] which contains the metadata for the shard, including
offsets, lengths and device placement.
"""
current_rank = dist.get_rank(pg)
world_size = dist.get_world_size(pg)
# Build shards_metadata first.
shards_metadata, ranks = build_reshard_metadata(
st_size, resharding_spec, world_size
)
# Get input split size for all2all.
reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined]
split_size = get_split_size(st_size[reshard_dim], world_size)
input_split_sizes = [0] * world_size
idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined]
new_rank = resharding_spec.placements[idx].rank() # type: ignore[union-attr, attr-defined]
input_split_sizes[new_rank] = local_shard.size(reshard_dim)
# Get output split size for all2all.
output_split_sizes = [0] * world_size
new_idx = ranks.index(current_rank)
sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx)
output_split_sizes[new_rank] = sharded_dim_size
# Get gathered_input for all2all.
local_shard = local_shard.transpose(0, reshard_dim).contiguous()
gathered_input_size = list(local_shard.size())
gathered_input_size[0] = sharded_dim_size
gathered_input = torch.empty(
gathered_input_size, device=local_shard.device, dtype=local_shard.dtype
)
# all2all.
local_shard = all_to_all_single(
gathered_input,
local_shard,
input_split_sizes=input_split_sizes,
output_split_sizes=output_split_sizes,
group=pg,
)
local_tensor = local_shard.transpose(0, reshard_dim).contiguous()
local_shards = [Shard(local_tensor, shards_metadata[current_rank])]
return local_shards, shards_metadata
def reshard_local_shard(
local_tensor: torch.Tensor,
st_size: torch.Size,
sharding_spec: shard_spec.ShardingSpec,
resharding_spec: shard_spec.ShardingSpec,
pg: ProcessGroup,
) -> Tuple[List[Shard], List[ShardMetadata]]:
"""
Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is
different from the original sharding dim, we need to do two steps logically:
1. To collect all shards based on original sharding spec.
2. Reshard the tensor based on the given resharding spec.
In reality, we consolidate the two steps into one by sending each rank the new
shard based on the resharding spec.
Args:
local_tensor (Tensor): Local tensor stored in the current rank.
st_size (torch.Size): The size of the sharded tensor.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
specification describing how the tensor is sharded originally.
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
specification describing how the tensor will be resharded.
pg (ProcessGroup): The process group to aggregate on.
Returns:
A Tuple of the followings:
A List[`Shard`] which contains the local tensor and its metadata.
A List[`ShardMetadata`] which contains the metadata for the shard, including
offsets, lengths and device placement.
"""
current_rank = dist.get_rank(pg)
world_size = dist.get_world_size(pg)
current_sharding_dim = int(sharding_spec.dim) # type: ignore[attr-defined]
reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined]
# Build shards_metadata first.
shards_metadata, ranks = build_reshard_metadata(
st_size, resharding_spec, world_size
)
# Compute expected size
input_split_sizes = []
for metadata in shards_metadata:
input_split_sizes.append(metadata.shard_sizes[reshard_dim])
rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1))
if rearrange_input:
# Need to re-arrange reshard_dim of local_tensor before all2all.
indices: List[int] = []
for metadata in shards_metadata:
offset_start_idx = metadata.shard_offsets[reshard_dim]
split_size = metadata.shard_sizes[reshard_dim]
indices += range(offset_start_idx, offset_start_idx + split_size)
local_tensor = local_tensor.index_select(
reshard_dim, torch.tensor(indices, device=local_tensor.device)
)
# Because reshard_dim != original shard_dim. We need to compute the
# size of tensor from each rank.
output_tensor_list = [torch.tensor(1)] * world_size
split_size = get_split_size(st_size[current_sharding_dim], world_size)
rearrange_output_list = False
indices = []
for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined]
sharded_dim_size = get_chunked_dim_size(
st_size[current_sharding_dim], split_size, idx
)
output_tensor_size = list(st_size)
output_tensor_size[current_sharding_dim] = sharded_dim_size
output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
output_tensor_list[
placement.rank()
] = torch.empty( # type: ignore[union-attr, index]
output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
)
indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type]
if idx != placement.rank(): # type: ignore[union-attr]
rearrange_output_list = True
# Perform autograd enabled all2all.
input_tensor_tuple = torch.split(local_tensor, input_split_sizes, dim=reshard_dim)
input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple]
output_tensor_list = all_to_all(
output_tensor_list,
input_tensor_list,
group=pg,
)
if rearrange_output_list:
# Need to re-arrange original shard_dim of output_tensor_list.
output_tensor_list = [output_tensor_list[idx] for idx in indices] # type: ignore[call-overload]
local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim)
local_shards = [Shard(local_tensor, shards_metadata[current_rank])]
return local_shards, shards_metadata

View File

@ -0,0 +1,63 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import List
import torch
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed.remote_device import _remote_device
@dataclass
class Shard:
"""
Container which holds the data for a shard as a Tensor and also
the associated metadata for that shard.
Args:
tensor(torch.Tensor): Local tensor for the shard.
metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`):
The metadata for the shard, including offsets, lengths and device placement.
"""
__slots__ = ["tensor", "metadata"]
tensor: torch.Tensor
metadata: ShardMetadata
def __post_init__(self):
# verification between local tensor and metadata
if list(self.tensor.size()) != self.metadata.shard_sizes:
raise ValueError(
"Shard tensor size does not match with metadata.shard_lengths! "
f"Found shard tensor size: {list(self.tensor.size())}, "
f"metadata.shard_lengths: {self.metadata.shard_sizes}, "
)
placement_device = self.metadata.placement
if (
placement_device is not None
and placement_device.device() != self.tensor.device
):
raise ValueError(
f"Local shard tensor device does not match with local Shard's placement! "
f"Found local shard tensor device: {self.tensor.device}, "
f"local shard metadata placement device: {placement_device.device()}"
)
@classmethod
def from_tensor_and_offsets(
cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int
):
"""
Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.
Args:
tensor(torch.Tensor): Local tensor for the shard.
shard_offsets(List[int]): List of integers specify the offset
of the shard on each dimension.
rank(int): Specify the rank for the shard.
"""
shard_sizes = list(tensor.size())
placement = _remote_device(f"rank:{rank}/{str(tensor.device)}")
shard_meta = ShardMetadata(
shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement
)
return Shard(tensor, shard_meta)

View File

@ -0,0 +1,267 @@
# mypy: allow-untyped-defs
import collections.abc
import copy
from typing import List, Optional, Sequence, TYPE_CHECKING
import torch
from torch.distributed import distributed_c10d as c10d, rpc
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
validate_non_overlapping_shards_metadata,
)
from .metadata import ShardedTensorMetadata, TensorProperties
from .shard import Shard
if TYPE_CHECKING:
from torch.distributed._shard.metadata import ShardMetadata
def _parse_and_validate_remote_device(pg, remote_device):
if remote_device is None:
raise ValueError("remote device is None")
worker_name = remote_device.worker_name()
rank = remote_device.rank()
device = remote_device.device()
# Validate rank, skip validation if rank is not part of process group.
if rank is not None and not c10d._rank_not_in_group(pg):
pg_global_ranks = c10d.get_process_group_ranks(pg)
if rank not in pg_global_ranks:
raise ValueError(
f"Global rank {rank} does not exist in input process group: {pg_global_ranks}"
)
if worker_name is not None:
if not rpc._is_current_rpc_agent_set():
raise RuntimeError(
f"RPC framework needs to be initialized for using worker names: {worker_name}"
)
workers = rpc._get_current_rpc_agent().get_worker_infos()
for worker in workers:
if worker.name == worker_name:
return worker.id, device
raise ValueError(f"Invalid worker name: {worker_name}")
return rank, device
def _validate_output_tensor_for_gather(
my_rank: int,
dst_rank: int,
size: torch.Size,
dst_tensor: Optional[torch.Tensor],
) -> None:
if dst_rank == my_rank:
if dst_tensor is None:
raise ValueError(
f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}"
)
if tuple(size) != (dst_tensor.size()):
raise ValueError(
f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())},"
f"but should be {tuple(size)}"
)
elif dst_tensor:
raise ValueError(
"Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks."
)
def _flatten_tensor_size(size) -> torch.Size:
"""
Checks if tensor size is valid, then flatten/return a torch.Size object.
"""
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
dims = list(*size)
else:
dims = list(size)
for dim in dims:
if not isinstance(dim, int):
raise TypeError(f"size has to be a sequence of ints, found: {dims}")
return torch.Size(dims)
def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
if is_local:
assert isinstance(ranks, int)
if expected != actual:
raise ValueError(
f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! "
f"Found one local shard tensor {prop_name}={expected}, "
f"the other local shard tensor {prop_name}={actual}."
)
else:
# compare failure check across ranks, ranks list should have two rank
assert len(ranks) == 2
if expected != actual:
raise ValueError(
f"ShardedTensor {prop_name} property does not match from different ranks! "
f"Found {prop_name}={expected} on rank:{ranks[0]}, "
f"and {prop_name}={actual} on rank:{ranks[1]}."
)
def build_metadata_from_local_shards(
local_shards: List[Shard],
global_size: torch.Size,
current_rank: int,
pg: c10d.ProcessGroup,
) -> ShardedTensorMetadata:
assert len(local_shards) > 0, "must have local shards!"
local_shard_metadatas: List[ShardMetadata] = []
first_shard_dtype = local_shards[0].tensor.dtype
first_shard_layout = local_shards[0].tensor.layout
first_shard_requires_grad = local_shards[0].tensor.requires_grad
first_shard_is_pinned = local_shards[0].tensor.is_pinned()
# 1). Validate local tensors and associated metadatas
for local_shard in local_shards:
local_shard_tensor = local_shard.tensor
local_shard_meta = local_shard.metadata
local_shard_metadatas.append(local_shard_meta)
rank, local_device = _parse_and_validate_remote_device(
pg, local_shard_meta.placement
)
if (
local_shard_tensor.layout != torch.strided
or local_shard_tensor.layout != first_shard_layout
):
raise ValueError(
f"Only torch.strided layout is currently supported, but found "
f"{local_shard_tensor.layout} on rank:{current_rank}!"
)
if not local_shard_tensor.is_contiguous():
raise ValueError(
"Only torch.contiguous_format memory_format is currently supported!"
)
if rank != current_rank:
raise ValueError(
f"Local shard metadata's rank does not match with the rank in its process group! "
f"Found current rank in the process group: {current_rank}, "
f"local ShardMetadata placement's rank: {rank}"
)
if local_shard_tensor.device != local_device:
raise ValueError(
f"Local shard tensor device does not match with local Shard's placement! "
f"Found local shard tensor device: {local_shard_tensor.device}, "
f"local shard metadata placement device: {local_device}"
)
_raise_if_mismatch(
local_shard_meta.shard_sizes,
list(local_shard_tensor.size()),
"size",
current_rank,
)
_raise_if_mismatch(
local_shard_tensor.is_pinned(),
first_shard_is_pinned,
"pin_memory",
current_rank,
)
_raise_if_mismatch(
local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank
)
_raise_if_mismatch(
local_shard_tensor.requires_grad,
first_shard_requires_grad,
"requires_grad",
current_rank,
)
# 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then
# do all_gather to collect local_sharded_tensor_metadata from all ranks
local_tensor_properties = TensorProperties(
dtype=first_shard_dtype,
layout=first_shard_layout,
requires_grad=first_shard_requires_grad,
memory_format=torch.contiguous_format,
pin_memory=first_shard_is_pinned,
)
local_sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=local_shard_metadatas,
size=global_size,
tensor_properties=local_tensor_properties,
)
return local_sharded_tensor_metadata
def build_global_metadata(
gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]],
):
global_sharded_tensor_metadata = None
global_metadata_rank = 0
for rank, rank_metadata in enumerate(gathered_metadatas):
if rank_metadata is None:
continue
if global_sharded_tensor_metadata is None:
global_sharded_tensor_metadata = copy.deepcopy(rank_metadata)
global_metadata_rank = rank
else:
_raise_if_mismatch(
global_sharded_tensor_metadata.size,
rank_metadata.size,
"global_size",
[global_metadata_rank, rank],
is_local=False,
)
# don't need to check layout and memory format as we already checked in local shards validation stage
_raise_if_mismatch(
global_sharded_tensor_metadata.tensor_properties.dtype,
rank_metadata.tensor_properties.dtype,
"dtype",
[global_metadata_rank, rank],
is_local=False,
)
_raise_if_mismatch(
global_sharded_tensor_metadata.tensor_properties.requires_grad,
rank_metadata.tensor_properties.requires_grad,
"requires_grad",
[global_metadata_rank, rank],
is_local=False,
)
_raise_if_mismatch(
global_sharded_tensor_metadata.tensor_properties.pin_memory,
rank_metadata.tensor_properties.pin_memory,
"pin_memory",
[global_metadata_rank, rank],
is_local=False,
)
# pass all validations, extend shards metadata
global_sharded_tensor_metadata.shards_metadata.extend(
rank_metadata.shards_metadata
)
if global_sharded_tensor_metadata is not None:
# check if shards_metadata have overlap shards
validate_non_overlapping_shards_metadata(
global_sharded_tensor_metadata.shards_metadata
)
# check if the shards_metadata is compatible with global size of the sharded tensor.
check_tensor(
global_sharded_tensor_metadata.shards_metadata,
global_sharded_tensor_metadata.size,
)
else:
raise ValueError("ShardedTensor have no local shards on all ranks!")
return global_sharded_tensor_metadata

View File

@ -0,0 +1,29 @@
import abc
import torch.nn as nn
class Sharder(abc.ABC):
"""
This is an interface which allows user to create more advanced
sharding strategies that are not easily be composed by the
`ShardingSpec`.
:class:`torch.distributed._shard.sharding_plan.ShardingPlan` could
take an object of the `Sharder` and call `shard` to shard the module,
then replace the original module with sharded module returned.
"""
@abc.abstractmethod
def shard(self, module: nn.Module) -> nn.Module:
"""
Shard a module base on the implementation of this method, and
return the sharded version of the module.
Args:
module (:class:`torch.nn.Module`):
The module to apply sharding to.
Returns:
A :class:`torch.nn.Module` object that represents a module
that's already been sharded.
"""

View File

@ -0,0 +1 @@
from .api import ShardingPlan, ShardingPlanner

Some files were not shown because too many files have changed in this diff Show More