I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from .checkpoint_activation import checkpoint
from .contract import _get_registry, contract
from .fully_shard import fully_shard
from .replicate import replicate

View File

@ -0,0 +1,126 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from contextlib import contextmanager, nullcontext
from typing import Any, ContextManager, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.utils.checkpoint import (
_checkpoint_without_reentrant_generator,
_DEFAULT_DETERMINISM_MODE,
)
from .contract import contract
@contextmanager
def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None):
r"""
Disable hooks installed by checkpoint to avoid unintentional recursion
during backward recomputation.
"""
with user_ctx if user_ctx else nullcontext():
orig_enable_hook = checkpoint.state(module).enable_hook
checkpoint.state(module).enable_hook = False
try:
yield
finally:
checkpoint.state(module).enable_hook = orig_enable_hook
@contract()
def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
r"""
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing model
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
this one does not modify model structure or fully-qualified names either.
Under the hood, it registers activation checkpointing logic as pre- and
post-forward hooks. Hence, this API can be easily applied to any model or
sub-modules in the model.
Args:
module (nn.Module): the target model or sub-module to apply activation
checkpointing.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> model = MyModel()
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
>>> model(torch.zeros(2, 10)).sum().backward()
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint")
use_reentrant = kwargs.pop("use_reentrant", False)
if use_reentrant:
raise NotImplementedError(
"use_reentrant=True is not supported in composable checkpoint. "
"Please use torch.utils.checkpoint.checkpoint instead."
)
preserve_rng_state = kwargs.pop("preserve_rng_state", True)
user_context_fns = kwargs.pop("context_fn", None)
determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE)
debug = kwargs.pop("debug", False)
if kwargs:
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
def forward_pre_hook(
module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> None:
if checkpoint.state(module).enable_hook:
def context_fns():
if user_context_fns is not None:
ctx1, ctx2 = user_context_fns()
return ctx1, _no_hook(module, ctx2)
else:
return nullcontext(), _no_hook(module)
checkpoint.state(
module
)._ac_generator = _checkpoint_without_reentrant_generator(
module,
preserve_rng_state,
context_fns,
determinism_check,
debug,
*args,
**kwargs,
)
next(checkpoint.state(module)._ac_generator)
def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
if checkpoint.state(module).enable_hook:
try:
next(checkpoint.state(module)._ac_generator)
except StopIteration:
pass
else:
raise RuntimeError(
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
)
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
# clear this even in the case of exception in fwd pass.
checkpoint.state(module)._ac_generator = None
checkpoint.state(module).enable_hook = True
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
return module

View File

@ -0,0 +1,224 @@
# mypy: allow-untyped-defs
import uuid
from collections import OrderedDict
from functools import wraps
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
import torch
import torch.nn as nn
from torch.distributed._composable_state import _State
from torch.distributed.utils import _get_root_modules
def generate_state_key(string="__composable_api_state_key"):
return f"{string}_{str(uuid.uuid4())}"
STATE_KEY = generate_state_key()
REGISTRY_KEY = generate_state_key()
# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
# we can add args and kwargs here, and then we can detect whether fully_shard
# is combined with reentrant activation checkpointing and error out with a clear
# message.
class RegistryItem:
pass
def contract(state_cls: Type[_State] = _State):
r"""
Decorate a function as a composable distributed API, where the first
argument of the function must be an :class:`nn.Module` instance or sequence
of :class:`nn.Module` instances.
The decorator verifies that the decorated function does not modify
fully-qualified names (FQNs) for parameters, buffers, or modules. The
decorated function can return different module instances than the input
modules; the FQN invariant will be enforced following the input order.
When a function ``func`` is decorated by ``@contract()``, a
``.state(module: nn.Module)`` method will be installed to the decorated
function. Then you can retrieve and modify the state on a module by calling
``func.state(module)``.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> @contract()
>>> def my_feature(module: nn.Module) -> nn.Module:
>>> my_feature.state(module).some_state = "any value"
>>> return module
>>>
>>> model = MyModel()
>>> my_feature(model.l1)
>>> assert my_feature.state(model.l1).some_state == "any value"
>>> my_feature(model.l2)
>>> model(torch.randn(2, 10)).sum().backward()
"""
# wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
@wraps(state_cls)
def inner(func):
@wraps(func)
def wrapper(
module: Union[nn.Module, Sequence[nn.Module]], *args, **kwargs
) -> Optional[nn.Module]:
inp_module = module
if isinstance(module, nn.Module):
modules = [module]
else:
# If the user passes a sequence of modules, then we assume that
# we only need to insert the state object on the root modules
# (i.e. those without a parent) among the passed-in modules.
modules = _get_root_modules(list(module))
state = state_cls() # shared across all modules
registry_item = RegistryItem() # shared across all modules
# `func` is allowed to return different module instances than the
# input modules as long as FQNs are preserved following the input
# module order
all_orig_named_params: List[Dict[str, nn.Parameter]] = []
all_orig_named_buffers: List[Dict[str, torch.Tensor]] = []
all_orig_named_modules: List[Dict[str, nn.Module]] = []
for module in modules:
default_all_state: Dict[Callable, _State] = OrderedDict()
default_registry: Dict[str, RegistryItem] = OrderedDict()
all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY, default_all_state
)
if not isinstance(all_state, dict):
raise AssertionError(
f"Distributed composable API states corrupted: {all_state}"
)
registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
REGISTRY_KEY, default_registry
)
if not isinstance(registry, dict):
raise AssertionError(
f"Distributed composable API registry corrupted: {registry}"
)
if func in all_state or func.__name__ in registry:
raise AssertionError(
"Each distinct composable distributed API can only be applied to a "
f"module once. {func.__name__} has already been applied to the "
f"following module:\n{module}"
)
all_state.setdefault(func, state)
registry.setdefault(func.__name__, registry_item)
all_orig_named_params.append(OrderedDict(module.named_parameters()))
all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
all_orig_named_modules.append(OrderedDict(module.named_modules()))
updated = func(inp_module, *args, **kwargs)
if updated is None:
updated = inp_module
if isinstance(updated, nn.Module):
updated_modules = [updated]
else:
updated_modules = _get_root_modules(list(inp_module))
all_new_named_params: List[Dict[str, nn.Parameter]] = []
all_new_named_buffers: List[Dict[str, torch.Tensor]] = []
all_new_named_modules: List[Dict[str, nn.Module]] = []
for module in updated_modules:
all_new_named_params.append(OrderedDict(module.named_parameters()))
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
all_new_named_modules.append(OrderedDict(module.named_modules()))
num_orig_modules = len(all_orig_named_modules)
num_new_modules = len(all_new_named_modules)
if num_orig_modules != num_new_modules:
raise AssertionError(
f"{func.__name__} should return the same number of modules as input modules"
f"Inputs: {num_orig_modules} modules\n"
f"Outputs: {num_new_modules} modules"
)
def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
if orig_fqns == new_fqns:
return
orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
orig_only = orig_fqn_set - new_fqn_set
new_only = new_fqn_set - orig_fqn_set
if len(orig_only) or len(new_only):
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify FQNs.\n"
f"FQNs only in original: {orig_only}\n"
f"FQNs only in new: {new_only}"
)
else:
raise RuntimeError(
f"{check_key}"
"Composable distributed API implementations cannot modify "
"the order of FQNs.\n"
f"Original FQNs: {orig_only}\n"
f"New FQNs: {new_only}"
)
for orig_named_params, new_named_params in zip(
all_orig_named_params, all_new_named_params
):
check_fqn(
list(orig_named_params.keys()),
list(new_named_params.keys()),
"Checking parameters: ",
)
for orig_named_buffers, new_named_buffers in zip(
all_orig_named_buffers, all_new_named_buffers
):
check_fqn(
list(orig_named_buffers.keys()),
list(new_named_buffers.keys()),
"Checking buffers: ",
)
for orig_named_modules, new_named_modules in zip(
all_orig_named_modules, all_new_named_modules
):
check_fqn(
list(orig_named_modules.keys()),
list(new_named_modules.keys()),
"Checking modules: ",
)
# TODO: verify that installed distributed paradigms are compatible with
# each other.
return updated
def get_state(module: nn.Module) -> Optional[_State]:
return module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY,
{}, # TODO(@yhcharles): this is a temporary fix, need a better way
).get(
func
) # type: ignore[call-overload]
wrapper.state = get_state # type: ignore[attr-defined]
return wrapper
return inner
def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
r"""
Get an ``OrderedDict`` of composable APIs that have been applied to the
``module``, indexed by the API name. If no API has been applied, then this
returns ``None``.
"""
return getattr(module, REGISTRY_KEY, None)

View File

@ -0,0 +1,2 @@
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method

View File

@ -0,0 +1,80 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass(frozen=True)
class MixedPrecisionPolicy:
"""
This configures FSDP's mixed precision. Unlike autocast, this applies mixed
precision at the module level, not op level, which means low-precision
activations are saved for backward and high-to-low-precision casts are
incurred only at module boundaries.
FSDP works well with module-level mixed precision since it keeps the
high-precision sharded parameters in memory anyway. In other words, FSDP
does not require any extra memory to keep a high-precision copy of the
parameters for the optimizer step.
Attributes:
param_dtype (Optional[torch.dtype]): This specifies the dtype for
the unsharded parameter and hence the dtype for forward/backward
computation and the parameter all-gather. If this is ``None``, then
the unsharded parameter uses the original dtype. The optimizer step
uses the sharded parameter in the original dtype. (Default:
``None``)
reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
gradient reduction (i.e. reduce-scatter or all-reduce). If this is
``None`` but ``param_dtype`` is not ``None``, then the reduction
uses the compute dtype. This can be used to run gradient reduction
in full precision while using low precision for compute. If also
gradient reduction is disabled via :meth:`set_requires_gradient_sync`,
then FSDP will accumulate gradients using ``reduce_dtype``.
(Default: ``None``)
output_dtype (Optional[torch.dtype]): This specifies the dtype for
casting floating-point forward outputs. This can be used to
help implement cases where different modules have different mixed
precision policies. (Default: ``None``)
cast_forward_inputs (bool): This specifies whether FSDP should cast the
forward's floating-point input tensors to ``param_dtype`` or not.
"""
param_dtype: Optional[torch.dtype] = None
reduce_dtype: Optional[torch.dtype] = None
output_dtype: Optional[torch.dtype] = None
cast_forward_inputs: bool = True
def __post_init__(self):
# Clamp `reduce_dtype` to `None` if no casting is required: since
# gradients are computed in `param_dtype`, if `reduce_dtype` matches,
# then we do not need extra casting
if self.param_dtype == self.reduce_dtype:
# Bypass the frozen dataclass checks
object.__setattr__(self, "reduce_dtype", None)
@dataclass
class OffloadPolicy:
"""This base class represents the policy of no offloading."""
@dataclass
class CPUOffloadPolicy(OffloadPolicy):
"""
This offload policy offloads parameters, gradients, and optimizer states to
CPU. Sharded parameters are copied host-to-device before all-gather. The
all-gathered parameters are freed according to ``reshard_after_forward``.
Sharded gradients are copied device-to-host in backward, and the optimizer
step runs on CPU with CPU optimizer states.
Attributes:
pin_memory (bool): Whether to pin sharded parameter and gradient
memory. Pinning memory allows H2D/D2H copying without blocking the
CPU and in turn, overlap with compute, but pinned memory cannot be
used by other processes. Set this to ``False`` if you have
insufficient CPU memory. (Default: ``True``)
"""
pin_memory: bool = True

View File

@ -0,0 +1,477 @@
# mypy: allow-untyped-decorators
from typing import cast, List, NamedTuple, Optional, Tuple, Union
import torch
import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.tensor import DTensor
from ._fsdp_common import (
_get_dim0_padded_size,
_raise_assert_with_print,
_to_dtype_if_needed,
)
from ._fsdp_param import FSDPParam, ShardedState
class AllGatherResult(NamedTuple):
all_gather_output: torch.Tensor
all_gather_event: Optional[torch.cuda.Event]
all_gather_work: Optional[dist.distributed_c10d.Work]
# For each parameter, the all-gather input dtype for each input
param_all_gather_input_dtypes: List[List[torch.dtype]]
# For each parameter, the all-gather input numel for each input
param_all_gather_input_numels: List[List[int]]
# 1D flattened version of `param_all_gather_input_numels` saved to avoid
# CPU overhead from recomputing
all_gather_input_split_sizes: List[int]
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib.define(
"""
all_gather_copy_in(
Tensor[] all_gather_inputs,
SymInt[] inp_split_sizes,
SymInt all_gather_input_numel,
SymInt world_size,
SymInt rank,
ScalarType dtype,
Device device
) -> (Tensor, Tensor)
"""
)
@torch.library.impl(lib, "all_gather_copy_in", "Meta")
def all_gather_copy_in_meta(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device="meta"
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
return all_gather_input, all_gather_output
@torch.library.impl(lib, "all_gather_copy_in", "CUDA")
@torch.library.impl(lib, "all_gather_copy_in", "CPU")
def all_gather_copy_in_cuda(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
with torch.no_grad():
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
return all_gather_input, all_gather_output
lib.define(
"split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
)
@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
@torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: List[int],
dim: int,
out: List[torch.Tensor],
) -> None:
torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
)
lib.define(
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
)
@torch.library.impl(lib, "chunk_cat", "Meta")
@torch.library.impl(lib, "chunk_cat", "CUDA")
@torch.library.impl(lib, "chunk_cat", "CPU")
def chunk_cat(
tensors: List[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
) -> None:
torch._chunk_cat(tensors, dim, num_chunks, out=out)
@torch.no_grad()
def foreach_all_gather(
fsdp_params: List[FSDPParam],
group: dist.ProcessGroup,
async_op: bool,
all_gather_copy_in_stream: torch.cuda.Stream,
all_gather_stream: torch.cuda.Stream,
device: torch.device,
) -> Optional[AllGatherResult]:
world_size, rank = group.size(), group.rank()
with torch.cuda.stream(all_gather_copy_in_stream):
param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
(
param_all_gather_input_dtypes,
param_all_gather_input_numels,
dtype,
) = _get_all_gather_input_metadatas(param_all_gather_inputs)
if dtype == torch.uint8:
all_gather_inputs = [
t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts
]
else:
all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts]
inp_split_sizes = [t.numel() for t in all_gather_inputs]
all_gather_input_numel = sum(inp_split_sizes)
all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
all_gather_inputs,
inp_split_sizes,
all_gather_input_numel,
world_size,
rank,
dtype,
device,
)
del param_all_gather_inputs
all_gather_stream.wait_stream(all_gather_copy_in_stream)
with torch.cuda.stream(all_gather_stream):
all_gather_work = dist.all_gather_into_tensor(
output_tensor=all_gather_output,
input_tensor=all_gather_input,
group=group,
async_op=async_op,
)
all_gather_event = all_gather_stream.record_event()
return AllGatherResult(
all_gather_output,
all_gather_event,
all_gather_work,
param_all_gather_input_dtypes,
param_all_gather_input_numels,
inp_split_sizes,
)
@torch.no_grad()
def _get_param_all_gather_inputs(
fsdp_params: List[FSDPParam],
) -> List[List[torch.Tensor]]:
if ca.compiled_autograd_enabled:
return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params]
# Intentionally try to run a fast-path that bypasses abstractions for the
# common FSDP case of bf16/fp32 mixed precision in order to use foreach
# copy for lower CPU overhead and more efficient copying in eager
def use_foreach_copy(fsdp_param: FSDPParam) -> bool:
return (
fsdp_param.param_dtype is not None
and not fsdp_param.offload_to_cpu
and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather")
)
param_all_gather_inputs: List[List[torch.Tensor]] = [[] for _ in fsdp_params]
foreach_copy_indices: List[int] = []
foreach_copy_inputs: List[torch.Tensor] = []
foreach_copy_input_numels: List[int] = []
# 1st pass: for foreach-copy parameters, get inputs and metadata for the
# foreach copy, and for the others, actually get their all-gather inputs
for i, fsdp_param in enumerate(fsdp_params):
if use_foreach_copy(fsdp_param):
foreach_copy_indices.append(i)
all_gather_input = (
fsdp_param._sharded_param_data
if fsdp_param.sharded_state == ShardedState.SHARDED
else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data)
)
foreach_copy_inputs.append(all_gather_input)
foreach_copy_input_numels.append(all_gather_input.numel())
else:
param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
# 2nd pass: use foreach copy to compute the remaining all-gather inputs
if foreach_copy_inputs:
fsdp_param_0 = fsdp_params[foreach_copy_indices[0]]
param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device
flat_foreach_copy_input = torch.empty(
(sum(foreach_copy_input_numels),), device=device, dtype=param_dtype
)
splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels)
torch._foreach_copy_(splits, foreach_copy_inputs)
for i, split in zip(foreach_copy_indices, splits):
param_all_gather_inputs[i] = [split]
return param_all_gather_inputs
@torch.no_grad()
def foreach_all_gather_copy_out(
all_gather_result: AllGatherResult,
fsdp_params: List[FSDPParam],
group: dist.ProcessGroup,
) -> None:
(
all_gather_output,
all_gather_event,
all_gather_work,
param_all_gather_input_dtypes,
param_all_gather_input_numels,
all_gather_input_split_sizes,
) = all_gather_result
if all_gather_event is not None: # sync op
torch.cuda.current_stream().wait_event(all_gather_event)
if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op
all_gather_work.wait()
world_size, device = group.size(), all_gather_output.device
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
):
if ca.compiled_autograd_enabled:
fsdp_param.init_all_gather_outputs(
all_gather_input_numels,
all_gather_input_dtypes,
world_size,
device,
# NOTE: Under compile, make sure we always recreate all_gather_outputs
# per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2].
force_recreate=True,
)
else:
fsdp_param.init_all_gather_outputs(
all_gather_input_numels, all_gather_input_dtypes, world_size, device
) # no-op after 1st call
fsdp_param.alloc_all_gather_outputs()
all_gather_output = all_gather_output.view(world_size, -1)
gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs)
if all_gather_output.dtype == torch.uint8:
out = [t.view(world_size, -1).view(torch.uint8) for t in gen]
else:
out = [t.view(world_size, -1) for t in gen]
torch.ops.fsdp.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
)
@torch.no_grad()
def foreach_reduce(
fsdp_params: List[FSDPParam],
unsharded_grads: List[torch.Tensor],
reduce_scatter_group: dist.ProcessGroup,
reduce_scatter_stream: torch.cuda.Stream,
orig_dtype: torch.dtype,
reduce_dtype: Optional[torch.dtype],
device: torch.device,
reduce_scatter_reduce_op: Optional[Union[dist.ReduceOp, dist.ReduceOp.RedOpType]],
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
all_reduce_stream: torch.cuda.Stream,
all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
) -> Tuple[torch.Tensor, torch.cuda.Event, torch.cuda.Event, Optional[torch.Tensor]]:
"""
``unsharded_grads`` owns the references to the gradients computed by
autograd, so clearing the list frees the gradients.
"""
grad_dtypes = {grad.dtype for grad in unsharded_grads}
if len(grad_dtypes) != 1:
# Check this at runtime since it could be a real runtime error if e.g.
# fp8 weights do not produce the correct higher precision gradients
_raise_assert_with_print(
f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}"
)
grad_dtype = unsharded_grads[0].dtype
reduce_dtype = reduce_dtype or grad_dtype
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
reduce_scatter_group, all_reduce_group, reduce_dtype
)
world_size = reduce_scatter_group.size()
padded_unsharded_sizes = tuple(
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
)
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
reduce_scatter_input = torch.empty(
(reduce_scatter_input_numel,), dtype=reduce_dtype, device=device
)
foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size)
current_stream = torch.cuda.current_stream()
# Only after the copy-in finishes can we free the gradients
unsharded_grads.clear()
reduce_scatter_stream.wait_stream(current_stream)
with torch.cuda.stream(reduce_scatter_stream):
reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
_div_if_needed(reduce_scatter_input, predivide_factor)
if reduce_scatter_reduce_op is None:
if predivide_factor is None:
reduce_scatter_reduce_op = ReduceOp.AVG
else:
reduce_scatter_reduce_op = ReduceOp.SUM
dist.reduce_scatter_tensor(
output=reduce_output,
input=reduce_scatter_input,
group=reduce_scatter_group,
op=reduce_scatter_reduce_op,
)
reduce_scatter_event = reduce_scatter_stream.record_event()
post_reduce_stream = reduce_scatter_stream
if all_reduce_group is not None: # HSDP
# Accumulations must run in the reduce-scatter stream
if not all_reduce_grads:
if partial_reduce_output is not None:
partial_reduce_output += reduce_output
else:
partial_reduce_output = reduce_output
return (
reduce_scatter_input,
reduce_scatter_event,
post_reduce_stream.record_event(),
partial_reduce_output,
)
if partial_reduce_output is not None:
reduce_output += partial_reduce_output
post_reduce_stream = all_reduce_stream
all_reduce_stream.wait_stream(reduce_scatter_stream)
with torch.cuda.stream(all_reduce_stream):
dist.all_reduce(
reduce_output,
group=all_reduce_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
)
with torch.cuda.stream(post_reduce_stream):
_div_if_needed(reduce_output, postdivide_factor)
reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype)
# View out and accumulate sharded gradients
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
for padded_unsharded_size, fsdp_param in zip(
padded_unsharded_sizes, fsdp_params
):
new_sharded_grad = torch.as_strided(
reduce_output,
size=fsdp_param.sharded_size,
stride=fsdp_param.contiguous_sharded_stride,
storage_offset=flat_grad_offset,
)
to_accumulate_grad = fsdp_param.sharded_param.grad is not None
if fsdp_param.offload_to_cpu:
# Only overlap the D2H copy (copying to pinned memory) if not
# accumulating gradients since the CPU add kernel depends on
# the copy result and we cannot run the add as a callback
non_blocking = fsdp_param.pin_memory and not to_accumulate_grad
# Since the GPU sharded gradient is allocated in the RS stream,
# we can free it here by not keeping a ref without waiting for
# the D2H copy since future RS-stream ops run after the copy
new_sharded_grad = new_sharded_grad.to(
torch.device("cpu"), non_blocking=non_blocking
)
if non_blocking:
# Record an event on which to block the CPU thread to
# ensure that the D2H copy finishes before the optimizer
fsdp_param.grad_offload_event = reduce_scatter_stream.record_event()
if to_accumulate_grad:
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
else:
new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(
new_sharded_grad
)
fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
if not ca.compiled_autograd_enabled:
for hook in (
getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {})
or {}
).values():
hook(fsdp_param.sharded_param)
padded_sharded_numel = padded_unsharded_size.numel() // world_size
flat_grad_offset += padded_sharded_numel
post_reduce_event = post_reduce_stream.record_event()
# The RS output is allocated in the RS stream and used in the default
# stream (for optimizer). To ensure its memory is not reused for later
# RSs, we do not need extra synchronization since the sharded parameters
# hold refs through the end of backward.
return reduce_scatter_input, reduce_scatter_event, post_reduce_event, None
def foreach_reduce_scatter_copy_in(
unsharded_grads: List[torch.Tensor],
reduce_scatter_input: torch.Tensor,
world_size: int,
) -> None:
reduce_scatter_input = reduce_scatter_input.view(world_size, -1)
torch.ops.fsdp.chunk_cat(
unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input
)
def _get_all_gather_input_metadatas(
param_all_gather_inputs: List[List[torch.Tensor]],
) -> Tuple[List[List[torch.dtype]], List[List[int]], torch.dtype]:
param_all_gather_input_dtypes: List[List[torch.dtype]] = []
param_all_gather_input_numels: List[List[int]] = []
all_gather_dtype = param_all_gather_inputs[0][0].dtype
for all_gather_inputs in param_all_gather_inputs:
input_dtypes: List[torch.dtype] = []
input_numels: List[int] = []
for all_gather_input in all_gather_inputs:
if all_gather_input.dtype != all_gather_dtype:
all_gather_dtype = torch.uint8
input_dtypes.append(all_gather_input.dtype)
input_numels.append(all_gather_input.numel())
param_all_gather_input_dtypes.append(input_dtypes)
param_all_gather_input_numels.append(input_numels)
return (
param_all_gather_input_dtypes,
param_all_gather_input_numels,
all_gather_dtype,
)
def _get_gradient_divide_factors(
reduce_scatter_group: dist.ProcessGroup,
all_reduce_group: Optional[dist.ProcessGroup],
reduce_dtype: torch.dtype,
) -> Union[Tuple[None, None], Tuple[float, float]]:
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
# use NCCL's built-in division to avoid separate div kernels
if reduce_dtype in (torch.float32, torch.bfloat16):
return None, None
data_parallel_size = reduce_scatter_group.size()
if all_reduce_group is not None:
data_parallel_size *= all_reduce_group.size()
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
# overflow/underflow. For N data parallel workers, each worker computes
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
factor: int = 1
while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
factor *= 2
factor = float(factor)
return (factor, data_parallel_size / factor)
def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:
if div_factor is not None and div_factor > 1:
tensor.div_(div_factor)

View File

@ -0,0 +1,152 @@
# mypy: allow-untyped-defs
import math
import traceback
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, cast, List, Optional
import torch
import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import _get_registry
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@dataclass
class DataParallelMeshInfo:
mesh: DeviceMesh
shard_mesh_dim: Optional[int] = None
replicate_mesh_dim: Optional[int] = None
def __post_init__(self):
if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
raise AssertionError(
"At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
)
@dataclass
class FSDPMeshInfo(DataParallelMeshInfo):
def __post_init__(self):
super().__post_init__()
if self.shard_mesh_dim is None:
raise AssertionError("Expects non-None shard_mesh_dim")
self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim)
self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim)
self.shard_mesh_rank: int = self.shard_process_group.rank()
@dataclass
class DDPMeshInfo(DataParallelMeshInfo):
def __post_init__(self):
super().__post_init__()
if self.replicate_mesh_dim is None:
raise AssertionError("Expects non-None replicate_mesh_dim")
self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim)
self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim)
self.replicate_mesh_rank: int = self.replicate_process_group.rank()
@dataclass
class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
def __post_init__(self):
# Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
super().__post_init__()
class TrainingState(Enum):
"""Describes the training state of one FSDP state / parameter group."""
# Transition to forward starting pre-forward until post-forward
FORWARD = auto()
# Transition to pre-backward when unsharding in backward
PRE_BACKWARD = auto()
# Transition to post-backward when resharding and reducing gradients
POST_BACKWARD = auto()
# Idle before/after forward or before pre-backward/after post-backward
IDLE = auto()
def _raise_assert_with_print(*args: Any, **kwargs: Any):
print(f"[Rank {dist.get_rank()}] ", end="")
print(*args, **kwargs)
traceback.print_stack()
raise AssertionError(*args, **kwargs)
def _is_composable_with_fsdp(module: nn.Module) -> bool:
registry = _get_registry(module)
if registry is None:
return True
# Registry keys by function name
return "replicate" not in registry
def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size:
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
return cast(torch.Size, torch.Size([padded_dim0]) + tensor_size[1:])
def _chunk_with_empty(
tensor: torch.Tensor, num_chunks: int, dim: int
) -> List[torch.Tensor]:
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
while len(chunks) < num_chunks:
chunks.append(chunks[0].new_empty(0))
return chunks
def _get_dim0_chunked_size(
chunk: torch.Tensor, unchunked_size: torch.Size
) -> torch.Size:
if chunk.numel() > 0:
return chunk.size()
# For 0 numel, we need to preserve trailing dims for DTensor APIs
return cast(torch.Size, torch.Size([0]) + unchunked_size[1:])
def _from_local_no_grad(
local_tensor: torch.Tensor,
sharding_spec: DTensorSpec,
) -> DTensor:
"""
This method is similar to ``DTensor.from_local()`` except that in eager mode
it avoids some CPU overhead by avoiding default args and not being differentiable.
"""
if not ca.compiled_autograd_enabled:
return DTensor(
# Use the local tensor directly instead of constructing a new tensor
# variable, e.g. with `view_as()`, since this is not differentiable
local_tensor,
sharding_spec,
requires_grad=local_tensor.requires_grad,
)
else:
return DTensor.from_local(
local_tensor,
sharding_spec.mesh,
sharding_spec.placements,
shape=sharding_spec.shape,
stride=sharding_spec.stride,
)
def _to_dtype_if_needed(
tensor: torch.Tensor, dtype: Optional[torch.dtype]
) -> torch.Tensor:
if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype)
return tensor
def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
if (
not isinstance(x, torch.Tensor)
or not torch.is_floating_point(x)
or x.dtype == dtype
):
return x
return x.to(dtype)

View File

@ -0,0 +1,168 @@
import itertools
from typing import List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
from ._fsdp_state import _get_module_fsdp_state
def _get_post_forward_mesh_info(
reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo
) -> Optional[FSDPMeshInfo]:
shard_mesh_size = mesh_info.shard_mesh_size
if not isinstance(reshard_after_forward, (bool, int)):
raise ValueError(
"reshard_after_forward should be a bool or an int representing the "
f"group size to reshard to, not {reshard_after_forward}"
)
# NOTE: `isinstance(False, int)` returns `True`.
if not isinstance(reshard_after_forward, bool) and isinstance(
reshard_after_forward, int
):
if (
reshard_after_forward < 1
or reshard_after_forward > shard_mesh_size
or shard_mesh_size % reshard_after_forward != 0
):
raise ValueError(
"If passing reshard_after_forward as an int, it should be a "
f"factor of {shard_mesh_size}, not {reshard_after_forward}"
)
elif reshard_after_forward == 1:
reshard_after_forward = False
elif reshard_after_forward == shard_mesh_size:
reshard_after_forward = True
post_forward_mesh_info = None
if reshard_after_forward is True:
post_forward_mesh_info = mesh_info
elif reshard_after_forward is not False: # int case
# For HSDP, we can flatten the two replicate dims into the 0th dim
post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward)
post_forward_mesh = DeviceMesh(
mesh_info.mesh.device_type, post_forward_mesh_tensor
)
post_forward_mesh_info = HSDPMeshInfo(
post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0
)
return post_forward_mesh_info
def _init_default_fully_shard_mesh() -> DeviceMesh:
"""Default to global CUDA mesh if possible else global CPU mesh."""
if not dist.distributed_c10d.is_initialized():
dist.distributed_c10d.init_process_group()
default_pg = dist.distributed_c10d._get_default_group()
device_type = "cuda" if torch.cuda.is_available() else "cpu"
mesh = init_device_mesh(device_type, mesh_shape=(default_pg.size(),))
return mesh
def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
if mesh.device_type == "cpu":
return torch.device("cpu")
device_handle = _get_device_handle(mesh.device_type)
return torch.device(mesh.device_type, device_handle.current_device())
def _get_managed_modules(root_modules: Tuple[nn.Module, ...]) -> List[nn.Module]:
modules: List[nn.Module] = []
root_modules_set = set(root_modules)
# Track visisted modules to avoid visiting shared modules multiple times
visited_modules: Set[nn.Module] = set()
def dfs(module: nn.Module) -> None:
"""
Runs a DFS to collect managed modules, not recursing into modules with
a non-composable API or ``fully_shard`` already applied.
"""
if not _is_composable_with_fsdp(module):
return
elif (
module not in root_modules_set
and _get_module_fsdp_state(module) is not None
):
return # nested `fully_shard` module
visited_modules.add(module)
for submodule in module.children():
if submodule not in visited_modules:
dfs(submodule)
modules.append(module)
for root_module in root_modules:
dfs(root_module)
return modules
def _verify_managed_param(name: str, param: nn.Parameter) -> None:
"""
Verify if the parameter is accepted by fully_shard. The only restriction now
is that the parameter cannot be a scalar tensor (param.numel == 0) since we
need at least one dim to shard.
"""
if len(param.shape) == 0:
raise ValueError(
"fully_shard doesn't support salar parameters. "
f"Change {name} to a 1D tensor with numel equal to 1."
)
def _get_managed_states(
modules: List[nn.Module],
) -> Tuple[List[nn.Parameter], List[torch.Tensor]]:
params: List[nn.Parameter] = []
buffers: List[torch.Tensor] = []
# Track visited parameters/buffers to avoid visiting shared parameters and
# buffers multiple times
visited_params: Set[nn.Parameter] = set()
visited_buffers: Set[torch.Tensor] = set()
for module in modules:
for name, param in module.named_parameters(recurse=False):
if param not in visited_params:
_verify_managed_param(name, param)
params.append(param)
visited_params.add(param)
for buffer in module.buffers(recurse=False):
if buffer not in visited_buffers:
buffers.append(buffer)
visited_buffers.add(buffer)
return params, buffers
def _move_states_to_device(
params: List[nn.Parameter],
buffers: List[torch.Tensor],
device: torch.device,
) -> None:
"""
We have FSDP move states to device for simpler and faster initialization
since FSDP almost always uses CUDA for training. We move parameters/buffers
rather than modules since modules to support ignoring parameters/buffers in
the future.
"""
# Follow the logic in `nn.Module._apply`
for tensor in itertools.chain(params, buffers):
if tensor.device == device or tensor.device.type == "meta":
# Keep meta-device tensors on meta device for deferred init
continue
if isinstance(tensor, DTensor):
if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type:
raise ValueError(
"Requires DTensor to have mesh of the same type as the FSDP mesh "
f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP"
)
raise AssertionError(
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
)
tensor_ = tensor
if is_traceable_wrapper_subclass(tensor_):
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
tensor_on_device = nn.Parameter(tensor.to(device))
torch.utils.swap_tensors(tensor, tensor_on_device)
else:
tensor.data = tensor.to(device)

View File

@ -0,0 +1,754 @@
# mypy: allow-untyped-defs
import itertools
from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Any, cast, List, Optional, Sequence, Tuple
import torch
import torch._dynamo.compiled_autograd as ca
import torch.nn as nn
from torch._prims_common import make_contiguous_strides_for
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.device_mesh import _mesh_resources
from torch.distributed.tensor.placement_types import _StridedShard, Placement
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_common import (
_chunk_with_empty,
_from_local_no_grad,
_get_dim0_chunked_size,
_raise_assert_with_print,
_to_dtype_if_needed,
FSDPMeshInfo,
HSDPMeshInfo,
)
"""
[Note: FSDP tensors]
FSDP considers the following tensors:
- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one
on the module when applying FSDP
- Sharded parameter: sharding the original parameter on dim-0 as a DTensor
over the main mesh
- All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather,
derived from the sharded parameter
- All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from
all-gathering the all-gather inputs
- Unsharded parameter: parameter used for forward/backward computation, derived
from the all-gather output; autograd leaf
We define these tensors to describe the general framework that can accomodate
extensions, where:
- all-gather-inputs = pre-all-gather-transform(sharded-parameter)
- unsharded-parameter = post-all-gather-transform(all-gather-outputs)
For the default ``torch.Tensor`` case, there is only one all-gather input, and
it shares the same underlying tensor data as the sharded parameter, meaning
that they can be thought of as the same tensors. The same applies for the
all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions,
these equivalences may no longer hold due to the pre/post-all-gather
transforms, and some may have multiple all-gather inputs/outputs (e.g.
quantized data and scales).
[Note: FSDP and autograd]
FSDP dynamically frees and allocates the unsharded parameter. Since autograd
can pack a reference to it or a view to save for backward, we use storage
resizing to implement the freeing/allocation since that preserves the aliasing.
This implies that we construct the unsharded parameter object once and write to
it in-place thereafter. For the default ``torch.Tensor` original parameter
case, the all-gather output and unsharded parameter share the same
data, so we use storage resizing on the all-gather output.
"""
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()")
@torch.library.impl(lib, "set_", "Meta")
@torch.library.impl(lib, "set_", "CUDA")
@torch.library.impl(lib, "set_", "CPU")
def set_(tensor, data):
tensor.set_(data)
"""
[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)]
Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op
(i.e. they show up as a mutation op in the middle of the AOT joint graph).
Reason:
Traceable FSDP2 compiled autograd BWD graph have the following traits:
(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param).
(3) They are both subclasses.
The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
So this doesn't work at all for Traceable FSDP2.
The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops.
This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
We can avoid this functionalization because:
(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created),
so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream.
(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places?
A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process.
Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner
(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to
make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph.
This requires a custom FX pass but we believe it's not hard to write and maintain.
Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors?
A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use,
so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input.
This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original
nn.Parameter in order to see the result of .set_.
"""
@torch.library.impl(lib, "set_", "Functionalize")
def set__functionalize(tensor, data):
torch._sync(tensor)
torch._sync(data)
# AOTDispatcher needs to know if any inputs had their storages mutated.
# (Why? It sometimes detaches inputs before sending them into the graph,
# when it sees that they do not need to have any gradients computed)
torch._functionalize_set_storage_changed(tensor)
tensor_inner = torch._from_functional_tensor(tensor)
data_inner = torch._from_functional_tensor(data)
with torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
):
torch.ops.fsdp.set_.default(tensor_inner, data_inner)
torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default)
class ShardedState(Enum):
"""
- ``SHARDED``: The sharded parameter is registered to the module. It is the
only contributor to parameter memory.
- ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a
smaller world size. Since this data should not be used for computation,
we do not register it to the module. Users should reshard the module
before any in-place modifications. Both it and the sharded parameter
contribute to parameter memory.
- ``UNSHARDED``: The unsharded parameter is registered to the module. Both
it and the sharded parameter contribute to parameter memory.
"""
SHARDED = auto()
SHARDED_POST_FORWARD = auto()
UNSHARDED = auto()
@dataclass
class ParamModuleInfo:
"""
For a parameter, this stores the module and the parameter name to be able
to do a parameter swap via ``setattr(module, param_name, ...)`` or to get
the parameter via ``getattr(module, param_name)``. We additionally save
shared modules and shared parameter names to update them accordingly.
"""
# Parameter names are unprefixed, e.g. "weight", not "lin.weight"
module: nn.Module
param_name: str
shared_modules: List[nn.Module] = field(default_factory=list)
shared_param_names: List[str] = field(default_factory=list)
@dataclass
class ExtensionsData:
# User-defined metadata passed from pre to post-all-gather
all_gather_metadata: Optional[Any] = None
# Save the all-gather input sizes to unflatten the all-gather outputs to ND
all_gather_input_sizes: Sequence[torch.Size] = () # ND
def clear(self):
self.all_gather_metadata = None
self.all_gather_input_sizes = ()
class FSDPParam:
"""
This class manages a parameter with FSDP or FSDP variants applied,
implementing dim-0 per-parameter sharding.
"""
orig_dtype: torch.dtype
param_dtype: Optional[torch.dtype]
reduce_dtype: Optional[torch.dtype]
_orig_size: torch.Size # ND
sharded_size: torch.Size # ND
contiguous_sharded_stride: Tuple[int, ...]
padded_sharded_param_size: torch.Size # ND
sharded_post_forward_size: torch.Size # ND
contiguous_sharded_post_forward_stride: Tuple[int, ...]
_sharded_param_data: torch.Tensor # 1D
sharded_param: nn.Parameter # ND
_sharded_post_forward_param_data: Optional[torch.Tensor] # 1D
_sharded_post_forward_param: Optional[nn.Parameter] # ND
_unsharded_param: nn.Parameter # ND
unsharded_accumulated_grad: Optional[torch.Tensor] # ND
_sharding_spec: DTensorSpec
# DTensor attributes (only defined for DTensor `param`):
_tp_spec: DTensorSpec
all_gather_outputs: List[torch.Tensor] # 1D
# All-gather extension attributes
_extensions_data: ExtensionsData
_unsharded_inner_tensors: List[torch.Tensor]
def __init__(
self,
param: nn.Parameter,
module_info: ParamModuleInfo,
mesh_info: FSDPMeshInfo,
post_forward_mesh_info: Optional[FSDPMeshInfo],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
offload_policy: OffloadPolicy,
):
self._module_info: ParamModuleInfo = module_info
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
self.device = device
self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
self.pin_memory = (
self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
)
self.grad_offload_event: Optional[torch.cuda.Event] = None
self._init_sharded_param(param, device)
if self.post_forward_mesh_info:
self._init_sharded_post_forward_param_metadata(param)
self._init_extensions()
self.all_gather_outputs: List[torch.Tensor] = []
self.unsharded_accumulated_grad = None
self._param_fqn: Optional[str] = None # prefixed from root module
# TODO: Remove this padding logic once DTensor pads the local tensor:
# https://github.com/pytorch/pytorch/issues/113045
self._post_load_hook_handle = (
module_info.module.register_load_state_dict_post_hook(
lambda *args, **kwargs: self.reset_sharded_param()
)
)
@torch.no_grad()
def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
if param.device != device and param.device.type != "meta":
raise AssertionError(
f"Expects the parameter to already be moved to device {device} but got {param.device}"
)
# TODO: Replace the sharded DTensor parameter construction logic with
# `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101
# TODO: Simplify the following sharded parameter padding logic after
# https://github.com/pytorch/pytorch/issues/113045
self.is_dtensor = isinstance(param, DTensor)
if self.is_dtensor:
self._tp_spec = cast(DTensor, param)._spec
dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh)
tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh)
if dp_global_mesh != tp_global_mesh or (
dp_global_mesh is None or tp_global_mesh is None
):
raise AssertionError(
"FSDP requires the DP and TP mesh to have the same parent mesh but got: \n"
f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}"
)
name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism"
assert dp_mesh.mesh_dim_names is not None, name_dims_error
assert tp_mesh.mesh_dim_names is not None, name_dims_error
submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
self._spmd_mesh = dp_global_mesh[submesh_names]
if len(self._tp_spec.placements) != 1:
raise NotImplementedError(
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
)
split_factor = self._tp_spec.num_shards_map[0]
assert (
2 <= self._spmd_mesh.ndim <= 3
), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
self._spmd_placements: Tuple[Placement, ...]
dp_shard_tp_placement = (
(
_StridedShard(0, split_factor=split_factor)
if split_factor > 1
else Shard(0)
),
self._tp_spec.placements[0],
)
if self._spmd_mesh.ndim == 2:
self._spmd_placements = dp_shard_tp_placement
else:
assert self.mesh_info.replicate_mesh_dim == 0
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
self._sharding_spec = DTensorSpec(
self._spmd_mesh,
self._spmd_placements,
tensor_meta=self._tp_spec.tensor_meta,
)
# NOTE: FSDP+TP does not support uneven sharding for now
# TODO: enable uneven sharding for FSDP+TP
if split_factor > 1: # FSDP has strided sharding on tensor dim 0
num_shards = self._sharding_spec.num_shards_map[0]
tensor_size_dim_0 = self._sharding_spec.shape[0]
if tensor_size_dim_0 % num_shards != 0:
raise NotImplementedError(
"FSDP+TP sharding does not support uneven sharding for now: "
f"tensor dim 0 has size {tensor_size_dim_0} which cannot be "
f"evenly sharded into {num_shards} shards."
)
param_data = cast(DTensor, param)._local_tensor
else:
self._spmd_mesh = self.mesh_info.mesh
if isinstance(self.mesh_info, HSDPMeshInfo):
self._spmd_placements = (Replicate(), Shard(0))
else:
self._spmd_placements = (Shard(0),)
self._sharding_spec = DTensorSpec(
self._spmd_mesh,
self._spmd_placements,
tensor_meta=TensorMeta(
param.size(),
param.stride(),
param.dtype,
),
)
param_data = param
self._orig_size = param_data.size()
self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
shard_rank = self.mesh_info.shard_mesh_rank
shard_world_size = self.mesh_info.shard_mesh_size
chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
sharded_param = chunks[shard_rank]
self.sharded_size = _get_dim0_chunked_size(sharded_param, param_data.size())
self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
padded_sharded_size = chunks[0].size() # 0th always padded
padded_sharded_param = param_data.new_zeros(padded_sharded_size)
self.padded_sharded_param_size = padded_sharded_param.size()
if sharded_param.numel() > 0:
padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param)
if self.offload_to_cpu and not padded_sharded_param.is_meta:
padded_sharded_param = padded_sharded_param.cpu()
if self.pin_memory:
padded_sharded_param = padded_sharded_param.pin_memory()
self._sharded_param_data = padded_sharded_param.view(-1)
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(padded_sharded_param[: sharded_param.size(0)])
)
self.sharded_param.requires_grad_(param.requires_grad)
# Let `param_data` be freed normally when its ref count reaches 0 when
# the `fully_shard` call returns to allow provided parameters to alias
self._setattr_on_modules(self.sharded_param)
self.sharded_state = ShardedState.SHARDED
def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None:
mesh_info = self.post_forward_mesh_info
assert mesh_info is not None # mypy
param_data = param._local_tensor if isinstance(param, DTensor) else param
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
self.sharded_post_forward_size = _get_dim0_chunked_size(
chunks[mesh_info.shard_mesh_rank], param_data.size()
)
self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
self.sharded_post_forward_size
)
def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
self.orig_dtype = self.sharded_param.dtype
# Clamp `param_dtype` to `None` if no casting is required
if param_dtype == self.orig_dtype:
param_dtype = None
self.param_dtype = param_dtype
self.reduce_dtype = reduce_dtype
# None indicates that the mixed precision is not enabled
def _init_extensions(self) -> None:
inner_tensor = self._sharded_local_tensor
has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather")
has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather")
if has_fsdp_pre_all_gather != has_fsdp_post_all_gather:
raise AssertionError(
"Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined "
f"if using all-gather extensions: {inner_tensor}"
)
if has_fsdp_pre_all_gather:
if self.padded_sharded_param_size != self._sharded_local_tensor.size():
raise NotImplementedError(
"FSDP all-gather extensions require even sharding on dim-0.\n"
f"{self._orig_size} is not divisible by FSDP world size {self.mesh_info.mesh.size()}."
)
self._extensions_data = ExtensionsData()
self._unsharded_inner_tensors: List[torch.Tensor] = []
def init_all_gather_outputs(
self,
all_gather_input_numels: List[int],
all_gather_input_dtypes: List[torch.dtype],
world_size: int,
device: torch.device,
force_recreate: bool = False,
):
if not force_recreate and len(self.all_gather_outputs) > 0:
return # already initialized
self.all_gather_outputs = [
torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
]
def init_unsharded_param(self):
"""
[Note: Invariants for torch.compile Traceable FSDP2]
1. Under compile, we always re-populate the content of `self._unsharded_param`
per AllGather using the slow path.
2. Under compile, we always recreate `self.all_gather_outputs` per AllGather.
This is to ensure the buffer creation is internal to the graph and
avoid `self.all_gather_outputs` being captured as a graph input.
3. Under compile, at the end of `free_unsharded_param()`, we always clean up
`self.all_gather_outputs` and `self._unsharded_inner_tensors`,
to avoid them being captured as graph output.
With these invariants, only these tensors will be inputs to the graph:
- Sharded parameters
- Placeholders for the `self._unsharded_param` nn.Parameter
"""
if not ca.compiled_autograd_enabled and hasattr(
self, "_unsharded_param"
): # after the 1st all-gather
inner_tensor = self._sharded_local_tensor
if not hasattr(inner_tensor, "fsdp_post_all_gather"):
return # already initialized
for tensor in self._unsharded_inner_tensors:
alloc_storage(tensor)
all_gather_outputs = self._unflatten_all_gather_outputs()
inner_tensor.fsdp_post_all_gather(
all_gather_outputs,
self._extensions_data.all_gather_metadata,
self.param_dtype or self.orig_dtype,
out=self._unsharded_param,
)
self._extensions_data.clear()
return
inner_tensor = self._sharded_local_tensor
if not ca.compiled_autograd_enabled and hasattr(
inner_tensor, "fsdp_post_all_gather"
):
all_gather_outputs = self._unflatten_all_gather_outputs()
(
unsharded_tensor,
self._unsharded_inner_tensors,
) = inner_tensor.fsdp_post_all_gather(
all_gather_outputs,
self._extensions_data.all_gather_metadata,
self.param_dtype or self.orig_dtype,
)
self._extensions_data.clear()
else:
# For the default path (no post-all-gather), the all-gather output
# gives the unsharded parameter data directly
assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}"
unsharded_tensor = self.all_gather_outputs[0]
unsharded_param = torch.as_strided(
unsharded_tensor,
self._orig_size,
self._contiguous_orig_stride,
storage_offset=0,
)
if self.is_dtensor:
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
if hasattr(self, "_unsharded_param"):
assert ca.compiled_autograd_enabled
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
self._unsharded_param
):
torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)
def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]:
return tuple(
t.view(-1, *s[1:])
for t, s in zip(
self.all_gather_outputs, self._extensions_data.all_gather_input_sizes
)
)
def to_sharded(self) -> None:
self._setattr_on_modules(self.sharded_param)
self.free_unsharded_param()
self.sharded_state = ShardedState.SHARDED
def to_sharded_post_forward(self) -> None:
if self.is_dtensor:
raise NotImplementedError(
"Resharding to smaller mesh with TP is not supported yet"
)
self._assert_in_states(ShardedState.UNSHARDED)
assert self.post_forward_mesh_info is not None # mypy
assert len(self.all_gather_outputs) == 1
shard_world_size = self.post_forward_mesh_info.shard_mesh_size
if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0:
_raise_assert_with_print(
f"All-gather output size ({numel}) must be divisible by the shard "
f"world size ({shard_world_size})"
)
shard_rank = self.post_forward_mesh_info.shard_mesh_rank
sharded_numel = numel // shard_world_size
self._sharded_post_forward_param_data = (
self.all_gather_outputs[0].narrow(
0, sharded_numel * shard_rank, sharded_numel
)
).clone() # clone to be able to free all-gather output
sharded_post_forward_tensor = torch.as_strided(
self._sharded_post_forward_param_data,
size=self.sharded_post_forward_size,
stride=self.contiguous_sharded_post_forward_stride,
storage_offset=0,
)
self._sharded_post_forward_param = nn.Parameter(
self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
)
self._setattr_on_modules(self._sharded_post_forward_param)
self.free_unsharded_param()
self.sharded_state = ShardedState.SHARDED_POST_FORWARD
def to_unsharded(self) -> None:
# Assume that the data has been allocated and all-gathered
set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
self._setattr_on_modules(self._unsharded_param)
if self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
# The data is allocated in the default stream via the post-forward
# reshard and must be kept alive for the next all-gather copy-in.
# Since we call this method after the copy-out, the data's lifetime
# is ensured without further synchronization.
self._sharded_post_forward_param = None
self._sharded_post_forward_param_data = None # free
self.sharded_state = ShardedState.UNSHARDED
def _setattr_on_modules(self, param: nn.Parameter) -> None:
unsafe_setattr_param(
self._module_info.module, self._module_info.param_name, param
)
for shared_module, shared_param_name in zip(
self._module_info.shared_modules, self._module_info.shared_param_names
):
unsafe_setattr_param(shared_module, shared_param_name, param)
def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
"""
Converts a local tensor representing either the sharded parameter or
sharded gradient to DTensor.
"""
if tensor.shape != self.sharded_size:
_raise_assert_with_print(
f"Expects size {self.sharded_size} but got {tensor.shape}"
)
return _from_local_no_grad(
tensor,
self._sharding_spec,
)
def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor:
if tensor.shape != self.sharded_post_forward_size:
_raise_assert_with_print(
f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}"
)
assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo)
# TODO: Prefer this DTensor to be read-only and generalize the
# placement once we support TP.
post_forward_sharding_spec = DTensorSpec(
self.post_forward_mesh_info.mesh,
(Replicate(), Shard(0)),
tensor_meta=self._sharding_spec.tensor_meta,
)
return _from_local_no_grad(tensor, post_forward_sharding_spec)
def to_accumulated_grad_if_needed(self) -> None:
# Access `_unsharded_param` to bypass the sharded state check since we
# prefer to reshard before upcasting the gradient to save memory
if (
self.reduce_dtype is None
or self._unsharded_param.grad is None
or self._unsharded_param.grad.dtype == self.reduce_dtype
):
return
unsharded_grad = self._unsharded_param.grad
self._unsharded_param.grad = None
self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype)
def accumulate_unsharded_grad_if_needed(self) -> None:
if (
self.unsharded_accumulated_grad is not None
and self.unsharded_param.grad is not None
):
self.unsharded_accumulated_grad += self.unsharded_param.grad
self.unsharded_param.grad = None
def alloc_all_gather_outputs(self) -> None:
for tensor in self.all_gather_outputs:
alloc_storage(tensor)
def free_unsharded_param(self) -> None:
for tensor in itertools.chain(
self.all_gather_outputs, self._unsharded_inner_tensors
):
free_storage(tensor)
if ca.compiled_autograd_enabled:
self.all_gather_outputs = []
self._unsharded_inner_tensors = []
@property
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D
self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
if self.sharded_state == ShardedState.SHARDED:
if not ca.compiled_autograd_enabled and hasattr(
self._sharded_local_tensor, "fsdp_pre_all_gather"
):
sharded_local_tensor = self._sharded_local_tensor
if self.offload_to_cpu:
sharded_local_tensor = sharded_local_tensor.to(
self.device, non_blocking=True
)
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh)
self._extensions_data.all_gather_input_sizes = [
t.size() for t in all_gather_inputs
]
return [t.view(-1) for t in all_gather_inputs]
sharded_param_data = self._sharded_param_data
if self.offload_to_cpu:
sharded_param_data = sharded_param_data.to(
self.device, non_blocking=True
)
return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
if not ca.compiled_autograd_enabled and hasattr(
self._sharded_local_tensor, "fsdp_pre_all_gather"
):
raise NotImplementedError
all_gather_input = _to_dtype_if_needed(
cast(torch.Tensor, self._sharded_post_forward_param_data),
self.param_dtype,
)
return [all_gather_input]
return [torch.empty(0)] # mypy
@property
def unsharded_param(self) -> nn.Parameter: # ND
self._assert_in_states(ShardedState.UNSHARDED)
return self._unsharded_param
@property
def unsharded_grad_data(self) -> torch.Tensor:
grad = self.unsharded_param.grad
assert grad is not None, "Expects unsharded_param.grad to not be None"
return self._get_grad_inner_tensor(grad)
@property
def unsharded_accumulated_grad_data(self) -> torch.Tensor:
grad = self.unsharded_accumulated_grad
assert grad is not None, "Expects unsharded_accumulated_grad to not be None"
return self._get_grad_inner_tensor(grad)
def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor:
if self.is_dtensor:
if isinstance(grad, AsyncCollectiveTensor):
grad = grad.wait()
assert isinstance(grad, DTensor), f"{type(grad)}"
if any(pl.is_partial() for pl in grad.placements):
placements = [
Replicate() if pl.is_partial() else pl for pl in grad.placements
]
grad = grad.redistribute(placements=placements)
grad = grad._local_tensor
return grad
@property
def _sharded_local_tensor(self) -> torch.Tensor:
return cast(DTensor, self.sharded_param)._local_tensor
def _assert_in_states(self, *states: ShardedState) -> None:
if self.sharded_state not in states:
_raise_assert_with_print(
f"Expects to be in one of {states}, not {self.sharded_state}"
)
def reset_sharded_param(self):
# For ops like `nn.Module._apply` or `load_state_dict(assign=True)`
# that change the sharded parameter tensor, we may need to re-pad the
# sharded local tensor and re-save the reference.
module_info = self._module_info
new_param = getattr(module_info.module, module_info.param_name)
if new_param is not self.sharded_param:
if torch.__future__.get_swap_module_params_on_conversion():
raise AssertionError(
f"Expects swap_tensors to preserve object but got {new_param} "
f"instead of {self.sharded_param}"
)
self.sharded_param = new_param
local_tensor = new_param._local_tensor
if local_tensor.is_meta:
return
padded_sharded_size = self.padded_sharded_param_size
if local_tensor.size() != padded_sharded_size:
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
padded_local_tensor[: local_tensor.size(0)].copy_(local_tensor)
local_tensor = padded_local_tensor
if self.pin_memory and not local_tensor.is_pinned():
local_tensor = local_tensor.cpu().pin_memory()
self._sharded_param_data = local_tensor.view(-1)
assert isinstance(self.sharded_param, DTensor) # mypy
self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]]
def __repr__(self):
return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})"
def alloc_storage(tensor: torch.Tensor) -> None:
size = tensor.numel() * tensor.itemsize
if (storage := tensor.untyped_storage()).size() != size:
storage.resize_(size)
def free_storage(tensor: torch.Tensor) -> None:
if (storage := tensor.untyped_storage()).size() != 0:
storage.resize_(0)
# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial
# CPU overhead, if the module did not override it. For FSDP, we know we do not
# need those checks when transitioning between sharded/unsharded parameters.
def unsafe_setattr_param(
module: nn.Module, param_name: str, param: nn.Parameter
) -> None:
if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__:
module._parameters[param_name] = param
else: # slow path
setattr(module, param_name, param)
def set_requires_grad_if_needed(
src_tensor: torch.Tensor, dst_tensor: torch.Tensor
) -> None:
# Only call `requires_grad_` if needed to avoid the Python <> C++ context
# switch overhead
if src_tensor.requires_grad != dst_tensor.requires_grad:
dst_tensor.requires_grad_(src_tensor.requires_grad)

View File

@ -0,0 +1,614 @@
# mypy: allow-untyped-defs
import contextlib
import logging
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
import torch
import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates
from torch.profiler import record_function
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils.hooks import RemovableHandle
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_collectives import (
AllGatherResult,
foreach_all_gather,
foreach_all_gather_copy_out,
foreach_reduce,
)
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
logger = logging.getLogger("torch.distributed._composable.fsdp")
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
"""
[Note: Overlapping all-gather copy-in and all-gather]
For implicit forward prefetching, we want to overlap the next copy-in with the
current all-gather. We do so using a separate copy-in stream. However, since
we have the all-gather input as a view into the output, we must make sure to
copy into different memory from the current all-gather's output. Thus, we keep
a reference to the current all-gather's output and have the next FSDP parameter
group free it after its copy-in. Finally, we have the last FSDP state flush the
reference to avoid holding onto memory after forward.
"""
class FSDPCommContext:
"""This has the communication state shared across FSDP states/parameter groups."""
def lazy_init(self):
if not torch.cuda.is_available():
raise RuntimeError("FSDP requires CUDA for streams")
# Setting the all-gather/reduce-scatter streams to be higher priority
# can help avoid some issues where their copies in/out are delayed and
# block computation (this is different from high-pri NCCL streams)
high_priority = -1
# All-gather state and copy-in stream allow overlapping the next
# copy-in with the current all-gather in forward; copy-in overlaps with
# reduce-scatter in backward without the separate copy-in stream
self.all_gather_copy_in_stream = torch.cuda.Stream(priority=high_priority)
# All-gather stream allows overlapping next all-gather with current
# forward compute
self.all_gather_stream = torch.cuda.Stream(priority=high_priority)
# Reduce-scatter stream gives separate execution "thread" for post-
# backward logic like pre/post-gradient division and reduce-scatter
self.reduce_scatter_stream = torch.cuda.Stream(priority=high_priority)
# Run the HSDP all-reduces concurrently with all-gather/reduce-scatter
# since collectives use different network resources and can overlap
# in the typical intra-node sharding / inter-node replication case
self.all_reduce_stream = torch.cuda.Stream()
# All-gather/reduce-scatter states keep references to collective
# tensors produced in one stream and used in another and accompanying
# CUDA events for synchronization
self.all_gather_state: Optional[AllGatherState] = None
self.reduce_scatter_state: Optional[ReduceScatterState] = None
# Post-forward order for explicit backward prefetching
self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles
def get_all_gather_streams(
self, training_state: TrainingState
) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]:
if training_state in (TrainingState.FORWARD, TrainingState.PRE_BACKWARD):
# Use separate streams for implicit prefetching
return self.all_gather_copy_in_stream, self.all_gather_stream
current_stream = torch.cuda.current_stream()
return current_stream, current_stream
# See [Note: Overlapping all-gather copy-in and all-gather]
class AllGatherState(NamedTuple):
all_gather_result: AllGatherResult
event: torch.cuda.Event # all-gather copy-out
class ReduceScatterState(NamedTuple):
reduce_scatter_input: torch.Tensor
event: torch.cuda.Event # reduce-scatter event
class FSDPParamGroup:
"""This class represents a parameter group to communicate together."""
_orig_dtype: torch.dtype
_reduce_dtype: Optional[torch.dtype]
def __init__(
self,
params: List[nn.Parameter],
modules: Tuple[nn.Module, ...],
mesh_info: FSDPMeshInfo,
post_forward_mesh_info: Optional[FSDPMeshInfo],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
offload_policy: OffloadPolicy,
):
self.modules = modules # permit ref cycle because 1:1 lifetime
param_module_infos = _get_param_module_infos(params, modules)
self.fsdp_params = [
FSDPParam(
param,
module_info,
mesh_info,
post_forward_mesh_info,
device,
mp_policy,
offload_policy,
)
for param, module_info in zip(params, param_module_infos)
]
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
self.device = device
self.mp_policy = mp_policy
self._training_state = TrainingState.IDLE
# Group's sharded state always matches its parameters' sharded states
self._sharded_state = ShardedState.SHARDED
self._module_fqn: Optional[str] = None # prefixed from root module
# Only consider resetting sharded parameters once in lazy init since it
# can incur nontrivial overhead to reset them
self._reset_sharded_params: bool = False
# - Hook state
self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {}
self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {}
# - Communication and communication/computation overlap
self.comm_ctx = FSDPCommContext()
# Group's indices in the shared post-forward order
self._post_forward_indices: List[int] = []
# Whether to reduce gradients at all (whether for FSDP or HSDP)
self.reduce_grads: bool = True
# Whether to all-reduce gradients for HSDP; only used if
# `self.reduce_grads` is true, in which case setting this to false
# means reduce-scatter but no all-reduce
self.all_reduce_grads: bool = True
# Whether to reshard parameters after backward (only useful for
# gradient accumulation)
self.reshard_after_backward: bool = True
# Optional custom reduce-scatter reduce op (e.g. to divide by a
# factor other than the shard world size)
self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None
# - CUDA events for stream synchronization
# Holds the all-gather output buffer, sync objects, and metadata
self._all_gather_result: Optional[AllGatherResult] = None
# Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of
# the group's post-backward (e.g. reduce-scatter, all-reduce and div), which
# should be waited on at the end of backward
self._post_reduce_event: Optional[torch.cuda.Event] = None
# Holds the reshard-after-forward CUDA event when resharding to a
# different world size, which should be waited on in the next unshard
self._reshard_after_forward_event: Optional[torch.cuda.Event] = None
# Only for HSDP, if accumulating gradients without all-reduce, save the
# partial reduce output (only reduce-scattered but not all-reduced)
self._partial_reduce_output: Optional[torch.Tensor] = None
# Initialization #
def _init_mp_dtypes(self) -> None:
for fsdp_param in self.fsdp_params:
fsdp_param.init_dtype_attrs(self.mp_policy)
orig_dtypes = {fsdp_param.orig_dtype for fsdp_param in self.fsdp_params}
if len(orig_dtypes) != 1:
# This can be relaxed if we copy-out for the reduce-scatter
raise AssertionError(
f"FSDP expects uniform original parameter dtype but got {orig_dtypes}"
)
self._orig_dtype = next(iter(orig_dtypes))
reduce_dtypes = {fsdp_param.reduce_dtype for fsdp_param in self.fsdp_params}
if len(reduce_dtypes) != 1:
# This can be relaxed if we issue one reduce-scatter per reduce
# dtype (but we would need a way for users to specify multiple
# reduce dtypes)
raise AssertionError(
f"FSDP expects uniform reduce dtype but got {reduce_dtypes}"
)
self._reduce_dtype = next(iter(reduce_dtypes))
def lazy_init(self):
# Lazy init should be idempotent
# Users may change or register parameters after construction time.
# For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
# other parameters (e.g. loaded from the state dict).
if self.is_sharded and not self._reset_sharded_params:
for fsdp_param in self.fsdp_params:
fsdp_param.reset_sharded_param()
self._reset_sharded_params = True
param_names_on_meta = [
fsdp_param._param_fqn
for fsdp_param in self.fsdp_params
if fsdp_param.sharded_param.device.type == "meta"
]
if param_names_on_meta:
raise RuntimeError(
"FSDP parameters should be materialized from meta device before training, "
f"but the following were still on meta device: {param_names_on_meta}\n"
"For example, call module.to_empty(device) to materialize to device and "
"call module.reset_parameters() on each module to initialize values."
)
# Initialize mixed precision attributes lazily in case the user changes
# the parameter dtypes after construction time but before forward
self._init_mp_dtypes()
self._register_state_dict_hooks()
# Runtime #
def unshard(self, async_op: bool = False):
if self._all_gather_result is not None: # already called, pending wait
return
if self.is_unsharded:
return # no-op
if self._reshard_after_forward_event is not None:
# Resharded parameter data is allocated in the default stream and
# used in the all-gather streams
self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
self._reshard_after_forward_event = None
with record_function(self._with_fqn("FSDP::all_gather")):
self._all_gather_result = foreach_all_gather(
self.fsdp_params,
self._all_gather_process_group,
async_op,
*self.comm_ctx.get_all_gather_streams(self._training_state),
self.device,
)
def wait_for_unshard(self):
"""
1. In forward with implict prefetching, to overlap the current copy-out
with the next all-gather, we save a reference to the current all-gather
result to free after the next copy-out.
2. Otherwise (explicit prefetching or in backward), we free the
all-gather result immediately after the current copy-out since we can
already overlap the current copy-out with the previous reduce-scatter.
"""
if not self._all_gather_result:
return # no preceding unshard
if self._training_state == TrainingState.FORWARD: # implicit prefetch
if prev_all_gather_state := self.comm_ctx.all_gather_state:
self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
self.comm_ctx.all_gather_state = None # free the all-gather result
with record_function(self._with_fqn("FSDP::all_gather_copy_out")):
foreach_all_gather_copy_out(
self._all_gather_result,
self.fsdp_params,
self._all_gather_process_group,
)
for fsdp_param in self.fsdp_params:
fsdp_param.init_unsharded_param()
self._to_unsharded()
all_gather_copy_out_event = torch.cuda.Event()
all_gather_copy_out_event.record()
if self._training_state == TrainingState.FORWARD:
self.comm_ctx.all_gather_state = AllGatherState(
self._all_gather_result, all_gather_copy_out_event
)
else:
self._wait_all_gather_streams_on_event(all_gather_copy_out_event)
self._all_gather_result = None # free unless saved in `all_gather_state`
def _wait_all_gather_streams_on_event(self, event: torch.cuda.Event):
# Calling `unshard` before lazy init means streams are not initialized
if hasattr(self.comm_ctx, "all_gather_copy_in_stream"):
self.comm_ctx.all_gather_copy_in_stream.wait_event(event)
if hasattr(self.comm_ctx, "all_gather_stream"):
self.comm_ctx.all_gather_stream.wait_event(event)
def reshard(self):
if self._training_state == TrainingState.FORWARD:
if not self._reshard_after_forward:
return
if self._use_post_forward_mesh:
self._to_sharded_post_forward()
self._reshard_after_forward_event = torch.cuda.Event()
self._reshard_after_forward_event.record()
return
self._to_sharded()
def pre_forward(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::pre_forward"))
with record_function(self._with_fqn("FSDP::pre_forward")):
self._training_state = TrainingState.FORWARD
self.unshard()
self.wait_for_unshard()
args, kwargs = self._register_post_backward_hook(args, kwargs)
return args, kwargs
def post_forward(self, module: nn.Module, input: Any, output: Any):
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::post_forward"))
with record_function(self._with_fqn("FSDP::post_forward")):
self.reshard()
self._record_post_forward()
self._training_state = TrainingState.IDLE
return output
def _record_post_forward(self) -> None:
# Since a group has one pre-backward unshard for each forward call
# before the backward, we record each usage (with multiplicity)
post_forward_index = len(self.comm_ctx.post_forward_order)
self.comm_ctx.post_forward_order.append(self)
self._post_forward_indices.append(post_forward_index)
def pre_backward(self, default_prefetch: bool, *unused: Any):
if self._training_state == TrainingState.PRE_BACKWARD:
return
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::pre_backward"))
with record_function(self._with_fqn("FSDP::pre_backward")):
self._training_state = TrainingState.PRE_BACKWARD
self.unshard() # no-op if prefetched
self.wait_for_unshard()
if default_prefetch and not ca.compiled_autograd_enabled:
self._backward_prefetch()
def post_backward(self, *unused: Any):
if not ca.compiled_autograd_enabled:
logger.debug("%s", self._with_fqn("FSDP::post_backward"))
self._training_state = TrainingState.POST_BACKWARD
with record_function(self._with_fqn("FSDP::post_backward_accumulate")):
for fsdp_param in self.fsdp_params:
fsdp_param.accumulate_unsharded_grad_if_needed()
with record_function(self._with_fqn("FSDP::post_backward_reshard")):
if not self.reduce_grads:
if self.reshard_after_backward:
self.reshard()
for fsdp_param in self.fsdp_params:
fsdp_param.to_accumulated_grad_if_needed()
return
# Save the autograd-computed gradients before resharding to only
# access the unsharded parameters when their data is present
fsdp_params_with_grad: List[FSDPParam] = []
unsharded_grads: List[torch.Tensor] = []
for fsdp_param in self.fsdp_params:
# May have an accumulated gradient of the reduce dtype if the
# previous backward did not reduce-scatter
if fsdp_param.unsharded_accumulated_grad is not None:
fsdp_params_with_grad.append(fsdp_param)
unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data)
fsdp_param.unsharded_accumulated_grad = None
elif fsdp_param.unsharded_param.grad is not None:
fsdp_params_with_grad.append(fsdp_param)
unsharded_grads.append(fsdp_param.unsharded_grad_data)
fsdp_param.unsharded_param.grad = None
if self.reshard_after_backward:
self.reshard()
if len(fsdp_params_with_grad) == 0:
return
with record_function(self._with_fqn("FSDP::post_backward_reduce")):
if self.comm_ctx.reduce_scatter_state is not None:
torch.cuda.current_stream().wait_event(
self.comm_ctx.reduce_scatter_state.event
)
self.comm_ctx.reduce_scatter_state = None
(
reduce_scatter_input,
reduce_scatter_event,
self._post_reduce_event,
self._partial_reduce_output,
) = foreach_reduce(
fsdp_params_with_grad,
unsharded_grads,
self._reduce_scatter_process_group,
self.comm_ctx.reduce_scatter_stream,
self._orig_dtype,
self._reduce_dtype,
self.device,
self.reduce_scatter_reduce_op,
self._all_reduce_process_group if self._is_hsdp else None,
self.comm_ctx.all_reduce_stream,
self.all_reduce_grads,
self._partial_reduce_output,
)
self.comm_ctx.reduce_scatter_state = ReduceScatterState(
reduce_scatter_input, reduce_scatter_event
)
def finalize_backward(self):
if self._post_reduce_event is not None:
torch.cuda.current_stream().wait_event(self._post_reduce_event)
self._post_reduce_event = None
for fsdp_param in self.fsdp_params:
if fsdp_param.grad_offload_event is not None:
fsdp_param.grad_offload_event.synchronize()
fsdp_param.grad_offload_event = None
self._post_forward_indices.clear()
def _backward_prefetch(self) -> None:
if self._training_state == TrainingState.PRE_BACKWARD:
if not self._post_forward_indices:
# Can be cleared if running multiple `backward`s
return
curr_index = self._post_forward_indices.pop()
if (target_index := curr_index - 1) < 0:
return
# Prefetch naively using the reverse post-forward order, which may
# have mistargeted prefetches if not all modules used in forward
# are used in this backward
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
self._prefetch_unshard(target_fsdp_param_group, "backward")
@staticmethod
def _prefetch_unshard(
target_fsdp_param_group: "FSDPParamGroup", pass_type: str
) -> None:
if pass_type == "backward":
training_state = TrainingState.PRE_BACKWARD
elif pass_type == "forward":
training_state = TrainingState.FORWARD
else:
raise ValueError(f"Unknown pass type: {pass_type}")
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
f"FSDP::{pass_type}_prefetch for {target_fqn}"
), target_fsdp_param_group.use_training_state(training_state):
target_fsdp_param_group.unshard()
# Utilities #
def _to_sharded(self):
if not self.is_sharded:
for fsdp_param in self.fsdp_params:
fsdp_param.to_sharded()
self._sharded_state = ShardedState.SHARDED
def _to_sharded_post_forward(self):
if not self.is_sharded_post_forward:
for fsdp_param in self.fsdp_params:
fsdp_param.to_sharded_post_forward()
self._sharded_state = ShardedState.SHARDED_POST_FORWARD
def _to_unsharded(self):
if not self.is_unsharded:
for fsdp_param in self.fsdp_params:
fsdp_param.to_unsharded()
self._sharded_state = ShardedState.UNSHARDED
@property
def is_sharded(self) -> bool:
return self._sharded_state == ShardedState.SHARDED
@property
def is_sharded_post_forward(self) -> bool:
return self._sharded_state == ShardedState.SHARDED_POST_FORWARD
@property
def is_unsharded(self) -> bool:
return self._sharded_state == ShardedState.UNSHARDED
@contextlib.contextmanager
def use_training_state(self, training_state: TrainingState):
old_training_state = self._training_state
self._training_state = training_state
try:
yield
finally:
self._training_state = old_training_state
# Hook Registration #
def _register_post_backward_hook(
self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
# Compile relies on `root_post_backward_callback` to call each
# `FSDPParamGroup.post_backward`
if ca.compiled_autograd_enabled:
return args, kwargs
if not torch.is_grad_enabled():
return args, kwargs
args_list, args_spec = tree_flatten(args)
kwargs_list, kwargs_spec = tree_flatten(kwargs)
args_kwargs_list = list(args_list) + list(kwargs_list)
inp_tensor_indices: List[int] = []
inp_tensors: List[torch.Tensor] = []
for i, obj in enumerate(args_kwargs_list):
if torch.is_tensor(obj) and obj.requires_grad:
inp_tensor_indices.append(i)
inp_tensors.append(obj)
if len(inp_tensors) == 0:
return args, kwargs # no tensors that require gradients
inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors)
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
args_kwargs_list[inp_tensor_idx] = inp_tensor
args_list = args_kwargs_list[: len(args_list)]
kwargs_list = args_kwargs_list[len(args_list) :]
args = tree_unflatten(args_list, args_spec)
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
return args, kwargs
def _register_state_dict_hooks(self) -> None:
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
assert (
num_pre_save_hooks == num_pre_load_hooks
), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
if num_pre_save_hooks > 0:
return # already registered
modules_with_fsdp_params: Set[nn.Module] = {
fsdp_param._module_info.module for fsdp_param in self.fsdp_params
}
def to_sharded_hook(*args: Any, **kwargs: Any) -> None:
self._to_sharded()
for module in modules_with_fsdp_params:
self._module_to_pre_save_state_dict_hook_handle[
module
] = module.register_state_dict_pre_hook(to_sharded_hook)
self._module_to_pre_load_state_dict_hook_handle[
module
] = module._register_load_state_dict_pre_hook(to_sharded_hook)
# Properties #
@property
def _reshard_after_forward(self) -> bool:
return self.post_forward_mesh_info is not None
@property
def _use_post_forward_mesh(self) -> bool:
return (
self._reshard_after_forward
and self.mesh_info != self.post_forward_mesh_info
)
@property
def _is_hsdp(self) -> bool:
return isinstance(self.mesh_info, HSDPMeshInfo)
@property
def _all_gather_process_group(self) -> dist.ProcessGroup:
mesh_info = (
cast(FSDPMeshInfo, self.post_forward_mesh_info)
if self.is_sharded_post_forward
else self.mesh_info
)
assert isinstance(mesh_info, FSDPMeshInfo)
return mesh_info.shard_process_group
@property
def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
assert isinstance(self.mesh_info, FSDPMeshInfo)
return self.mesh_info.shard_process_group
@property
def _all_reduce_process_group(self) -> dist.ProcessGroup:
assert isinstance(self.mesh_info, HSDPMeshInfo)
return self.mesh_info.replicate_process_group
def _with_fqn(self, label: str) -> str:
if self._module_fqn:
return f"{label} ({self._module_fqn})"
return label
def __repr__(self):
return f"FSDPParamGroup(fqn={self._module_fqn})"
def _get_param_module_infos(
params: List[nn.Parameter], modules: Tuple[nn.Module, ...]
) -> List[ParamModuleInfo]:
"""
Shared parameter: lin1.weight = lin2.weight
Shared module: mlp.lin1 = mlp.lin2
We do not remove duplicates when traversing both modules and parameters to
find shared modules' parameters and shared parameters within a module.
"""
params_set = set(params)
param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {}
for module in modules:
for _, submodule in module.named_modules(remove_duplicate=False):
for param_name, param in _named_parameters_with_duplicates(
submodule, recurse=False
):
if param in params_set:
if param not in param_to_module_info:
param_to_module_info[param] = ParamModuleInfo(
submodule, param_name
)
else:
param_to_module_info[param].shared_modules.append(submodule)
param_to_module_info[param].shared_param_names.append(
param_name
)
if len(param_to_module_info) != len(params):
raise AssertionError(f"Some parameters are not in the module tree of {module}")
return [param_to_module_info[param] for param in params]
class RegisterPostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
# All tensors in `inputs` should require gradient
ctx.param_group = param_group
return inputs
@staticmethod
def backward(ctx, *grads: torch.Tensor):
ctx.param_group.post_backward()
return (None,) + grads

View File

@ -0,0 +1,383 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
import logging
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
)
import torch
import torch._dynamo.compiled_autograd as ca
import torch.nn as nn
from torch._logging import warning_once
from torch.autograd import Variable
from torch.autograd.graph import _MultiHandle
from torch.distributed._composable_state import (
_get_module_state,
_insert_module_state,
_State,
)
from torch.distributed.utils import _to_kwargs
from torch.utils._pytree import tree_flatten, tree_map
from ._fsdp_api import MixedPrecisionPolicy
from ._fsdp_common import _cast_fp_tensor, TrainingState
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
if TYPE_CHECKING:
from ._fsdp_param import FSDPParam
logger = logging.getLogger("torch.distributed._composable.fsdp")
class FSDPStateContext:
"""This has state shared across FSDP states."""
def __init__(self) -> None:
# All FSDP states in the root state's module tree
self.all_states: List[FSDPState] = []
# Iteration's forward root runs the once-per-forward logic; this root
# may not be the overall root set by lazy initialization in cases where
# only a submodule runs forward (e.g. encoder-only for eval)
self.iter_forward_root: Optional[FSDPState] = None
# Final callback should only be queued once per backward
self.post_backward_final_callback_queued: bool = False
# Whether to finalize backward in this backward's final callback
self.is_last_backward: bool = True
# Optional user-provided event recorded after optimizer for the
# all-gather streams to wait on in the root pre-forward
self.post_optim_event: Optional[torch.cuda.Event] = None
def disable_if_config_true(func):
@functools.wraps(func)
def fsdp_hook_wrapper(*args, **kwargs):
if torch._dynamo.config.skip_fsdp_hooks:
return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
else:
return func(*args, **kwargs)
return fsdp_hook_wrapper
class FSDPState(_State):
def __init__(self) -> None:
super().__init__()
self._fsdp_param_group: Optional[FSDPParamGroup] = None
self._is_root: Optional[bool] = None # root set during lazy init
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []
self._modules_to_run_forward: Set[nn.Module] = set()
# Define a separate init since `__init__` is called in the contract
def init(
self,
modules: Tuple[nn.Module, ...],
device: torch.device,
mp_policy: MixedPrecisionPolicy,
) -> None:
for module in modules:
_insert_module_state(module, self)
self._modules = modules
self._device = device
self._mp_policy = mp_policy
if len(modules) == 1:
self._pre_forward_hook_handle = modules[0].register_forward_pre_hook(
self._pre_forward, prepend=True, with_kwargs=True
)
self._post_forward_hook_handle = modules[0].register_forward_hook(
self._post_forward, prepend=False
)
else:
hook_handle = _register_group_forward_hooks(
modules,
self._pre_forward,
self._post_forward,
self._modules_to_run_forward,
)
self._pre_forward_hook_handle = hook_handle
self._post_forward_hook_handle = hook_handle
def _root_pre_forward(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
self._lazy_init()
if self._state_ctx.iter_forward_root is not None:
return args, kwargs
if not ca.compiled_autograd_enabled:
logger.debug("FSDP::root_pre_forward")
self._state_ctx.iter_forward_root = self
with torch.profiler.record_function("FSDP::root_pre_forward"):
# Wait for optimizer before implicitly prefetched all-gathers
if (event := self._state_ctx.post_optim_event) is not None:
self._comm_ctx.all_gather_copy_in_stream.wait_event(event)
self._comm_ctx.all_gather_stream.wait_event(event)
self._state_ctx.post_optim_event = None
else:
current_stream = torch.cuda.current_stream()
self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream)
self._comm_ctx.all_gather_stream.wait_stream(current_stream)
if self._device.type == "cuda":
with torch.profiler.record_function("FSDP::inputs_to_device"):
args_tuple, kwargs_tuple = _to_kwargs(
args, kwargs, self._device, False
) # same as DDP
args, kwargs = args_tuple[0], kwargs_tuple[0]
return args, kwargs
def _lazy_init(self) -> None:
"""
Lazy initialization represents when all modules' parallelisms have
finalized (e.g. FSDP has been applied to all desired modules). This
means that we can determine which state is the root, and we do so by
the 1st state to run forward.
"""
if self._is_root is not None:
return # no-op: already initialized
self._is_root = True
if len(self._modules) > 1:
raise RuntimeError(
f"FSDP requires a single root module but got {self._modules}"
)
root_module = self._modules[0]
visited_states: Set[FSDPState] = set()
for module_name, module in root_module.named_modules():
if (state := _get_module_fsdp_state(module)) is None:
continue
if module is not root_module:
if state not in visited_states and state._is_root is not None:
raise RuntimeError(
"FSDP state has already been lazily initialized for "
f"{module_name}\nFSDP requires running forward through "
"the root module first"
)
state._is_root = False
self._state_ctx.all_states.append(state)
visited_states.add(state)
if self._fsdp_param_group:
# For the root, do not reshard after forward since for training,
# the parameters would be freed and all-gathered immediately
self._fsdp_param_group.post_forward_mesh_info = None
self._init_fqns()
self._init_shared_state()
# Run parameter group lazy inits after initializing FQNs for improved
# error messages
for state in self._state_ctx.all_states:
if state._fsdp_param_group:
state._fsdp_param_group.lazy_init()
def _init_shared_state(self) -> None:
self._comm_ctx.lazy_init()
for state in self._state_ctx.all_states:
state._state_ctx = self._state_ctx
state._comm_ctx = self._comm_ctx
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.comm_ctx = self._comm_ctx
def _init_fqns(self) -> None:
"""Sets module and parameter FQN attributes for debugging."""
assert self._is_root
root_module = self._modules[0]
param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {}
module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {}
for state in self._state_ctx.all_states:
if fsdp_param_group := state._fsdp_param_group:
for fsdp_param in fsdp_param_group.fsdp_params:
param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param
for module in fsdp_param_group.modules:
module_to_fsdp_param_group[module] = fsdp_param_group
for param_name, param in root_module.named_parameters():
if param in param_to_fsdp_param:
param_to_fsdp_param[param]._param_fqn = param_name
for module_name, module in root_module.named_modules():
if module in module_to_fsdp_param_group:
module_fqn = module_to_fsdp_param_group[module]._module_fqn
if module_fqn is None:
module_to_fsdp_param_group[module]._module_fqn = module_name
else:
assert isinstance(module_fqn, str), f"{module_fqn}"
module_fqn += f", {module_name}"
module_to_fsdp_param_group[module]._module_fqn = module_fqn
@disable_if_config_true
def _pre_forward(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
# When composing with module-hook-based activation checkpointing, the
# the pre-backward hook is responsible for the unshard
if self._training_state == TrainingState.PRE_BACKWARD:
return args, kwargs
self._training_state = TrainingState.FORWARD
args, kwargs = self._root_pre_forward(module, args, kwargs)
if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype:
with torch.profiler.record_function("FSDP::cast_forward_inputs"):
cast_fn = functools.partial(
_cast_fp_tensor, self._mp_policy.param_dtype
)
args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)
if self._fsdp_param_group:
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
for fsdp_state in self._states_to_forward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "forward")
return args, kwargs
@disable_if_config_true
def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any:
# When composing with module-hook-based activation checkpointing, the
# post-backward hook is responsible for the reshard
if self._training_state == TrainingState.PRE_BACKWARD:
return output
if self._fsdp_param_group:
output = self._fsdp_param_group.post_forward(module, input, output)
output = self._register_pre_backward_hook(output)
self._training_state = TrainingState.IDLE
if self._state_ctx.iter_forward_root is self:
if all_gather_state := self._comm_ctx.all_gather_state:
# Free the last all-gather result if needed; refer to
# [Note: Overlapping all-gather copy-in and all-gather]
self._comm_ctx.all_gather_copy_in_stream.wait_event(
all_gather_state.event
)
self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event)
self._comm_ctx.all_gather_state = None # free the all-gather result
self._state_ctx.iter_forward_root = None
if self._mp_policy.output_dtype is not None:
with torch.profiler.record_function("FSDP::cast_forward_outputs"):
output = tree_map(
functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype),
output,
)
return output
def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor:
self._training_state = TrainingState.PRE_BACKWARD
self._register_root_post_backward_final_callback()
if self._fsdp_param_group:
default_prefetch = len(self._states_to_backward_prefetch) == 0
self._fsdp_param_group.pre_backward(default_prefetch)
for fsdp_state in self._states_to_backward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "backward")
return grad
def _root_post_backward_final_callback(self) -> None:
if not ca.compiled_autograd_enabled:
logger.debug("FSDP::root_post_backward")
with torch.profiler.record_function("FSDP::root_post_backward_callback"):
for state in self._state_ctx.all_states:
if state._fsdp_param_group and state._fsdp_param_group.is_unsharded:
# Run post-backward in case forward inputs did not require
# gradient so the autograd backward did not run
state._fsdp_param_group.post_backward()
state._training_state = TrainingState.IDLE
if state._fsdp_param_group:
state._fsdp_param_group._training_state = TrainingState.IDLE
if self._state_ctx.is_last_backward:
state._finalize_backward()
if self._state_ctx.is_last_backward:
self._comm_ctx.post_forward_order.clear()
if self._comm_ctx.reduce_scatter_state is not None:
torch.cuda.current_stream().wait_event(
self._comm_ctx.reduce_scatter_state.event
)
self._comm_ctx.reduce_scatter_state = None
self._state_ctx.post_backward_final_callback_queued = False
def _finalize_backward(self) -> None:
if self._modules_to_run_forward:
msg = (
f"{len(self._modules_to_run_forward)} of the {len(self._modules)} "
f"modules passed to fully_shard did not run forward before backward, "
"which is error-prone since FSDP post-forward/pre-backward logic "
"will not run for these modules. We recommend passing only modules "
"that run forward together. Modules that did not run forward: "
f"{list(self._modules_to_run_forward)}"
)
warning_once(logger, msg, stacklevel=2)
# Clear since we want the next forward to run
self._modules_to_run_forward.clear()
if self._fsdp_param_group:
self._fsdp_param_group.finalize_backward()
def _register_pre_backward_hook(self, output: Any) -> Any:
if not torch.is_grad_enabled():
return output
flat_outputs, _ = tree_flatten(output)
for t in flat_outputs:
if torch.is_tensor(t) and t.requires_grad:
t.register_hook(self._pre_backward)
return output
def _register_root_post_backward_final_callback(self):
if self._state_ctx.post_backward_final_callback_queued:
return
self._state_ctx.post_backward_final_callback_queued = True
Variable._execution_engine.queue_callback(
self._root_post_backward_final_callback
)
def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]:
state = _get_module_state(module)
if isinstance(state, FSDPState):
return state
return None
def _register_group_forward_hooks(
modules: Sequence[nn.Module],
pre_hook: Callable,
post_hook: Callable,
modules_to_run: Set[nn.Module],
):
"""
Registers group forward pre and post-hooks. The pre-hook runs upon the
first module pre-forward, and the post-hook runs upon the last. If at least
one module does not run forward, then the post-hook does not run.
"""
modules_set = set(modules)
@disable_if_config_true
@functools.wraps(pre_hook)
def wrapped_pre_hook(*args: Any, **kwargs: Any):
if len(modules_to_run) == 0: # first to run
modules_to_run.update(modules_set)
return pre_hook(*args, **kwargs)
@disable_if_config_true
def get_wrapped_post_hook(module: nn.Module):
@functools.wraps(post_hook)
def wrapped_post_hook(*args: Any, **kwargs: Any):
modules_to_run.discard(module)
if len(modules_to_run) == 0:
return post_hook(*args, **kwargs)
return wrapped_post_hook
pre_handles = [
module.register_forward_pre_hook(
wrapped_pre_hook, prepend=True, with_kwargs=True
)
for module in modules
]
post_handles = [
module.register_forward_hook(
get_wrapped_post_hook(module), prepend=False, always_call=True
)
for module in modules
]
return _MultiHandle(tuple(pre_handles + post_handles))

View File

@ -0,0 +1,446 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Type, Union
import torch
import torch.nn as nn
from torch.distributed._composable import contract
from torch.distributed.tensor import DeviceMesh
from torch.distributed.utils import _get_root_modules
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
from ._fsdp_init import (
_get_device_from_mesh,
_get_managed_modules,
_get_managed_states,
_get_post_forward_mesh_info,
_init_default_fully_shard_mesh,
_move_states_to_device,
)
from ._fsdp_param_group import FSDPParamGroup
from ._fsdp_state import _get_module_fsdp_state, FSDPState
cls_to_fsdp_cls: Dict[Type, Type] = {}
# The decorator adds a state object to `module` that can be accessed via
# `fully_shard.state(module)`. The state object and module are 1:1.
@contract(state_cls=FSDPState) # type: ignore[operator]
def fully_shard(
module: Union[nn.Module, List[nn.Module]],
*,
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
):
"""
Shard module parameters across data parallel workers.
This function applies fully sharded data parallelism (FSDP) or a variant to
``module``, a technique for memory savings at the cost of communication.
Parameters are sharded across ``mesh``, and in turn, so are their gradients
and optimizer states.
The sharded parameters are all-gathered to construct the unsharded
parameters for forward or backward computation. The unsharded parameters
are freed after computation to save memory. The gradients are reduced
across the mesh and divided by the mesh size for data parallelism. The
optimizer step runs on the sharded parameters.
Each call to ``fully_shard`` constructs one communication group that
includes the parameters in ``module.parameters()`` except those already
assigned to a group from a nested call. Each group's parameters and its
gradients are communicated together in one collective, respectively.
Constructing multiple groups across the model (e.g. "layer by layer")
allows for peak memory savings and communication/computation overlap.
Implementation-wise, the sharded parameters are represented as
:class:`DTensor` s, sharded on dim-0, and the unsharded parameters are
represented as :class:`Tensor` s. A module forward pre-hook all-gathers the
parameters, and a module forward hook frees them. Similar backward hooks
gather parameters and later free parameters/reduce gradients.
Args:
module (Union[nn.Module, List[nn.Module]): The module or modules to
shard with FSDP and group together for communication.
mesh (Optional[DeviceMesh]): This data parallel mesh defines the
sharding and device. If 1D, then parameters are fully sharded
across the 1D mesh (FSDP). If 2D, then parameters are sharded
across the 0th dim and replicated across the 1st dim (HSDP). The
mesh's device type gives the device type used for communication;
if a CUDA or CUDA-like device type, then we use the current device.
reshard_after_forward (Union[bool, int]): This controls the parameter
behavior after forward and can trade off memory and communication:
- If ``True``, then this reshards parameters after forward and
all-gathers in backward.
- If ``False``, then this keeps the unsharded parameters in memory
after forward and avoids the all-gather in backward.
- If an ``int``, then this represents the world size to reshard to
after forward. It should be a non-trivial divisor of the ``mesh``
shard dim size (i.e. excluding 1 and the dim size itself). A choice
may be the intra-node size (e.g. ``torch.cuda.device_count()``).
This allows the all-gather in backward to be over a smaller world
size at the cost of higher memory usage than setting to ``True``.
- The root FSDP state has its value specially set to ``False`` as a
heuristic since its parameters would typically be immediately
all-gathered for backward.
- After forward, the parameters registered to the module depend on
to this: The registered parameters are the sharded parameters if
``True``; unsharded parameters if ``False``; and the paramters
resharded to the smaller mesh otherwise. To modify the parameters
between forward and backward, the registered parameters must be the
sharded parameters. For ``False`` or an ``int``, this can be done
by manually resharding via :meth:`reshard`.
mp_policy (MixedPrecisionPolicy): This controls the mixed precision
policy, which offers parameter/reduction mixed precision for this
module. See :class:`MixedPrecisionPolicy` for details.
offload_policy (OffloadPolicy): This controls the offloading policy,
which offers parameter/gradient/optimizer state offloading. See
:class:`OffloadPolicy` and its subclasses for details.
"""
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
raise ValueError(
f"fully_shard does not support containers that do not implement forward: {module}"
)
mesh = mesh or _init_default_fully_shard_mesh()
if mesh.ndim not in (1, 2):
raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
elif mesh.ndim == 1:
mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
else:
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
device = _get_device_from_mesh(mesh)
post_forward_mesh_info = _get_post_forward_mesh_info(
reshard_after_forward, mesh_info
)
arg_module = module
modules = (
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
)
state = fully_shard.state(modules[0])
state.init(modules, device, mp_policy)
managed_modules = _get_managed_modules(modules)
params, buffers = _get_managed_states(managed_modules)
_move_states_to_device(params, buffers, device)
if params:
state._fsdp_param_group = FSDPParamGroup(
params,
modules,
mesh_info,
post_forward_mesh_info,
device,
mp_policy,
offload_policy,
)
# For Dynamo
for managed_module in managed_modules:
managed_module._is_fsdp_managed_module = True # type: ignore[assignment]
managed_module._fsdp_use_orig_params = True # type: ignore[assignment]
# Place FSDP leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
new_cls = cls_to_fsdp_cls.get(cls, None)
if not new_cls:
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
cls_to_fsdp_cls[cls] = new_cls
module.__class__ = new_cls
return arg_module
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
raise AssertionError(
"FSDP does not support deepcopy. Please use state dict for serialization."
)
class FSDPModule:
def __new__(cls, *args, **kwargs):
"""
Override ``__new__`` to remove the FSDP class and directly construct
the original class for cases like indexing into a container module.
"""
# Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
# and index 1 is the `FSDPModule` class itself
orig_cls = cls.__mro__[2]
self = orig_cls.__new__(orig_cls, *args, **kwargs)
self.__init__(*args, **kwargs)
return self
def reshard(self) -> None:
"""
Reshards the module's parameters, registering the sharded parameters
to the module and freeing the unsharded parameters if needed. This
method is *not* recursive.
"""
state = self._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reshard()
def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]:
"""
Unshards the module's parameters by allocating memory and all-gathering
the parameters. This method is *not* recursive.
Args:
async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
that has a :meth:`wait` method to wait on the unshard op. If
``False``, then returns ``None`` and waits on the handle inside
this function.
.. warning:: This method is experimental and subject to change.
.. note:: If ``async_op=True``, then the user does not have to call
:meth:`wait` on the returned handle if waiting on the unshard op
in the module's pre-forward is tolerable. FSDP will wait on the
pending unshard op in the pre-forward automatically.
"""
state = self._get_fsdp_state()
fsdp_param_group = state._fsdp_param_group
if fsdp_param_group is not None:
fsdp_param_group.lazy_init()
fsdp_param_group.unshard(async_op=async_op)
handle = UnshardHandle(fsdp_param_group)
if async_op:
return handle
handle.wait()
return None
def set_is_last_backward(self, is_last_backward: bool) -> None:
"""
Sets whether the next backward is the last one, meaning that FSDP
should wait for gradient reduction to finish and clear internal data
structures used for explicit prefetching.
"""
state = self._get_fsdp_state()
state._state_ctx.is_last_backward = is_last_backward
def set_requires_gradient_sync(
self, requires_gradient_sync: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should sync gradients. This can be used to implement
gradient accumulation without communication. For HSDP, this controls
both reduce-scatter and all-reduce together.
Args:
requires_gradient_sync (bool): Whether to reduce gradients for the
module's parameters.
recurse (bool): Whether to set for all submodules or just the
passed-in module.
"""
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reduce_grads = requires_gradient_sync
fsdp_param_group.all_reduce_grads = requires_gradient_sync
def set_requires_all_reduce(
self, requires_all_reduce: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should all-reduce gradients. This can be used to
implement gradient accumulation with only reduce-scatter but not
all-reduce for HSDP.
"""
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.all_reduce_grads = requires_all_reduce
def set_reshard_after_backward(
self, reshard_after_backward: bool, *, recurse: bool = True
) -> None:
"""
Sets if the module should reshard parameters after backward. This can
be used during gradient accumulation to trade off higher memory for
reduced communication.
Args:
reshard_after_backward (bool): Whether to reshard parameters after
backward.
recurse (bool): Whether to set for all submodules or just the
passed-in module.
"""
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reshard_after_backward = reshard_after_backward
def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None:
"""
Sets the FSDP modules for which this FSDP module should explicitly
prefetch all-gathers in forward. The prefetching runs after this
module's all-gather copy-out.
Passing a singleton list containing the next FSDP module gives the same
all-gather overlap behavior as the default overlap behavior, except the
prefetched all-gather is issued earlier from the CPU. Passing a list
with at least length two is required for more aggressive overlap and
will use more reserved memory.
Args:
modules (List[FSDPModule]): FSDP modules to prefetch.
"""
_assert_all_fsdp_modules(modules)
self._get_fsdp_state()._states_to_forward_prefetch = [
module._get_fsdp_state() for module in modules
]
def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
"""
Sets the FSDP modules for which this FSDP module should explicitly
prefetch all-gathers in backward. This overrides the default backward
pretching implementation that prefetches the next FSDP module based on
the reverse post-forward order.
Passing a singleton list containing the previous FSDP module gives the
same all-gather overlap behavior as the default overlap behavior.
Passing a list with at least length two is required for more aggressive
overlap and will use more reserved memory.
Args:
modules (List[FSDPModule]): FSDP modules to prefetch.
"""
_assert_all_fsdp_modules(modules)
self._get_fsdp_state()._states_to_backward_prefetch = [
module._get_fsdp_state() for module in modules
]
def set_post_optim_event(self, event: torch.cuda.Event) -> None:
"""
Sets a post-optimizer-step event for the root FSDP module to wait the
all-gather streams on.
By default, the root FSDP module waits the all-gather streams on the
current stream to ensure that the optimizer step has finished before
all-gathering. However, this may introduce false dependencies if
there is unrelated computation after the optimizer step. This API
allows the user to provide their own event to wait on. After the root
waits on the event, the event is discarded, so this API should be
called with a new event each iteration.
Args:
event (torch.cuda.Event): Event recorded after the optimizer step
to wait all-gather streams on.
"""
self._get_fsdp_state()._state_ctx.post_optim_event = event
def set_reduce_scatter_divide_factor(self, factor: float) -> None:
"""
Sets a custom divide factor for the reduce-scatter. This becomes a
custom reduce op using NCCL's PreMulSum, which allows multiplying by
the factor before reduction.
Args:
factor (float): Custom divide factor.
"""
state = self._get_fsdp_state()
if (fsdp_param_group := state._fsdp_param_group) is not None:
mul_factor = 1.0 / float(factor)
reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor)
fsdp_param_group.reduce_scatter_reduce_op = reduce_op
def _get_fsdp_state(self) -> FSDPState:
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
raise AssertionError(f"No FSDP state found on {self}")
return state
def _apply(self, *args: Any, **kwargs: Any) -> Any:
# Reshard to ensure that sharded parameters are registered
self.reshard()
ret = super()._apply(*args, **kwargs) # type: ignore[misc]
state = self._get_fsdp_state()
if not (fsdp_param_group := state._fsdp_param_group):
return ret
# TODO: Remove this padding logic once DTensor pads the local tensor:
# https://github.com/pytorch/pytorch/issues/113045
with torch.no_grad():
for fsdp_param in fsdp_param_group.fsdp_params:
fsdp_param.reset_sharded_param()
return ret
class UnshardHandle:
"""
A handle to wait on the unshard op.
Args:
fsdp_param_group (FSDPParamGroup, optional): FSDP parameter group to
unshard. This should be ``None`` iff the FSDP module does not
manage any parameters, meaning the unshard is a no-op.
"""
def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
self._fsdp_param_group = fsdp_param_group
def wait(self):
"""
Waits on the unshard op.
This ensures that the current stream can use the unsharded parameters,
which are now registered to the module.
"""
if self._fsdp_param_group is not None:
self._fsdp_param_group.wait_for_unshard()
# Avoid keeping a reference
self._fsdp_param_group = None
def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
"""
Registers a method on ``module`` to be a forward method for FSDP.
FSDP only knows to run its pre-forward and post-forward hooks on the
default :meth:`nn.Module.forward` method. This function patches a user
specified method to run the pre/post-forward hooks before/after the method,
respectively. If ``module`` is not an :class:`FSDPModule`, then this is a
no-op.
Args:
module (nn.Module): Module to register the forward method on.
method_name (str): Name of the forward method.
"""
if not isinstance(module, FSDPModule):
# Make no-op to allow including both when using/not using FSDP
return
if not hasattr(module, method_name):
raise ValueError(f"{type(module)} does not have a method {method_name}")
orig_method = getattr(module, method_name)
@functools.wraps(orig_method)
def wrapped_method(self, *args, **kwargs):
fsdp_state = self._get_fsdp_state()
args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
out = orig_method(*args, **kwargs)
return fsdp_state._post_forward(self, args, out)
# Use `__get__` to make `wrapped_method` an instance method
setattr(
module,
method_name,
wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined]
)
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
for module in modules:
if not isinstance(module, FSDPModule):
raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}")

View File

@ -0,0 +1,131 @@
# mypy: allow-untyped-decorators
from typing import Callable, Iterable, Optional, Union
from typing_extensions import deprecated
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import contract
from torch.distributed._composable_state import _get_module_state, _insert_module_state
from torch.distributed.fsdp._common_utils import _FSDPState
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handle_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
HYBRID_SHARDING_STRATEGIES,
)
from torch.distributed.fsdp._runtime_utils import (
_register_post_forward_hook,
_register_pre_forward_hook,
_register_root_pre_forward_hook,
)
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _Policy
@contract(state_cls=_FSDPState)
@deprecated(
"`torch.distributed._composable.fully_shard` is being deprecated. "
"You can continue to use the wrapper based FSDP. "
"See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. "
"`torch.distributed._composable.fully_shard` will be removed after PyTorch 2.5.",
category=FutureWarning,
)
def fully_shard(
module: nn.Module,
*,
process_group: Optional[dist.ProcessGroup] = None,
policy: Optional[_Policy] = None,
strategy: Optional[ShardingStrategy] = None,
mixed_precision: Optional[MixedPrecision] = None,
cpu_offload: Optional[CPUOffload] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
device_id: Optional[Union[int, torch.device]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
] = None,
) -> nn.Module:
"""Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``."""
torch._C._log_api_usage_once("torch.distributed.fully_shard")
# Enforce the new auto wrap policy
if policy is not None and not isinstance(policy, _Policy):
raise ValueError(f"Expects a `_Policy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
state = _init_device_handle(state, module, state._ignored_params, device_id)
_annotate_modules_for_dynamo(module, state._ignored_modules, True)
state = _init_process_group_state(state, process_group, strategy, policy)
if policy is not None:
root_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
"cpu_offload": cpu_offload,
"ignored_modules": ignored_modules,
"device_id": device_id,
"param_init_fn": param_init_fn,
"sync_module_states": sync_module_states,
"forward_prefetch": forward_prefetch,
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
root_kwargs,
fully_shard,
)
state = _init_core_state(
state,
strategy or ShardingStrategy.FULL_SHARD,
mixed_precision,
cpu_offload,
limit_all_gathers=True,
use_orig_params=True,
backward_prefetch_limit=1,
forward_prefetch_limit=1,
)
state = _init_runtime_state(state)
state = _init_prefetching_state(
state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
)
state = _init_buffer_state(state, module)
state = _init_param_handle_from_module(
state, module, device_id, param_init_fn, sync_module_states
)
state = _init_state_dict_state(state)
_register_all_state_dict_hooks(state)
_register_pre_forward_hook(state, module)
_register_post_forward_hook(state, module)
_register_root_pre_forward_hook(state, module) # prepend last
# Always insert the state for the passed-in module even if it has no
# managed parameters, in which case it has no handles and does not appear
# in `_fully_sharded_module_to_handles`
_insert_module_state(module, state)
for submodule in module.modules():
if (
submodule in state._fully_sharded_module_to_handle
and _get_module_state(submodule) is None
):
_insert_module_state(submodule, state)
return module

View File

@ -0,0 +1,256 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import weakref
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple
import torch
import torch.nn as nn
from torch.distributed._composable_state import _State
from torch.nn.parallel import DistributedDataParallel
from .contract import _get_registry, contract
_ROOT_MODULE_PREFIX = ""
class _ReplicateState(_State):
def __init__(self) -> None:
super().__init__()
self.module: nn.Module = nn.ParameterList()
self.has_initialized: bool = False
self._param_list: nn.ParameterList = nn.ParameterList()
# TODO(@fegin): this variable is originally create for testing, we
# should remove this if possible.
self._orig_module = self.module
self._param_names: List[str] = []
self._no_sync: bool = False
self._init_args: Optional[Tuple[Any, ...]] = None
self._init_kwargs: Dict[str, Any] = {}
self._comm_hook_args: List[Any] = []
def _collect_params(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
prefix: str = _ROOT_MODULE_PREFIX,
) -> None:
# skip if managed by fully_sharded API
if _is_fully_sharded(module):
return
# if a module is ignored, all descendants of the module are ignored.
if module in ignored_modules:
return
recurse_prefix = (
f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
)
for n, p in module.named_parameters(recurse=False):
if p not in ignored_params:
self._param_list.append(p)
self._param_names.append(f"{recurse_prefix}{n}")
for name, child_module in module.named_children():
self._collect_params(
child_module,
ignored_modules,
ignored_params,
prefix=f"{recurse_prefix}{name}",
)
def lazy_init(self) -> None:
@torch._disable_dynamo(recursive=True)
def _lazy_init():
assert self._init_args is not None
self.init(*self._init_args, **self._init_kwargs)
self.register_comm_hook()
self._init_args = ()
self._init_kwargs = {}
_lazy_init()
def init(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
**kwargs,
) -> None:
if self.has_initialized:
return
self.has_initialized = True
device_mesh = kwargs.get("device_mesh", None)
self.module = module
ignored_params = {p for m in ignored_modules for p in m.parameters()}
for submodule in module.modules():
if _is_fully_sharded(submodule):
ignored_params.update(submodule.parameters())
from torch.distributed.tensor.parallel.ddp import _localize_dtensor
_localize_dtensor(module, ignored_params=ignored_params)
self._collect_params(module, ignored_modules, ignored_params)
if "device_id" in kwargs:
# replicate() supports a small usability enhancement where
# user can pass in device_id as a Union[int, torch.device] even for
# CPU devices so users don't have to change code for CPU/GPU runs.
# We derive the right device_ids to feed into DDP to support this.
if kwargs["device_id"] is not None:
device_id = kwargs["device_id"]
# Convert to device_ids that DDP expects.
if isinstance(device_id, torch.device) and device_id.type == "cpu":
# CPU modules receive device_ids None
kwargs["device_ids"] = None
else:
# GPU modules expect device_ids=[cuda_device]
kwargs["device_ids"] = [device_id]
else:
kwargs["device_ids"] = None
kwargs.pop("device_id")
self._ddp = DistributedDataParallel(self._param_list, **kwargs)
# Weakref to the DDP instance is currently only used for testing.
replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
def register_comm_hook(self) -> None:
for comm_args, comm_kwargs in self._comm_hook_args:
self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
self._comm_hook_args.clear()
def record_init_args(self, *args, **kwargs) -> None:
self._init_args = args
self._init_kwargs = kwargs
def forward_pre_hook(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
if self._init_args or self._init_kwargs:
self.lazy_init()
self._ddp.require_backward_grad_sync = not self._no_sync
return self._ddp._pre_forward(*args, **kwargs)
def forward_post_hook(
self,
module: nn.Module,
input: Tuple[torch.Tensor],
output: torch.Tensor,
) -> torch.Tensor:
return self._ddp._post_forward(output)
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
raise AssertionError(
"DDP does not support deepcopy. Please use state dict for serialization."
)
# Follow the same pattern as FSDP/fully_shard
class DDP:
def __new__(cls, *args, **kwargs):
"""
Override ``__new__`` to remove the DDP class and directly construct
the original class for cases like indexing into a container module.
"""
# Use index 2 since 0 is the dynamically constructed `DDP<...>` class
# and index 1 is the `DDP` class itself
orig_cls = cls.__mro__[2]
return orig_cls.__new__(orig_cls, *args, **kwargs)
def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
"""
Sets if the module should sync gradients. This can be used to implement
gradient accumulation without communication.
Args:
requires_gradient_sync (bool): Whether to reduce gradients for the
module's parameters.
"""
replicate.state(self)._no_sync = not requires_gradient_sync
def register_comm_hook(self, *args, **kwargs) -> None:
replicate.state(self)._comm_hook_args.append((args, kwargs))
@contract(state_cls=_ReplicateState)
def replicate(
module: nn.Module,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
**kwargs,
) -> nn.Module:
r"""Replicates a module
Args:
module (torch.nn.Module): module to replicate
Example::
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> module = nn.Linear(3, 3)
>>> replicate(module)
"""
torch._C._log_api_usage_once("torch.distributed.replicate")
# TODO(fegin): using kwargs is not a good idea if we would like to make
# replicate a formal API to replace DDP.
if "device_id" in kwargs:
if not isinstance(kwargs["device_id"], (int, torch.device)):
raise RuntimeError(
"Expected device_id to be int or torch.device, "
f"but got {type(kwargs['device_id'])}"
)
if _is_fully_sharded(module):
raise RuntimeError(
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
)
if ignored_modules is None:
ignored_modules = {}
else:
ignored_modules = set(ignored_modules)
state = cast(_ReplicateState, replicate.state(module))
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
device_mesh = kwargs.get("device_mesh", None)
if device_mesh is not None:
from torch.distributed.device_mesh import _mesh_resources
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
# if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh.
if root_mesh != device_mesh:
# TODO: This is a temporary work around to enable DDP + TP.
# We should do the logic in DDP so that the 2D implementation is
# sound and the state_dict works out of the box.
#
# This won't conflict with what is done in DDP class as the module
# replicate is going to pass is NOT the original module.
from torch.distributed.tensor.parallel.ddp import (
_localize_dtensor,
_reconstruct_dtensor,
)
module.register_forward_pre_hook(_reconstruct_dtensor)
module.register_forward_hook(_localize_dtensor)
module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
state.record_init_args(module, ignored_modules, **kwargs)
# Place DDP leftmost for highest priority in the method resolution order
cls = module.__class__
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
module.__class__ = new_cls
return module
def _is_fully_sharded(module: nn.Module) -> bool:
r"""Check if module is marked with fully_shard."""
registry = _get_registry(module)
if registry is None:
return False
return "fully_shard" in registry