I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,39 @@
from ._flat_param import FlatParameter as FlatParameter
from .fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel,
LocalOptimStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
OptimStateDictConfig,
OptimStateKeyType,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictSettings,
StateDictType,
)
__all__ = [
"BackwardPrefetch",
"CPUOffload",
"FullOptimStateDictConfig",
"FullStateDictConfig",
"FullyShardedDataParallel",
"LocalOptimStateDictConfig",
"LocalStateDictConfig",
"MixedPrecision",
"OptimStateDictConfig",
"OptimStateKeyType",
"ShardedOptimStateDictConfig",
"ShardedStateDictConfig",
"ShardingStrategy",
"StateDictConfig",
"StateDictSettings",
"StateDictType",
]

View File

@ -0,0 +1,558 @@
# mypy: allow-untyped-defs
"""
This file includes private common utilities for FSDP.
"""
import logging
import traceback
import warnings
import weakref
from enum import auto, Enum
from functools import partial
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
List,
no_type_check,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
)
import torch
import torch.distributed as dist
import torch.distributed.fsdp._flat_param as flat_param_file
import torch.nn as nn
from torch.distributed._composable_state import _get_module_state, _State
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)
from torch.distributed.utils import _apply_to_tensors
from torch.utils._mode_utils import no_dispatch
from .api import (
FullOptimStateDictConfig,
FullStateDictConfig,
OptimStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictType,
)
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
from ._flat_param import FlatParamHandle
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
FSDP_FLATTENED = "_fsdp_flattened"
# Save a global mapping from module to its input tensor dtype to be populated
# during the forward pre-hook and consumed in the forward post-hook when
# overriding a module's mixed precision
# NOTE: We currently take the last input tensor's dtype in the case of multiple
# floating-point input tensors, which may be incorrect. However, since there is
# not a 1:1 correspondence between input and output tensors, we must use *some*
# heuristic like this to predict the desired output dtype.
_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
class _FSDPDeviceHandle:
"""
This is a simple abstraction for FSDP computing devices,
which enables custom backends that implement CUDA-like
semantics to be integrated with FSDP.
"""
def __init__(self, device: torch.device, backend: Any = None):
if backend is None:
try:
self.__backend = getattr(torch, device.type)
self.__device = device
except AttributeError as exc:
raise AttributeError(
f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'."
) from exc
else:
self.__backend = backend
@classmethod
def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle":
"""
Return a device handle corresponding to the device, and through this handle,
operations with the same semantics as CUDA can be performed on the device.
Just return torch.cuda if the device is cuda to make attribute-access faster.
Custom backend must first register a module with the same name with {device.type} on torch.
"""
if device.type == "cuda":
return cast(_FSDPDeviceHandle, torch.cuda)
elif device.type == "mtia":
return cast(_FSDPDeviceHandle, torch.mtia)
return cls(device)
def __getattr__(self, __name: str) -> Any:
try:
return getattr(self.__backend, __name)
except AttributeError as exc:
raise AttributeError(
f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{__name}'"
) from exc
class _UninitializedDeviceHandle(_FSDPDeviceHandle):
def __init__(self) -> None:
pass
def __getattribute__(self, __name: str) -> Any:
raise RuntimeError("Trying to use an uninitialized device handle.")
class _FSDPState(_State):
def __init__(self) -> None:
# TODO: Move all the attributes to this class to enable typing for
# FSDP/fully_shard.
self._ignored_modules: Set[nn.Module] = set()
self._ignored_params: Set[nn.Parameter] = set()
# Buffer names are cleaned (without wrapper prefixes)
self._ignored_buffer_names: Set[str] = set()
self.process_group: Optional[dist.ProcessGroup] = None
self.rank: int = -1
self.world_size: int = -1
self._device_mesh: Optional[DeviceMesh] = None
self.sharding_strategy = ShardingStrategy.FULL_SHARD
self._use_orig_params: bool = False
self.training_state = TrainingState.IDLE
self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
self._state_dict_config: StateDictConfig = FullStateDictConfig()
self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
self._is_root: Optional[bool] = None
self._handle: Optional[flat_param_file.FlatParamHandle] = None
self._fully_sharded_module_to_handle: Dict[
nn.Module, Optional[flat_param_file.FlatParamHandle]
] = {}
self.compute_device: Optional[torch.device] = None
self._gradient_predivide_factor: int = 0
self._gradient_postdivide_factor: int = 0
self._comm_hook: Optional[Callable] = None
self._comm_hook_state: Optional[Any] = None
self._unshard_event: Optional[torch.Event] = None
# Abstract device handle for fsdp compute device. For now,
# the compute device must implement cuda semantics used by fsdp
self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle()
# All following attributes should only be used for root states:
# Save these static lists to avoid the repeated tree traversals
self._all_fsdp_states: List[_FSDPState] = []
self._all_handles: List[flat_param_file.FlatParamHandle] = []
self._fsdp_extension: Optional[FSDPExtensions] = None
def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
state = _get_module_state(module)
if state is None or not isinstance(state, _FSDPState):
return None
return state
def _get_module_fsdp_state_if_fully_sharded_module(
module: nn.Module,
) -> Optional[_FSDPState]:
state = _get_module_fsdp_state(module)
if state is None:
return None
if state == module: # FullyShardedDataParallel module case.
return state
if module in state._fully_sharded_module_to_handle: # fully_shard case.
return state
return None
class TrainingState(Enum):
"""
An enum that indicates the state of a ``FullyShardedDataParallel` instance.
"""
IDLE = auto()
FORWARD_BACKWARD = auto()
SUMMON_FULL_PARAMS = auto()
class HandleTrainingState(Enum):
"""
An enum that indicates the state of a ``FlatParamHandle`.
"""
IDLE = auto()
FORWARD = auto()
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto()
def _is_composable(state: _FSDPState):
# TODO: This is a temporary hack for differentiate between code paths.
return not isinstance(state, nn.Module)
@no_type_check
def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]:
"""
Returns the ``FlatParamHandle`` s corresponding to ``module``. This is
the handle that contains some parameter in ``module``.
"""
if _is_composable(state):
# A valid FSDP state may have no managed parameters and hence no
# handles, meaning no entry in `_fully_sharded_module_to_handles`
if state._handle is None:
return None
assert (
module in state._fully_sharded_module_to_handle
), f"Expects a fully sharded module but got {module} on rank {state.rank}"
return state._fully_sharded_module_to_handle[module]
else:
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
return module._handle
@no_type_check
def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
"""Returns if ``module`` has parameters managed by FSDP."""
return _module_handle(state, module) is not None
def _get_sharding_strategy(handle):
"""
Returns the sharding strategy of the handle.
"""
return handle._sharding_strategy if handle else None
def clean_tensor_name(tensor_name: str) -> str:
"""
Cleans the parameter or buffer name by removing any module wrapper
prefixes.
"""
tensor_name = tensor_name.replace(FSDP_PREFIX, "")
# TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
# it couples `CheckpointWrapper` and FSDP and also does not scale for more
# module wrappers.
tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
return tensor_name
def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
"""
Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
avoid re-flattening it during nested construction.
"""
setattr(tensor, FSDP_FLATTENED, True)
def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
"""Returns if ``tensor`` has been marked as flattened by FSDP."""
return getattr(tensor, FSDP_FLATTENED, False)
def _named_parameters_with_duplicates(
module: nn.Module, **kwargs: Any
) -> List[Tuple[str, nn.Parameter]]:
"""
This API is required as some modules overwrite `named_parameters()` but do not support
`remove_duplicate`.
"""
assert (
"remove_duplicate" not in kwargs
), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
kwargs["remove_duplicate"] = False
try:
ret = list(module.named_parameters(**kwargs))
except AssertionError as e:
kwargs.pop("remove_duplicate")
ret = list(module.named_parameters(**kwargs))
return ret
def _get_param_to_fqns(
model: torch.nn.Module,
dedup_shared_params: bool = True,
) -> Dict[nn.Parameter, List[str]]:
"""
Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here,
we use canonical to mean the fully-qualified name assigned to the parameter
based on its position in the original nn.Module hierarchy before any wrapper
or parallelism has been applied to it. This is in contrast to FQNs that may be
generated after parallelisms or wrappers have been applied to the model.
Each normal parameter maps to a singleton list containing its FQN, while each
``FlatParameter`` maps to a list of its original parameter FQNs, which may
have length greater than one. All FQNs are prefixed starting from ``model``.
In the case where FSDP was applied with ``use_orig_params=True``, there should be no
``FlatParameter`` s registered to the model's modules and this mapping will only
contain mappings from ``nn.Parameter`` s to singleton FQN lists.
It is only in the case where FSDP was applied with ``use_orig_params=False`` where
a ``FlatParameter`` will be registered in place of the original parameters and there
will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the
original parameters.
Args:
model (torch.nn.Module): Root module (which may or may not be a
:class:`FullyShardedDataParallel` instance).
dedup_shared_params (bool): For shared parameters, if ``True``, only
includes the FQNs corresponding to the first encounter of the
shared parameter in the module traversal; if ``False``, then
includes the FQNs across all encounters. (Default: ``True``)
"""
def module_fn(module, prefix, tree_level, param_to_fqns):
for param_name, param in _named_parameters_with_duplicates(
module, recurse=False
):
local_fqns = (
param._fqns
if isinstance(param, flat_param_file.FlatParameter)
else [param_name]
) # prefixed from `module`
global_fqns = [
clean_tensor_name(prefix + name) for name in local_fqns
] # prefixed from the top level `model` (i.e. including `prefix`)
is_shared_param = param in param_to_fqns
if not is_shared_param:
param_to_fqns[param] = global_fqns
else:
if isinstance(param, flat_param_file.FlatParameter):
# DMP overwrites `named_parameters` and skip (advance to
# the next child module) the wrapped_module (e.g.,
# _dmp_wrapped_module and _fsdp_wrapped_module). When a user
# calls `named_child` to traverse the module recursively and
# calls `named_parameters` with `recurse=False`, parameters
# will be traversed more than once.
# This hack is specified designed for DMP + FSDP. We
# overwrite the flat_parameters traversal result to only obtain
# the last one, which happens to be the correct one.
#
# TODO: Remove this hack once DMP + FSDP is not supported.
warnings.warn(
"FlatParameter is being traversed more than once. "
"This case should only happen when using "
"DistributedModelParallel with FullyShardedDataParallel."
)
param_to_fqns[param] = global_fqns
elif not dedup_shared_params:
param_to_fqns[param].extend(global_fqns)
def return_fn(param_to_fqns):
return param_to_fqns
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
[key for key, _ in _named_parameters_with_duplicates(model)],
param_to_unflat_param_names,
)
@no_type_check
def _log_post_backward_hook(
state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger
) -> None:
# Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for.
# Below logging of module names this post-bwd hook fires for can help debug certain
# cases where hooks don't fire, such as under certain activation checkpoint configs.
if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO:
param_fqns = _get_handle_fqns_from_root(state, handle)
logger.warning("FSDP firing post-backward hooks for parameters %s", param_fqns)
@no_type_check
def _get_handle_fqns_from_root(
state: _FSDPState, handle: "FlatParamHandle"
) -> Optional[List[str]]:
if handle is None:
return None
param_to_fqn = state._exec_order_data.param_to_fqn
handle_params = handle.flat_param._params # only populated for use_orig_params
param_fqns = [
fqn for fqn_list in [param_to_fqn[p] for p in handle_params] for fqn in fqn_list
]
return param_fqns
def _apply_to_modules(
root_module: torch.nn.Module,
module_fn: Callable,
return_fn: Callable,
filter_fqns: Optional[List[str]] = None,
*args,
**kwargs,
):
"""
Performs a pre-order traversal of the modules in the hierarchy rooted at
``root_module``, applying ``module_fn`` at each module and finally
returning a value using ``return_fn``. The traversal constructs the full
module prefix name (e.g. "module.submodule." just like in model state dict)
and makes that available to ``module_fn``.
``filter_fqns`` is used because some module may have its own prefix similar
to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
to remove the prefix.
"""
def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
# Call the module function before recursing over children (pre-order)
module_fn(module, prefix, tree_level, *args, **kwargs)
for submodule_name, submodule in module.named_children():
if submodule is None:
continue
new_prefix = prefix + submodule_name + "."
new_tree_level = tree_level + 1
if filter_fqns is not None:
for fqn in filter_fqns:
if fqn.startswith(new_prefix):
break
else:
# DMP's named_parameter() will mess up the traversal with
# ``named_children`` + `named_parameter(recurse=False)``.
# This hack is a must to make the traversal work.
# TODO: Remove this hack once DMP + FSDP is not supported.
# It turns out that recursive wrapping may trigger this as
# well.
if (
submodule_name == "_fsdp_wrapped_module"
or submodule_name == "_dmp_wrapped_module"
):
new_prefix = prefix
elif submodule_name == "module":
new_prefix = prefix
f(submodule, new_prefix, new_tree_level, *args, **kwargs)
f(root_module, "", 0, *args, **kwargs)
return return_fn(*args, **kwargs)
@no_type_check
def _assert_in_training_states(
state: _FSDPState,
training_states: List[TrainingState],
) -> None:
"""Asserts that FSDP is in the states ``_training_states``."""
# Raise a `ValueError` instead of using `assert` to ensure that these
# logical assertions run even if `assert`s are disabled
if state.training_state not in training_states:
msg = (
f"expected to be in states {training_states} but current state is "
f"{state.training_state}"
)
# Print the error on rank 0 in case this is called in the backward pass
if state.rank == 0:
if isinstance(state, nn.Module):
print(f"Asserting FSDP instance is: {state}")
print(f"ERROR: {msg}")
traceback.print_stack()
raise ValueError(msg)
def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
"""
Returns:
Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
parent-less) with respect to the modules in the set itself. In other
words, these are the modules in ``modules`` that are not the child of
any other module in ``modules``.
"""
root_modules: Set[nn.Module] = set()
module_to_submodules = {module: set(module.modules()) for module in modules}
for candidate_module in modules:
is_root_module = True
for module, submodules in module_to_submodules.items():
is_child_module = (
candidate_module is not module and candidate_module in submodules
)
if is_child_module:
is_root_module = False
break
if is_root_module:
root_modules.add(candidate_module)
return root_modules
def _override_module_mixed_precision(
root: torch.nn.Module,
module_classes_to_override: Iterable[Type[nn.Module]],
wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006
) -> Set[Type[nn.Module]]:
module_classes_to_override = tuple(set(module_classes_to_override))
# Return a set of the actually overridden module classes
overridden_module_classes: Set[Type[nn.Module]] = set()
for mod in root.modules():
if isinstance(mod, module_classes_to_override):
overridden_module_classes.add(type(mod))
mod._wrap_overrides = wrap_override_dict # type: ignore[assignment]
# TODO: We need to run this mixed precision ignored module in fp32,
# but ensure subsequent modules, that may possibly be running with
# mixed precision, still receive the appropriate precision inputs
# without user having to adjust mixed precision config too much.
# As a result, we attach pre and post forward hooks to up / down
# cast. We should revisit this design.
def cast_fn(
dtype: torch.dtype, module: nn.Module, x: torch.Tensor
) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == dtype:
return x
_MODULE_TO_INP_DTYPE[module] = x.dtype
return x.to(dtype)
def forward_pre_hook(module, args):
return _apply_to_tensors(partial(cast_fn, torch.float32, module), args)
def forward_post_hook(module, args, output):
# NOTE: If the forward did not have any floating-point tensors,
# then the dtype will not be set for this module, and we do not
# upcast the dtype.
if module in _MODULE_TO_INP_DTYPE:
old_dtype = _MODULE_TO_INP_DTYPE[module]
return _apply_to_tensors(
partial(cast_fn, old_dtype, module), output
)
# We intentionally append both of these hooks so that they run after
# all other hooks.
mod.register_forward_pre_hook(forward_pre_hook, prepend=False)
mod.register_forward_hook(forward_post_hook, prepend=False)
return overridden_module_classes
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
# FIXME record_stream doesn't work with non-cuda/mtia tensors
if tensor.device.type not in [
"cuda",
"mtia",
torch._C._get_privateuse1_backend_name(),
]:
return
if torch.distributed._functional_collectives.is_torchdynamo_compiling():
return
# from @ezyang:
# The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin
# Looking over the PR, it looks like this is because we don't actually support Stream arguments
# in torch dispatch, so it just chokes.
# If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False),
# a better version of this would just be to check if there are any modes before disabling dispatch.
# TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here.
tensor.record_stream(stream)
else:
with no_dispatch():
tensor.record_stream(stream)

View File

@ -0,0 +1,157 @@
# mypy: allow-untyped-defs
import logging
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from typing import Dict, Iterator, List, Set, Tuple
import torch
import torch.distributed as dist
import torch.distributed.fsdp._flat_param as flat_param_file
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
_get_module_fsdp_state,
clean_tensor_name,
)
logger = logging.getLogger(__name__)
class SimpleProfiler:
class Type(str, Enum):
ALL = "all"
ALLGATHER = "all_gather"
ALLGATHER_OBJ = "all_gather_object"
RESHARDING = "resharding"
H2D = "H2D"
D2H = "D2H"
results: Dict[str, float] = defaultdict(float)
profiling: Set[str] = set()
@classmethod
def reset(cls) -> None:
cls.results.clear()
cls.profiling.clear()
@classmethod
@contextmanager
def profile(cls, profile_type: str) -> Iterator[None]:
assert profile_type not in cls.profiling, (
f"{profile_type} is already being profiled. "
"SimpleProfiler does not support profiling multiple instances at "
"the same time. "
)
cls.profiling.add(profile_type)
begin = time.monotonic()
try:
yield
finally:
end = time.monotonic()
cls.results[profile_type] += end - begin
cls.profiling.remove(profile_type)
@classmethod
def dump_and_reset(cls, msg: str) -> None:
# This cannot be combined with DETAIL distributed log
# as the profiling will be very incorrect.
if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
logger.info("%s %s", msg, cls.results)
cls.reset()
def _get_sharded_module_tree_with_module_name_to_fqns(
model: torch.nn.Module,
) -> Tuple[str, Dict[str, List[str]]]:
"""
It is used for composable fully_shard() code path, it returns
1. sharded module tree info: each line reprents a submodule name that contats the
submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
is like this:
[CompositeModel] FULLY SHARDED
l1[Linear]
u1[UnitModule] FULLY SHARDED
u1.l1[Linear]
u1.seq[Sequential]
u1.seq.0[ReLU]
u1.seq.1[Linear]
u1.seq.2[ReLU]
u1.l2[Linear]
u2[UnitModule] FULLY SHARDED
u2.l1[Linear]
u2.seq[Sequential]
u2.seq.0[ReLU]
u2.seq.1[Linear]
u2.seq.2[ReLU]
u2.l2[Linear]
l2[Linear]
2. a dict mapping from the concated module FQN and class name to a list of its managed
original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
{'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
}
All FQNs are prefixed starting from ``model``.
Args:
model (torch.nn.Module): Root module (which may or may not be passed to
composable `fully_shard()`).
"""
def module_fn(
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
):
num_spaces = tree_level * 4
trimed_prefix = (
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
)
prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
state = _get_module_fsdp_state(module)
if state is None:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
return
handle = state._fully_sharded_module_to_handle.get(module, None)
if handle:
sharded_tree_info[0] += (
printed_prefixed_module_name + " FULLY SHARDED" + "\n"
)
else:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
if handle:
param = handle.flat_param
assert isinstance(param, flat_param_file.FlatParameter)
global_fqns = [
clean_tensor_name(prefix + name) for name in param._fqns
] # prefixed from the top level `model` (i.e. including `prefix`)
if prefixed_module_name in sharded_module_name_to_fqns:
sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
else:
sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
return sharded_tree_info[0], sharded_module_name_to_fqns
# Use List to mutate its value in place while running the recursive functions
sharded_tree_info: List[str] = [
"",
]
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
[key for key, _ in model.named_parameters()],
sharded_tree_info,
sharded_module_name_to_fqns,
)

View File

@ -0,0 +1,46 @@
# mypy: allow-untyped-defs
from typing import Set
import torch.nn as nn
def _annotate_modules_for_dynamo(
module: nn.Module,
ignored_modules: Set[nn.Module],
use_orig_params: bool,
):
"""
Annotates the submodules in ``module`` 's tree, except those in
``ignored_modules``, indicating that the submodules are FSDP-managed and
saving the ``use_orig_params`` setting passed to the FSDP constructor.
"""
for submodule in module.modules():
if submodule not in ignored_modules:
"""[note: Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
Dynamo doesn't get to see this instance (FullyShardedDataParallel) during tracing, since
it skips tracing all the torch.distributed.fsdp code.
- Why? Running the FSDP code eagerly avoids lots of issues trying to trace complex hooks, and also
gets us graph-breaks on FSDP module boundaries which we want anyway for comm ops.
- However, we _also_ want dynamo to treat the wrapped module inside FSDP 'unspecially' (*),
and we need a way to indicate to dynamo which modules are wrapped by FSDP.
(*) UnspecializedNNModules in dynamo are traced-through without any assumptions, and with thorough
guards. NNModules otherwise are 'specialized', meaning there is less overhead due to assuming
their code is well-behaved.
One particular issue with specialized NNModules for FSDP is that the
views created for orig_params are captured into the compiled graph on the first iteration, and while
they are always going to point to the correct flatparameter and give correct results, their order
of creation influences the order of backward execution, preventing overlap of comm and computation
during backward. We need to _use_ the new parameter views created on each forward iteration, in
order for backward to interleave hooks with compute per layer. UnspecializedNNModule lets us achieve
this by capturing the module code more 'functionally' and passing parameters in as inputs each time.
"""
submodule._is_fsdp_managed_module = True # type: ignore[assignment]
# Dynamo only supports FSDP with use_orig_params=True.
# This is hacky, but I could not think of another way to add an assertion to dynamo
# for this, since Dynamo skips all the FSDP code frames and thus can't inspect the
# FSDP module directly
submodule._fsdp_use_orig_params = use_orig_params # type: ignore[assignment]

View File

@ -0,0 +1,365 @@
# mypy: allow-untyped-defs
import itertools
import warnings
from enum import auto, Enum
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed.fsdp._common_utils import _FSDPState, _get_param_to_fqns
from torch.distributed.fsdp._flat_param import FlatParamHandle
class _ExecOrderWarnStatus(Enum):
"""Used internally for execution order validation."""
NONE = auto() # no deviation yet
WARNING = auto() # deviated this iteration; currently issuing warnings
WARNED = auto() # deviated in a previous iteration
class _ExecOrderData:
"""
This contains the data structures to track the execution order. We track
the pre-forward order on the *first* iteration for forward prefetching
(which thus assumes static graph) and the post-forward order on *every*
iteration for backward prefetching (which thus does not assume static
graph but may be provide an incorrect order).
"""
def __init__(
self,
debug_level: dist.DebugLevel,
backward_prefetch_limit: int,
forward_prefetch_limit: int,
) -> None:
# Tracks the (static) pre-forward order for execution order validation
# and forward prefetching
self.handles_pre_forward_order: List[FlatParamHandle] = []
# Tracks the post-forward order for pre-backward prefetching
self.handles_post_forward_order: List[Optional[FlatParamHandle]] = []
self._iter = 0
# Gives the max number of backward/forward prefetched all-gathers by a
# single module
self._backward_prefetch_limit = backward_prefetch_limit
self._forward_prefetch_limit = forward_prefetch_limit
# Data structures for execution order validation
self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL
self.process_group: Optional[dist.ProcessGroup] = None
self.world_size: Optional[int] = None
self.all_handles: List[FlatParamHandle] = []
# Names are prefixed from the root module
self.param_to_fqn: Dict[nn.Parameter, List[str]] = {}
# Current index in the pre-forward execution order
self.current_order_index = 0
self.warn_status = _ExecOrderWarnStatus.NONE
def init(
self,
state: _FSDPState,
root_module: nn.Module,
process_group: dist.ProcessGroup,
) -> None:
"""
Initializes the data structures needed for checking the forward order.
This should be called after a root FSDP instance has been set during
lazy initialization.
"""
self.process_group = process_group
self.rank = process_group.rank()
self.world_size = process_group.size()
# Fix an order over the handles, which should be the same across ranks
for handle in traversal_utils._get_fsdp_handles(root_module):
index = len(self.all_handles)
self.all_handles.append(handle)
handle._handle_index = index
self.param_to_fqn = _get_param_to_fqns(root_module)
# TODO (awgu): We can broadcast the metadata of rank 0's `all_handles`
# to check that all ranks have the same handles in the same order.
# https://github.com/pytorch/pytorch/issues/79620
@property
def is_first_iter(self) -> bool:
return self._iter == 0
def get_handle_to_backward_prefetch(
self,
current_handle: FlatParamHandle,
) -> Optional[FlatParamHandle]:
"""
Returns a :class:`list` of the handles keys of the handles to backward
prefetch given the current handles key. If there are no valid handles
keys to prefetch, then this returns an empty :class:`list`.
"""
current_index = current_handle._post_forward_index
if current_index is None:
return None
target_index = current_index - 1
target_handle: Optional[FlatParamHandle] = None
for _ in range(self._backward_prefetch_limit):
if target_index < 0:
break
target_handle = self.handles_post_forward_order[target_index]
target_index -= 1
return target_handle
def get_handle_to_forward_prefetch(
self,
current_handle: FlatParamHandle,
) -> Optional[FlatParamHandle]:
"""
Returns a :class:`list` of the handles keys of the handles to forward
prefetch given the current handles key. If there are no valid handles
keys to prefetch, then this returns an empty :class:`list`.
"""
current_index = current_handle._pre_forward_order_index
if current_index is None:
return None
target_index = current_index + 1
target_handle: Optional[FlatParamHandle] = None
for _ in range(self._forward_prefetch_limit):
if target_index >= len(self.handles_pre_forward_order):
break
target_handle = self.handles_pre_forward_order[target_index]
target_index += 1
return target_handle
def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None:
"""
Records ``handles`` in the post-forward order, where ``handles`` should
be a group of handles used in the same module's forward. If ``handles``
is empty, then it is omitted.
Unlike :meth:`record_pre_forward`, this records the order *every*
iteration with the expectation that the recorded order is reset in
:meth:`next_iter`.
"""
if not handle:
return
# Only record the first usage of a handles key
if handle._post_forward_index:
self.handles_post_forward_order.append(handle)
return
index = len(self.handles_post_forward_order)
handle._post_forward_index = index
self.handles_post_forward_order.append(handle)
def record_pre_forward(
self, handle: Optional[FlatParamHandle], is_training: bool
) -> None:
"""
Records ``handles`` in the pre-forward order, where ``handles`` should
be a group of handles used in the same module's forward. If ``handles``
is empty, then it is omitted.
On the first iteration, this checks the execution order across ranks.
See :meth:`_check_order` for details.
"""
if not handle:
return
self._check_order(handle, is_training)
# Fix the order after the first iteration and only record the first
# usage of a handles key
if not self.is_first_iter or handle._pre_forward_order_index is not None:
return
index = len(self.handles_pre_forward_order)
handle._pre_forward_order_index = index
self.handles_pre_forward_order.append(handle)
def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
"""
Checks the forward execution order as long as ``is_training`` is
``True`` since checking in eval mode is not supported. This only checks
if the distributed debug level is DETAIL.
- On the first iteration, this uses all-gathers to check that all ranks
are all-gathering the same handles and hence ``FlatParameter`` s,
raising an error if not.
- On subsequent iterations, this checks that each rank is locally
consistent with its own forward order from the first iteration, issuing
a warning if not. This issues a warning on the first deviating
iteration and stops warning thereafter.
"""
# Do not check order in eval mode since the post-backward callback does
# not run so it cannot be used to mark the end of an iteration
if not is_training or not self._checking_order:
return
if self.is_first_iter:
msg_prefix = "Forward order differs across ranks:"
optional_local_indices: Tuple[
Optional[int], ...
] = self._get_handle_indices(handle)
device = handle.device # guaranteed to be non-CPU
num_valid_indices = sum(
(index is not None) for index in optional_local_indices
)
tensor_kwargs: Dict[str, Union[torch.dtype, torch.device]] = {
"dtype": torch.int32,
"device": device,
}
world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs) # type: ignore[arg-type, call-overload]
local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) # type: ignore[arg-type, call-overload]
dist.all_gather_into_tensor(
world_num_valid_indices,
local_num_valid_indices,
group=self.process_group,
)
# Copy entire tensor from D2H once to avoid per element D2H copies
world_num_valid_indices = world_num_valid_indices.cpu()
# Check that all ranks plan to all-gather the same number of
# parameters
# TODO (awgu): Since every module has at most one handle in the
# current implementation, this should never raise the error.
assert self.world_size is not None # mypy
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
# TODO(voz): Don't graph break on this - dynamo hates the n1 != n2
# tensor comparison control flow.
# https://github.com/pytorch/pytorch/issues/107055
for (r1, n1), (r2, n2) in itertools.combinations(
(
(rank, world_num_valid_indices[rank])
for rank in range(self.world_size)
),
2,
):
if n1 != n2:
raise RuntimeError(
f"{msg_prefix} rank {r1} is all-gathering {n1} parameters "
f"while rank {r2} is all-gathering {n2} parameters"
)
world_indices = torch.zeros( # type: ignore[call-overload]
self.world_size * num_valid_indices, **tensor_kwargs
)
local_indices = torch.tensor(optional_local_indices, **tensor_kwargs) # type: ignore[arg-type]
dist.all_gather_into_tensor(
world_indices, local_indices, group=self.process_group
)
# Copy entire tensor from D2H once to avoid per element D2H copies
world_indices = world_indices.cpu()
# Check that all ranks plan to all-gather the same index parameters
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
# TODO(voz): Don't graph break on this - dynamo hates the i1 != i2
# tensor comparison control flow.
# https://github.com/pytorch/pytorch/issues/107055
for (r1, i1), (r2, i2) in itertools.combinations(
(
(
rank,
world_indices[
rank
* num_valid_indices : (rank + 1)
* num_valid_indices
],
)
for rank in range(self.world_size)
),
2,
):
if i1 != i2:
r1_param_names = self._get_names_from_handle_indices(i1)
r2_param_names = self._get_names_from_handle_indices(i2)
raise RuntimeError(
f"{msg_prefix} rank {r1} is all-gathering parameters "
f"for {r1_param_names} while rank {r2} is all-gathering "
f"parameters for {r2_param_names}"
)
else:
# Only issue warnings on the first deviating iteration and stop
# checking thereafter to avoid flooding the console
if self.warn_status == _ExecOrderWarnStatus.WARNED:
return
msg_prefix = None # non-`None` means we should warn
if self.current_order_index >= len(self.handles_pre_forward_order):
# This iteration sees extra all-gather(s) compared to the first
msg_prefix = (
"Expected to not all-gather any more parameters in the "
"forward but trying to all-gather parameters for "
)
else:
expected_handle = self.handles_pre_forward_order[
self.current_order_index
]
if expected_handle != handle:
expected_param_names = self._get_names_from_handles(expected_handle)
msg_prefix = (
f"Expected to all-gather for {expected_param_names} "
"but trying to all-gather parameters for "
)
if msg_prefix is not None:
param_names = self._get_names_from_handles(handle)
msg_suffix = (
f"{param_names}"
if param_names
else "a newly-added parameter since construction time"
)
warnings.warn(
"Forward order differs from that of the first iteration "
f"on rank {self.rank}. Collectives are unchecked and may "
f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}"
)
self.warn_status = _ExecOrderWarnStatus.WARNING
self.current_order_index += 1
def _get_handle_indices(
self,
handle: FlatParamHandle,
) -> Tuple[Optional[int], ...]:
"""
Returns the handle indices (i.e. indices into ``self.all_handles``)
corresponding to the handles in ``handle``. An entry in the
returned tuple is ``None`` if the handle is invalid.
"""
indices: List[Optional[int]] = []
if handle:
indices.append(handle._handle_index)
return tuple(indices)
def _get_names_from_handle_indices(
self,
handle_indices: Tuple[int, ...],
) -> List[List[str]]:
"""
Returns a list of FQNs for each handle in ``handle_indices``. If a
handle index is invalid, then its FQNs are omitted from the returned
list.
"""
fqns: List[List[str]] = []
for index in handle_indices:
if index is None or index < 0 or index >= len(self.all_handles):
continue
handle = self.all_handles[index]
flat_param = handle.flat_param
fqns.append(self.param_to_fqn[flat_param])
return fqns
def _get_names_from_handles(
self,
handle: FlatParamHandle,
) -> List[List[str]]:
"""
Returns a list of FQNs for each handle in ``handles_key``. If a handle
is invalid, then its FQNs are omitted from the returned list.
"""
fqns: List[List[str]] = []
if handle:
flat_param = handle.flat_param
if flat_param in self.param_to_fqn:
fqns.append(self.param_to_fqn[flat_param])
return fqns
def next_iter(self):
"""
Advances the internal data structures per iteration. This should be
called in the post-backward callback since that marks the true end of
an iteration.
"""
self._iter += 1
self.handles_post_forward_order.clear()
if self._checking_order:
self.current_order_index = 0
if self.warn_status == _ExecOrderWarnStatus.WARNING:
self.warn_status = _ExecOrderWarnStatus.WARNED

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,179 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed.fsdp._shard_utils import (
_all_gather_dtensor,
_create_chunk_dtensor,
_create_chunk_sharded_tensor,
)
from torch.distributed.tensor import DeviceMesh, DTensor
class FSDPExtensions(ABC):
"""
This enables some customizable hooks to enable composability with tensor
parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
set a custom :class:`FSDPExtensions` that implements the hooks.
"""
@abstractmethod
def pre_flatten_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[Any]]:
"""E.g. converting ``DistributedTensor`` to local tensor."""
...
@abstractmethod
def post_unflatten_transform(
self,
tensor: torch.Tensor,
param_extension: Any,
) -> torch.Tensor:
"""E.g. converting local tensor to ``DistributedTensor``."""
...
@abstractmethod
def chunk_tensor(
self,
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Shards a tensor to chunks and returns the local chunk."""
...
@abstractmethod
def chunk_dtensor(
self,
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> torch.Tensor:
"""Shards a tensor/DTensor to DTensor and returns the local DTensor."""
...
@abstractmethod
def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[Shard]]:
"""
This is to be called before loading a *sharded* model state dict and
should return the tensor and list of shards from which to load data.
"""
...
@abstractmethod
def all_gather_dtensor(
self,
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
This is to be called before loading a *sharded* DTensor state dict.
This gathers tensor in FSDP dimension and returns local tensor of
TP DTensor.
"""
...
_extensions: Optional[FSDPExtensions] = None
def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
global _extensions
_extensions = flattener
def _ext_pre_flatten_transform(
tensor: torch.Tensor,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> Tuple[torch.Tensor, Optional[Any]]:
if fsdp_extension is not None:
new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor)
if param_extension is not None:
return new_tensor, param_extension
return tensor, None
def _ext_post_unflatten_transform(
tensor: torch.Tensor,
param_extension: Any,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
if fsdp_extension is not None and param_extension is not None:
return fsdp_extension.post_unflatten_transform(tensor, param_extension)
return tensor
def _ext_chunk_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
chunk_tensor_fn = (
fsdp_extension.chunk_tensor
if fsdp_extension is not None
else _create_chunk_sharded_tensor
)
return chunk_tensor_fn(
tensor,
rank,
world_size,
num_devices_per_node,
pg,
)
def _ext_chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
chunk_dtensor_fn = (
fsdp_extension.chunk_dtensor
if fsdp_extension is not None
else _create_chunk_dtensor
)
return chunk_dtensor_fn(
tensor,
rank,
device_mesh,
)
def _ext_pre_load_state_dict_transform(
tensor: torch.Tensor,
fsdp_extension: Optional[FSDPExtensions] = None,
) -> Tuple[torch.Tensor, List[Shard]]:
if fsdp_extension is not None:
return fsdp_extension.pre_load_state_dict_transform(tensor)
assert type(tensor) is ShardedTensor
shards = tensor.local_shards()
return (tensor, shards)
def _ext_all_gather_dtensor(
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
fsdp_extension: Optional[FSDPExtensions] = None,
) -> torch.Tensor:
all_gather_dtensor_fn = (
fsdp_extension.all_gather_dtensor
if fsdp_extension is not None
else _all_gather_dtensor
)
return all_gather_dtensor_fn(tensor, parent_mesh)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,33 @@
import collections
from typing import Deque, Optional
import torch
class _FreeEventQueue:
"""
This tracks all pending frees corresponding to inflight all-gathers. The
queueing pattern is iterative enqueues with a single dequeue per iteration
once the limit ``_max_num_inflight_all_gathers`` is reached.
"""
def __init__(self) -> None:
self._queue: Deque[torch.Event] = collections.deque()
self._max_num_inflight_all_gathers = 2 # empirically chosen
def enqueue(self, free_event: torch.Event) -> None:
"""Enqueues a free event."""
self._queue.append(free_event)
def dequeue_if_needed(self) -> Optional[torch.Event]:
"""Dequeues a single event if the limit is reached."""
if len(self._queue) >= self._max_num_inflight_all_gathers:
return self._dequeue()
return None
def _dequeue(self) -> Optional[torch.Event]:
"""Dequeues a free event if possible."""
if self._queue:
event = self._queue.popleft()
return event
return None

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,137 @@
# mypy: allow-untyped-defs
import copy
import itertools
import math
from typing import Optional
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
TensorProperties,
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
def _get_remote_device_str(rank, device_type, num_devices_per_node):
if device_type.lower() == "cpu":
return f"rank:{rank}/{device_type}"
elif device_type.lower() == "hpu":
return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
else:
return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
def _create_chunk_sharded_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> ShardedTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local shard to create a ShardedTensor.
"""
chunks = tensor.chunk(world_size, dim=0)
if len(chunks) > rank:
local_shard = chunks[rank].clone()
offsets = [0 for _ in tensor.size()]
offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
else:
local_shards = []
# Create a ShardedTensor without invoking communication.
chunk_sizes = [list(chunk.size()) for chunk in chunks]
dim0_offsets = [0] + list(
itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
)[:-1]
offsets = [0] * (len(chunk_sizes[0]) - 1)
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
device_type = (
distributed_c10d._get_pg_default_device(pg).type
if device is None
else device.type
)
placements = [
_get_remote_device_str(
dist.get_global_rank(pg, r),
device_type,
num_devices_per_node,
)
for r in range(len(chunk_sizes))
]
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
shard_metadata = [
ShardMetadata(offset, size, placement)
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
]
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=tensor.size(),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=False,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
)
def _create_chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> DTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local tensor to create a DTensor.
"""
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
tensor = tensor.clone().detach()
# FSDP placements: [Shard(0)]
# HSDP placements: [Replicate(), Shard(0)]
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
return DTensor.from_local(
tensor, device_mesh, replicate_placements, run_check=False
).redistribute(
placements=shard_placements,
)
def _all_gather_dtensor(
tensor: DTensor,
root_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
All gather a DTensor in its sharded dimension and return the local tensor.
"""
assert (
root_mesh == tensor.device_mesh
), "The device mesh of a tensor should be a root mesh."
placements = list(copy.deepcopy(tensor.placements))
# FSDP placements: [Shard(0)] -> [Replicate()]
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
placements[-1] = Replicate()
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh,
placements=placements,
)
return tensor.to_local()

View File

@ -0,0 +1,924 @@
# mypy: allow-untyped-defs
import contextlib
import logging
import math
import warnings
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterator,
List,
no_type_check,
Tuple,
)
import torch
import torch.distributed as dist
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._shard.sharded_tensor import (
init_from_local_shards,
Shard,
ShardedTensor,
)
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_module_fsdp_state_if_fully_sharded_module,
_has_fsdp_params,
_is_composable,
_module_handle,
clean_tensor_name,
FSDP_PREFIX,
FSDP_WRAPPED_MODULE,
)
from torch.distributed.fsdp._debug_utils import SimpleProfiler
from torch.distributed.fsdp._runtime_utils import (
_cast_buffers_to_dtype_and_device,
_get_orig_buffer_dtypes,
_lazy_init,
_reset_flat_param_grad_info_if_needed,
)
from torch.distributed.fsdp.api import (
FullStateDictConfig,
ShardingStrategy,
StateDictType,
)
from torch.distributed.tensor import DTensor
from torch.distributed.utils import _replace_by_prefix
from ._fsdp_extensions import (
_ext_all_gather_dtensor,
_ext_chunk_dtensor,
_ext_chunk_tensor,
_ext_post_unflatten_transform,
_ext_pre_load_state_dict_transform,
)
from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM
logger = logging.getLogger(__name__)
def _should_unshard_params(fsdp_state: _FSDPState) -> bool:
return not (
fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
and (_is_composable(fsdp_state) or fsdp_state._use_orig_params)
)
def _convert_to_wrapped_module_name(module_name: str) -> str:
module_name = module_name.replace(f"{FSDP_PREFIX}", "")
module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "")
if module_name:
module_name = f"{module_name}."
# `CheckpointWrapper` adds a prefix that has to be removed as well.
module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "")
return module_name
def _param_name_infos(
module: nn.Module, fsdp_state: _FSDPState
) -> Iterator[Tuple[str, str, str]]:
if not _has_fsdp_params(fsdp_state, module):
return
for param_name, module_name in _module_handle(
fsdp_state, module
).param_module_names():
module_name = _convert_to_wrapped_module_name(module_name)
fqn = f"{module_name}{param_name}"
yield fqn, param_name, module_name
def _shared_param_name_infos(
module: nn.Module, fsdp_state
) -> Iterator[Tuple[str, str, str]]:
for param_name, module_name in _module_handle(
fsdp_state, module
).shared_param_module_names():
module_name = _convert_to_wrapped_module_name(module_name)
fqn = f"{module_name}{param_name}"
yield fqn, param_name, module_name
@no_type_check
def _enter_unshard_params_ctx(
module: nn.Module,
fsdp_state: _FSDPState,
writeback: bool = False,
rank0_only: bool = False,
offload_to_cpu: bool = False,
with_grads: bool = False,
) -> None:
"""
state_dict hooks cannot use the pure context call as the checkpoint flow
requires to enter the context in the pre-hook but leave the context in the
post-hook. This API enters the context of ``_unshard_fsdp_state_params``.
"""
assert module not in fsdp_state._unshard_params_ctx, (
"Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
"is not None."
)
fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params(
module,
fsdp_state,
writeback=writeback,
rank0_only=rank0_only,
offload_to_cpu=offload_to_cpu,
with_grads=with_grads,
)
fsdp_state._unshard_params_ctx[module].__enter__()
@no_type_check
def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None:
"""A helper function to exit ``_unshard_fsdp_state_params`` context."""
fsdp_state._unshard_params_ctx[module].__exit__(None, None, None)
fsdp_state._unshard_params_ctx.pop(module)
def _common_pre_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
) -> None:
"""Performs the pre-state_dict tasks shared by all state_dict types."""
if fsdp_state._device_handle.is_available():
fsdp_state._device_handle.synchronize()
# TODO: need to check if this is always correct for composable FSDP.
_lazy_init(fsdp_state, module)
if fsdp_state._is_root:
_reset_flat_param_grad_info_if_needed(fsdp_state._all_handles)
def _common_unshard_pre_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
offload_to_cpu: bool,
rank0_only: bool,
) -> None:
"""
Performs the pre-state_dict tasks shared by all state_dict types that require
``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook.
"""
# For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases.
if not _should_unshard_params(fsdp_state):
return
_enter_unshard_params_ctx(
module,
fsdp_state,
writeback=False,
offload_to_cpu=offload_to_cpu,
rank0_only=rank0_only,
)
@no_type_check
def _common_unshard_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
param_hook: Callable,
) -> Dict[str, Any]:
"""
The post-state_dict flow that shared by all state_dict types that require
``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
hook.
"""
_replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix)
# Return early for trivial cases
if not state_dict or not _has_fsdp_params(fsdp_state, module):
if _should_unshard_params(fsdp_state):
_exit_unshard_params_ctx(module, fsdp_state)
return state_dict
# If a rank does not have unsharded parameters(when `rank0_only=True`
# and `rank != 0`), then the rank only needed to participate in the
# all-gather and does not need to save the # state dict. We simply check
# rank0_only to ensure this issue.
rank0_only = (
fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT
and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only
)
# no_fsdp_return means the state_dict returned by this rank should contain
# only non-FSDP controlled parameters and buffers.
no_fsdp_return = rank0_only and fsdp_state.rank != 0
if no_fsdp_return and not fsdp_state._use_orig_params:
for clean_key in fsdp_state._buffer_names:
# This is a hack to support activation checkpoint.
clean_key = clean_key.replace(
f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", ""
)
state_dict.pop(f"{prefix}{clean_key}", None)
# Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is
# passed in to unshard context, but nonzero ranks reshard early, causing this flat_param
# to appear in state_dict.
state_dict.pop(f"{prefix}{FLAT_PARAM}")
_exit_unshard_params_ctx(module, fsdp_state)
return state_dict
# Loop only the parameters saved in this instance's wrapped module to
# avoid processing buffers.
for fqn, param_name, module_name in _param_name_infos(module, fsdp_state):
fqn = f"{prefix}{fqn}"
if no_fsdp_return:
state_dict.pop(fqn)
continue
assert fqn in state_dict, (
f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
f"has {state_dict.keys()}. "
f"prefix={prefix}, module_name={module_name}, "
f"param_name={param_name} rank={fsdp_state.rank}."
)
param_hook(state_dict, prefix, fqn)
if _should_unshard_params(fsdp_state):
_exit_unshard_params_ctx(module, fsdp_state)
cpu_device = torch.device("cpu")
buffer_clean_fqns = []
buffers = []
for clean_key in fsdp_state._buffer_names:
# This is a hack to support activation checkpoint.
clean_key = clean_tensor_name(clean_key)
fqn = f"{prefix}{clean_key}"
if fqn not in state_dict:
# A buffer can be registered as non-persistent.
continue
if no_fsdp_return:
state_dict.pop(fqn)
else:
buffer = state_dict[fqn]
if (
fsdp_state._state_dict_config.offload_to_cpu
and buffer.device != cpu_device
):
state_dict[fqn] = buffer.to(cpu_device)
# skip upcasting for ignored buffers
if clean_key not in fsdp_state._ignored_buffer_names:
buffer_clean_fqns.append(clean_key)
buffers.append(state_dict[fqn])
if buffers:
mixed_precision_enabled_for_buffers = (
fsdp_state._mixed_precision_enabled_for_buffers()
if not _is_composable(fsdp_state)
else (fsdp_state.mixed_precision.buffer_dtype is not None)
)
if mixed_precision_enabled_for_buffers:
buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns)
_cast_buffers_to_dtype_and_device(
buffers, buffer_dtypes, fsdp_state.compute_device
)
for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
fqn = f"{prefix}{clean_fqn}"
logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype)
state_dict[fqn] = buffer.clone()
return state_dict
@no_type_check
def _full_pre_state_dict_hook(
fsdp_state: _FSDPState,
module: nn.Module,
*args,
**kwargs,
) -> None:
"""
Hook that runs before model.state_dict() is called. pre-state_dict hook is
not actually supported by ``nn.Module``. As a result, this API is called
from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict
is supported in ``nn.Module``, this hook will be registered as a hook in
``nn.Module``.
"""
if getattr(fsdp_state, "_device_mesh", False):
root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
_common_pre_state_dict_hook(module, fsdp_state)
_common_unshard_pre_state_dict_hook(
module,
fsdp_state,
offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,
rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only,
)
@no_type_check
def _full_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> Dict[str, Any]:
"""
Hook that runs after model.state_dict() is called before returning result to
user. For FSDP, we may have to clone the tensors in state_dict as params go
back to sharded version after _unshard_fsdp_state_params ends, and also remove
the ``FSDP_WRAPPED_MODULE`` prefix.
"""
def param_hook(
state_dict: Dict[str, Any],
prefix: str,
fqn: str,
) -> None:
clean_key = fqn
clean_prefix = clean_tensor_name(prefix)
# Strip prefix out of key if needed as buffer names and param names
# do not have prefix considered as they are not computed in `state_dict`
# call.
if clean_key.startswith(clean_prefix):
clean_key = clean_key[len(clean_prefix) :]
# Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
if not getattr(state_dict[fqn], "_has_been_cloned", False):
try:
state_dict[fqn] = state_dict[fqn].clone().detach()
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
except BaseException as e:
warnings.warn(
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
"This may mean that this state_dict entry could point to invalid "
"memory regions after returning from state_dict() call if this "
"parameter is managed by FSDP. Please check clone "
f"implementation of {fqn}. Error: {str(e)}"
)
return _common_unshard_post_state_dict_hook(
module, fsdp_state, state_dict, prefix, param_hook
)
def _full_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> None:
_lazy_init(fsdp_state, module)
if _should_unshard_params(fsdp_state):
with SimpleProfiler.profile("_enter_unshard_params_ctx"):
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
# Add FSDP_PREFIX only for wrapper-based FSDP.
if not _is_composable(fsdp_state):
_replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
def _full_post_load_state_dict_hook(
module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
) -> None:
if _should_unshard_params(fsdp_state):
with SimpleProfiler.profile("_exit_unshard_params_ctx"):
_exit_unshard_params_ctx(module, fsdp_state)
def _local_pre_state_dict_hook(
fsdp_state: _FSDPState,
module: nn.Module,
*args,
**kwargs,
) -> None:
"""
Hook that runs before model.state_dict() is called. Right now, pre-state_dict
hook is not supported by the PyTorch core. So this API is called from
`_local_post_state_dict_hook()` to simulate the case.
"""
if (
_has_fsdp_params(fsdp_state, module)
and not _module_handle(fsdp_state, module).uses_sharded_strategy
):
raise RuntimeError(
"``local_state_dict`` can only be used when parameters are flatten "
"and sharded."
)
_common_pre_state_dict_hook(module, fsdp_state)
@no_type_check
def _local_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> Dict[str, Any]:
"""
This hook create a ShardedTensor from the local flat_param and replace
the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
will happen. The underlying storage is the same.
"""
_replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix)
if not _has_fsdp_params(fsdp_state, module):
return state_dict
# state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
# value as the flat_param but it is a pure Tensor because
# nn.Module.state_dict() will detach the parameter. Therefore, we need
# to get flat_param to get the metadata.
assert _module_handle(fsdp_state, module), "Should have returned early"
flat_param = _module_handle(fsdp_state, module).flat_param
# Constructs a ShardedTensor from the flat_param "without" padding.
# Removing the padding allows users to change the number of ranks
# when loading the local_state_dict.
full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined]
shard_offset = flat_param.numel() * fsdp_state.rank
valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
if valid_data_size > 0:
# If FlatParameter is returned, FlatParameter._local_shard cause a
# pickling issue (can be torch.save but not torch.load). Since there
# is no benefit for state_dict to return the actual FlatParameter class,
# a view (which is a tensor) of the FlatParameter will be returned.
flat_param = flat_param[:valid_data_size].view(valid_data_size)
local_shards = [
Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank)
]
else:
local_shards = []
sharded_tensor = init_from_local_shards(
local_shards, full_numel, process_group=fsdp_state.process_group
) # type: ignore[assignment]
# TODO: Add DTensor state_dict support for LOCAL_STATE_DICT.
if fsdp_state._state_dict_config.offload_to_cpu:
sharded_tensor = sharded_tensor.cpu()
state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
return state_dict
def _local_post_load_state_dict_hook(
module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
) -> None:
pass
def _local_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> None:
"""
This hook finds the local flat_param for this FSDP module from the
state_dict. The flat_param should be a ShardedTensor. This hook converts
the ShardedTensor to a tensor. No copy happen unless padding is required.
"""
_lazy_init(fsdp_state, module)
_replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}")
fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"
if fqn not in state_dict:
assert not _has_fsdp_params(fsdp_state, module), (
"No `FlatParameter` in `state_dict` for this FSDP instance "
"but it has parameters"
)
return
load_tensor = state_dict[fqn]
assert isinstance(
load_tensor, ShardedTensor
), "Tensors in local_state_dict should be ShardedTensor."
# Convert the ShardedTensor to a Tensor.
flat_param = _module_handle(fsdp_state, module).flat_param
assert flat_param is not None
valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
shards = load_tensor.local_shards()
if valid_data_size > 0:
assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
load_tensor = shards[0].tensor
# Get the metadata of the flat_param to decide whether to pad the loaded
# tensor.
if flat_param._shard_numel_padded > 0:
assert load_tensor.numel() < flat_param.numel(), (
f"Local shard size = {flat_param.numel()} and the tensor in "
f"the state_dict is {load_tensor.numel()}."
)
load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded])
else:
load_tensor = flat_param
# TODO: Add DTensor state_dict support for LOCAL_STATE_DICT.
state_dict[fqn] = load_tensor
def _sharded_pre_state_dict_hook(
fsdp_state: _FSDPState,
module: nn.Module,
*args,
**kwargs,
) -> None:
"""
Hook that runs before model.state_dict() is called. Check
``_full_pre_load_state_dict_hook`` for the detail.
"""
if (
_has_fsdp_params(fsdp_state, module)
and not _module_handle(fsdp_state, module).uses_sharded_strategy
):
raise RuntimeError(
"``sharded_state_dict`` can only be used when parameters are flatten "
"and sharded."
)
_common_pre_state_dict_hook(module, fsdp_state)
# Setting offload_to_cpu here does not work even if offload_to_cpu is True.
# We have to create ShardedTensor first then move it to CPU.
_common_unshard_pre_state_dict_hook(
module,
fsdp_state,
offload_to_cpu=False,
rank0_only=False,
)
@no_type_check
def _sharded_post_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> Dict[str, Any]:
"""
The hook replaces the unflattened, unsharded parameter in the state_dict
with a unflattened, sharded parameter (a ShardedTensor).
"""
def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
param = state_dict[fqn]
if not fsdp_state._state_dict_config._use_dtensor:
sharded_tensor = _ext_chunk_tensor(
tensor=param,
rank=fsdp_state.rank,
world_size=fsdp_state.world_size,
num_devices_per_node=fsdp_state._device_handle.device_count(),
pg=fsdp_state.process_group,
fsdp_extension=fsdp_state._fsdp_extension,
)
else:
sharded_tensor = _ext_chunk_dtensor(
tensor=param,
rank=fsdp_state.rank,
device_mesh=fsdp_state._device_mesh,
fsdp_extension=fsdp_state._fsdp_extension,
)
if fsdp_state._state_dict_config.offload_to_cpu:
sharded_tensor = sharded_tensor.cpu()
state_dict[fqn] = sharded_tensor
return _common_unshard_post_state_dict_hook(
module, fsdp_state, state_dict, prefix, param_hook
)
@no_type_check
def _sharded_post_load_state_dict_hook(
module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
) -> None:
if _has_fsdp_params(fsdp_state, module):
with SimpleProfiler.profile("_exit_unshard_params_ctx"):
_exit_unshard_params_ctx(module, fsdp_state)
@no_type_check
def _sharded_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: _FSDPState,
state_dict: Dict[str, Any],
prefix: str,
) -> None:
"""
The hook combines the unflattened, sharded parameters (ShardedTensor) to
a new FlatParameter and shards the new FlatParameter to the local chunk.
"""
_lazy_init(fsdp_state, module)
if not _is_composable(fsdp_state):
_replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
if not _has_fsdp_params(fsdp_state, module):
return
handle = _module_handle(fsdp_state, module)
if not handle.uses_sharded_strategy:
raise RuntimeError(
"load_sharded_state_dict can only be called when parameters "
"are flattened and sharded."
)
fqn_to_param_ext = dict(
zip(handle.flat_param._fqns, handle.flat_param._param_extensions)
)
for fqn, _, _ in _param_name_infos(module, fsdp_state):
if not _is_composable(fsdp_state):
fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}"
else:
fqn_from_global_root = f"{prefix}{fqn}"
try:
param = state_dict.pop(fqn_from_global_root)
except KeyError:
logger.warning(
f"Did not find param with FQN {fqn_from_global_root}, skipping it. " # noqa: G004
"The weight will not be filled if you expect it to be."
)
continue # TODO: Improve unittesting for state_dict finetuning
# cases: https://github.com/pytorch/pytorch/issues/109134
if not fsdp_state._state_dict_config._use_dtensor:
# All-gather the param (ShardedTensor)
param, shards = _ext_pre_load_state_dict_transform(
param, fsdp_state._fsdp_extension
)
assert len(shards) < 2, (
"Expects 0 or 1 shard per rank "
f"but got {len(shards)} shards on rank {fsdp_state.rank}."
)
param_numel = param.size().numel()
dim_0_size = param.size()[0]
chunk_size = (
math.ceil(dim_0_size / fsdp_state.world_size)
* param_numel
// dim_0_size
)
if len(shards) == 1:
local_tensor = shards[0].tensor.flatten()
with SimpleProfiler.profile(SimpleProfiler.Type.H2D):
local_tensor = local_tensor.to(fsdp_state.compute_device)
num_padding = chunk_size - local_tensor.numel()
if num_padding > 0:
local_tensor = F.pad(local_tensor, [0, num_padding])
else:
local_tensor = torch.zeros(
chunk_size, dtype=param.dtype, device=fsdp_state.compute_device
)
tensor = torch.empty(
chunk_size * fsdp_state.world_size,
dtype=local_tensor.dtype,
device=fsdp_state.compute_device,
)
with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER):
dist.all_gather_into_tensor(
tensor, local_tensor, group=fsdp_state.process_group
)
tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
state_dict[fqn_from_global_root] = tensor
else:
if param.device != fsdp_state._device_mesh.device_type:
param = param.to(fsdp_state._device_mesh.device_type)
root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
local_tensor = _ext_all_gather_dtensor(
param, root_mesh, fsdp_state._fsdp_extension
)
if fqn_to_param_ext.get(fqn) is not None:
ext = fqn_to_param_ext[fqn]
local_tensor = _ext_post_unflatten_transform(
local_tensor, ext, fsdp_state._fsdp_extension
)
state_dict[fqn_from_global_root] = local_tensor
with SimpleProfiler.profile("_enter_unshard_params_ctx"):
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
@contextlib.contextmanager
def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
old_state_dict_config = fsdp_state._state_dict_config
old_state_dict_type = fsdp_state._state_dict_type
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
yield
fsdp_state._state_dict_config = old_state_dict_config
fsdp_state._state_dict_type = old_state_dict_type
@no_type_check
@torch.no_grad()
def _post_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
"""
_post_state_dict_hook() is called after the state_dict() of this
FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide
what postprocessing will be done.
"""
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
else:
context = contextlib.nullcontext()
with context:
_post_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
}
processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
module, fsdp_state, state_dict, prefix
)
if fsdp_state._is_root:
logger.info("FSDP finished processing state_dict(), prefix=%s", prefix)
for key, tensor in sorted(processed_state_dict.items()):
if key.startswith(prefix) and isinstance(tensor, torch.Tensor):
local_shape = tensor.shape
if isinstance(tensor, ShardedTensor):
local_shape = None
shards = tensor.local_shards()
if shards:
local_shape = shards[0].tensor.shape
elif isinstance(tensor, DTensor):
local_shape = tensor.to_local().shape
logger.info(
"FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s",
key,
type(tensor),
tensor.shape,
local_shape,
tensor.dtype,
tensor.device,
)
return processed_state_dict
@no_type_check
@torch.no_grad()
def _pre_state_dict_hook(
module: nn.Module,
*args,
**kwargs,
) -> None:
"""
This is called before the core state dict saving logic of ``module``.
``fsdp_state._state_dict_type`` is used to decide what postprocessing will
be done.
"""
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
else:
_set_use_dtensor(fsdp_state)
context = contextlib.nullcontext()
with context:
_pre_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
}
_pre_state_dict_hook_fn[fsdp_state._state_dict_type](
fsdp_state,
module,
*args,
**kwargs,
)
@no_type_check
def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
# If device_mesh is passed in when initalizing FSDP, we automatically turn the
# _use_dtensor flag to be true for ShardedStateDictConfig().
if getattr(fsdp_state, "_device_mesh", None):
state_dict_type = fsdp_state._state_dict_type
if state_dict_type == StateDictType.LOCAL_STATE_DICT:
raise RuntimeError(
"Found state_dict_type LOCAL_STATE_DICT",
"DeviceMesh is not compatible with LOCAL_STATE_DICT.",
"Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
)
else:
fsdp_state._state_dict_config._use_dtensor = True
@no_type_check
@torch.no_grad()
def _pre_load_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> None:
"""
This is called before ``module._load_from_state_dict()``.
``fsdp_state._state_dict_type`` is used to decide what preprocessing will
be done.
"""
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
else:
_set_use_dtensor(fsdp_state)
context = contextlib.nullcontext()
_lazy_init(fsdp_state, module)
if fsdp_state._is_root:
SimpleProfiler.reset()
with context:
_pre_load_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
}
# Code that is common for all state_dict impls
if fsdp_state._device_handle.is_available():
fsdp_state._device_handle.synchronize()
# Dispatch into state_dict specific implementation of pre-hook.
_pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
module, fsdp_state, state_dict, prefix
)
@no_type_check
@torch.no_grad()
def _post_load_state_dict_hook(
module: nn.Module,
incompatible_keys: Tuple[List[str], List[str]],
*args: Any,
) -> None:
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
else:
context = contextlib.nullcontext()
with context:
_post_load_state_dict_hook_fn = {
StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
}
# Code that is common for all state_dict impls
# Dispatch into state_dict type specific implementation of post-hook for
# loading state_dict.
_post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
# When reporting incompatible keys, trim FSDP prefixes.
missing_keys = incompatible_keys[0]
unexpected_keys = incompatible_keys[1]
for i in range(len(missing_keys)):
missing_keys[i] = clean_tensor_name(missing_keys[i])
for i in range(len(unexpected_keys)):
unexpected_keys[i] = clean_tensor_name(unexpected_keys[i])
if fsdp_state._is_root:
SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")
def _register_all_state_dict_hooks(state: _FSDPState):
"""
Registers pre-save, post-save, pre-load, and post-load state dict hooks.
"""
for hook_registration_fn_str, hook, hook_registration_fn_kwargs in (
("register_state_dict_pre_hook", _pre_state_dict_hook, {}),
("_register_state_dict_hook", _post_state_dict_hook, {}),
(
"_register_load_state_dict_pre_hook",
_pre_load_state_dict_hook,
{"with_module": True},
),
("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}),
):
_register_state_dict_hooks_base(
state, hook_registration_fn_str, hook, hook_registration_fn_kwargs
)
@no_type_check
def _register_state_dict_hooks_base(
state: _FSDPState,
hook_registration_fn_name: str,
hook: Callable,
hook_registration_fn_kwargs: Dict[str, Any],
) -> None:
"""Registers ``hook`` using ``hook_registration_fn``."""
if not _is_composable(state):
getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs)
else:
handle = state._handle
if handle:
getattr(handle._fully_sharded_module, hook_registration_fn_name)(
hook, **hook_registration_fn_kwargs
)

View File

@ -0,0 +1,238 @@
# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
import torch
import torch.nn as nn
@dataclass
class TracingConfig:
"""
This represents a symbolic tracing configuration.
Args:
tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
use for symbolic tracing. The default value is the native
:class:`torch.fx.Tracer` constructed with default arguments.
However, the user may want to pass a different value such as the
``HFTracer`` for models in the HuggingFace Transformers_ library.
.. _Transformers: https://huggingface.co/docs/transformers/index
concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
should not be treated as ``torch.fx.Proxy`` when tracing the
module ``forward()``. Passing ``concrete_args`` allows partially
specializing the forward, e.g. to remove control flow or data
structures. This ``concrete_args`` here is the same argument used
in :meth:`~torch.fx.Tracer.trace`.
"""
tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
concrete_args: Optional[Dict[str, Any]] = None
class _ParamUsageInfo(NamedTuple):
"""
This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
execution information. The ``dict`` maps modules to a list of these
``_ParamUsageInfo`` instances, where each instance represents a group of
parameters used together.
Specifically, for each module key in the ``dict``, each instance of this
class represents either:
(1) the module and some sublist of its ``named_parameters()`` used
together in execution (see ``_patched_create_proxy()``), or
(2) a submodule and all of ``submodule.named_parameters()`` (see
``_patched_call_module()``).
Type (1) corresponds to directly using parameters in ops without calling
``forward()``, and type (2) corresponds to calling ``forward()``. The
mapped-to lists in the ``dict`` follow the execution order.
"""
module: nn.Module
named_params: List[Tuple[str, nn.Parameter]]
class _ExecutionInfo:
"""
This represents the execution order information from the forward pass.
Attributes:
curr_module (nn.Module): Current module being traced.
module_forward_order (List[nn.Module]): The modules in (pre-)forward
order, i.e. the order in which their ``forward()`` methods are
called. Each call to a module's ``forward()`` corresponds to one
element in the list.
module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
Maps a module to a list of module execution infos. See
:class:`_ParamUsageInfo` for details.
param_forward_order (List[nn.Parameter]): The parameters in forward
execution order, where only a parameter's first participation is
included.
visited_params (Set[nn.Parameter]): The parameters visited so far
during the trace. This is only used during tracing for fast
membership check. Invariant: The parameters in
``param_forward_order`` are exactly those in ``visited_params``.
"""
def __init__(self, root_module: nn.Module) -> None:
self.curr_module: nn.Module = root_module
self.module_forward_order: List[nn.Module] = [root_module]
self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
root_module: []
}
self.param_forward_order: List[nn.Parameter] = []
self.visited_params: Set[nn.Parameter] = set()
class _ExecOrderTracer:
def __init__(self) -> None:
self.exec_info: Optional[_ExecutionInfo] = None
@contextmanager
def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
self.exec_info = _ExecutionInfo(root_module)
orig_call_module = tracer.call_module
orig_create_proxy = tracer.create_proxy
tracer.call_module = functools.partial( # type: ignore[method-assign]
self._patched_call_module, orig_call_module, self.exec_info
)
fqn_to_param = dict(root_module.named_parameters())
tracer.create_proxy = functools.partial( # type: ignore[method-assign]
self._patched_create_proxy,
orig_create_proxy,
self.exec_info,
fqn_to_param,
)
try:
yield
finally:
tracer.call_module = orig_call_module # type: ignore[method-assign]
tracer.create_proxy = orig_create_proxy # type: ignore[method-assign]
def _patched_call_module(
self,
call_module: Callable,
exec_info: _ExecutionInfo,
# Below are the expected arguments to `call_module()`
module: nn.Module,
forward: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""
Overrides ``call_module`` to save execution information to
``exec_info``. Note that ``call_module`` is called during symbolic
tracing for each non-root module.
Args:
call_module (Callable): Original ``call_module`` to override.
exec_info (_ExecutionInfo): Used to record execution information.
module (nn.Module): Module corresponding to this ``call_module``.
forward (Callable): ``forward()`` method of ``module`` to be called
for this ``call_module``.
args (Tuple[Any, ...]): Positional arguments for ``forward``.
kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
Returns:
Same return value as ``call_module``.
"""
exec_info.module_forward_order.append(module)
named_params = list(module.named_parameters())
curr_module = exec_info.curr_module
if named_params:
assert (
curr_module in exec_info.module_to_param_usage_infos
), "The current module should have already been processed by a patched `call_module`"
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
_ParamUsageInfo(module, named_params)
)
prev_curr_module = curr_module
exec_info.curr_module = module
exec_info.module_to_param_usage_infos[module] = []
output = call_module(module, forward, args, kwargs)
exec_info.curr_module = prev_curr_module
return output
def _patched_create_proxy(
self,
create_proxy: Callable,
exec_info: _ExecutionInfo,
fqn_to_param: Dict[str, nn.Parameter],
# Below are the expected arguments to `create_proxy()`
kind: str,
target: torch.fx.node.Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
) -> torch.fx.Proxy:
"""
Overrides ``create_proxy`` to save execution information to
``exec_info``. Note that ``create_proxy`` is called during symbolic
tracing for each leaf function/method/module.
Args:
create_proxy (Callable): Original ``create_proxy`` to override.
exec_info (_ExecutionInfo): Used to record execution information.
fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
root module's ``named_parameters()`` with FQN as key and
parameter as value.
kind (str): Kind of the target method ('call_function',
'call_method', 'get_attr', 'call_module', 'placeholder', or
'output'). See :class:`torch.fx.Graph` for details. This is
passed to ``create_proxy``.
target (torch.fx.node.Target): Contains the string name of the
function/method/module. This is passed to ``create_proxy``.
args (Tuple[Any, ...]): Positional arguments for the function/
method/module. This is passed to ``create_proxy``.
kwargs (Dict[str, Any]): Keyword arguments for the function/method/
module. This is passed to ``create_proxy``
name (Optional[str]): An optional string name for the ``Node``
created in ``create_proxy``. This is passed to
``create_proxy``.
type_expr (Optional[Any]): An optional type annotation representing
the Python type that the output of the node has. This is passed
to ``create_proxy``.
proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
An alternative proxy constructor used in ``create_proxy``. This
is passed to ``create_proxy``.
Returns:
torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
"""
proxy = create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
)
curr_module = exec_info.curr_module
if kind in ("call_function", "call_method"):
if args is not None:
named_params: List[Tuple[str, nn.Parameter]] = []
for arg in args:
if (
isinstance(arg, torch.fx.Proxy)
and arg.node.target in fqn_to_param
):
param = fqn_to_param[arg.node.target] # type: ignore[index]
named_params.append((arg.node.target, param)) # type: ignore[arg-type]
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)
if named_params:
exec_info.module_to_param_usage_infos[curr_module].append(
_ParamUsageInfo(curr_module, named_params)
)
elif kind == "call_module":
named_params = list(curr_module.named_parameters())
if named_params:
exec_info.module_to_param_usage_infos[curr_module].append(
_ParamUsageInfo(curr_module, named_params)
)
for _, param in named_params:
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)
return proxy

View File

@ -0,0 +1,113 @@
"""
NOTE: This file must be imported like
``import torch.distributed.fsdp._traversal_utils`` and not like
``from torch.distirbuted.fsdp._traversal_utils import ...`` to avoid circular
imports. For brevity, we may import the file as ``traversal_utils``.
"""
import collections
from typing import Deque, List, Set, Tuple
import torch.nn as nn
from torch.distributed._composable.contract import _get_registry
from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state
"""
[Note: FSDP State Traversal]
For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel``
module wrapping a fully sharded module, and for the non-wrapper code path,
``_FSDPState`` is an object that gets embedded on a fully sharded module.
See [Note: Fully Sharded Module] for the definition.
There are three common traversal idioms: Given a root module,
- ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree.
- ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the
tree (i.e. those with ``_is_root == True``).
- ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree.
All of these methods must take in the root module (i.e. an ``nn.Module``) and
not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph
traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal.
"""
def _composable(module: nn.Module) -> bool:
"""
Returns if ``module`` can compose with ``fully_shard``.
"""
# TODO: Add any other composable APIs that are mutually exclusive.
registry = _get_registry(module)
if registry is None:
return True
return "replicate" not in registry
# TODO (awgu): We may be able to remove this function if we retired the
# `use_orig_params=False` code path since so far we only need the module for
# `FlatParameter` registration, which is not needed for `use_orig_params=True`.
def _get_fsdp_states_with_modules(
module: nn.Module,
) -> Tuple[List[_FSDPState], List[nn.Module]]:
"""
Returns a tuple containing:
1. A list of the ``_FSDPState`` instances in the module tree rooted at
``module`` without any duplicates and following the ``module.modules()``
traversal order (which is assumed to be depth-first).
2. A corresponding list of the modules owning the states in the first list.
For the wrapper code path, both returned lists are the same, each
containing all ``FullyShardedDataParallel`` instances. For the composable
code path, this returns a list of all composable state instances and a list
of the corresponding fully sharded modules. See [Note: Fully Sharded
Module].
NOTE: The traversal does not proceed into any module annotated by an
incompatible API (e.g. ``replicate``).
"""
fsdp_states: List[_FSDPState] = []
fsdp_modules: List[nn.Module] = []
# Track the visited FSDP states since multiple modules may share the same
# one and we want to return a de-duplicated list
visited_fsdp_states: Set[_FSDPState] = set()
# Track the visited modules in case of shared modules, which implies the
# module graph is no longer a tree
visited_modules: Set[nn.Module] = set()
# Perform depth-first search from `module` to ensure that we do not
# traverse into an incompatible API's subtree (use DFS instead of BFS to
# match `.modules()` order)
deque: Deque[nn.Module] = collections.deque([module])
while deque:
submodule = deque.popleft()
visited_modules.add(submodule)
if not _composable(submodule):
continue
for child_module in reversed(list(submodule.children())):
if child_module not in visited_modules:
deque.appendleft(child_module)
optional_state = _get_module_fsdp_state(submodule)
if optional_state is not None and optional_state not in visited_fsdp_states:
visited_fsdp_states.add(optional_state)
fsdp_states.append(optional_state)
fsdp_modules.append(submodule)
return fsdp_states, fsdp_modules
def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]:
"""See :func:`_get_fsdp_states_with_modules`."""
fsdp_states, _ = _get_fsdp_states_with_modules(module)
return fsdp_states
def _get_fsdp_handles(module: nn.Module) -> List:
"""
Returns all ``FlatParamHandle`` s in the module tree rooted at ``module``
following the rules in :func:`_get_fsdp_state`.
"""
handles = [
fsdp_state._handle
for fsdp_state in _get_fsdp_states(module)
if fsdp_state._handle is not None
]
return handles

View File

@ -0,0 +1,336 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
from typing import cast, Generator
import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_module_fsdp_state,
_has_fsdp_params,
_module_handle,
HandleTrainingState,
TrainingState,
)
from torch.distributed.fsdp._runtime_utils import (
_lazy_init,
_reset_flat_param_grad_info_if_needed,
_reshard,
_reshard_grads,
_unshard,
_unshard_grads,
)
from torch.distributed.utils import _p_assert
from ._flat_param import FlatParamHandle
FLAT_PARAM = "_flat_param"
@torch.no_grad()
def _writeback_to_local_shard(
handle: FlatParamHandle,
writeback_grad: bool,
):
"""
For the handle, writes back the this rank's shard of the unsharded
flattened parameter to the sharded flattened parameter. If
``writeback_grad=True``, then writes back to the sharded gradient as
well.
Precondition: The handle's ``FlatParameter`` 's data points to the
padded unsharded flattened parameter.
"""
def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
if handle.uses_sharded_strategy:
# For sharded strategies, get the *unpadded* shard instead of
# the *padded* shard to persist user changes to the padding
# (though FSDP does not explicitly support this)
shard, _ = FlatParamHandle._get_unpadded_shard(
flat_param_or_grad,
handle.rank,
handle.world_size,
)
return shard
# For `NO_SHARD`, the `flat_param` or its gradient may be modified,
# so we write it back directly
return flat_param_or_grad
param_shard = _get_shard(handle.flat_param)
handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined]
if writeback_grad:
existing_grad = handle.sharded_grad
if existing_grad is not None:
assert handle.flat_param.grad is not None
grad_shard = _get_shard(handle.flat_param.grad)
existing_grad[: grad_shard.numel()].copy_(grad_shard)
def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
"""
De-registers the flattened parameter from the wrapped module, hiding it
from ``nn.Module`` methods.
We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
attribute but dynamically change whether it is visible to ``nn.Module``
methods.
"""
if _has_fsdp_params(state, module):
# TODO: figure out the case for the composable APIs.
cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
"""
Registers the flattened parameter to the wrapped module, making it
visible to ``nn.Module`` methods.
We do not use :meth:`nn.Module.register_parameter` because we want
``FLAT_PARAM`` to always be an attribute but dynamically change whether
it is visible to ``nn.Module`` methods.
"""
handle = _module_handle(state, module)
if _has_fsdp_params(state, module):
# TODO: figure out the case for the composable APIs.
cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param
@contextlib.contextmanager
def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
"""
Assumes that the flattened parameter is unsharded. When in the context,
de-registers the flattened parameter and unflattens the original
parameters as ``nn.Parameter`` views into the flattened parameter.
After the context, re-registers the flattened parameter and restores
the original parameters as ``Tensor`` views into the flattened
parameter.
"""
handle = _module_handle(state, module)
if not handle:
yield
else:
_deregister_flat_param(state, module)
try:
with handle.unflatten_as_params():
yield
finally:
if not handle._use_orig_params:
_register_flat_param(state, module)
def _validate_unshard_params_args(
state: _FSDPState,
writeback: bool,
rank0_only: bool,
offload_to_cpu: bool,
with_grads: bool,
) -> None:
if with_grads and (offload_to_cpu or not state._use_orig_params):
raise NotImplementedError(
f"with_grads={with_grads}, "
f"use_orig_params={state._use_orig_params}, "
f"offload_to_cpu={offload_to_cpu} "
f"is not supported yet"
)
if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy):
raise NotImplementedError(
"offload_to_cpu=True and NO_SHARD is not supported yet"
)
if writeback and rank0_only:
# TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
# persist the changes.
raise NotImplementedError(
"writeback=True and rank0_only=True is not supported yet"
)
if offload_to_cpu and not rank0_only:
warnings.warn(
"offload_to_cpu=True and rank0_only=False may result in the"
"unsharded parameters being redundantly copied to CPU memory for "
"GPUs sharing the same CPU memory, which risks CPU OOM. We "
"recommend using offload_to_cpu=True with rank0_only=True."
)
@contextlib.contextmanager
def _unshard_fsdp_state_params(
module: nn.Module,
state: _FSDPState,
writeback: bool,
rank0_only: bool,
offload_to_cpu: bool,
with_grads: bool,
):
"""
This unshards the parameters for a single FSDP state ``state`` that
corresponds to ``module``.
"""
_validate_unshard_params_args(
state, writeback, rank0_only, offload_to_cpu, with_grads
)
state._device_handle.synchronize()
# If handles are shared by other module(s), the handle may be already unsharded.
maybe_handle = _module_handle(state, module)
handle = None
if (
maybe_handle
and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
):
handle = maybe_handle
if not handle:
yield
return
assert (
handle._training_state == HandleTrainingState.IDLE
), f"Expects the handle training to be IDLE but got {handle._training_state}"
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
_reset_flat_param_grad_info_if_needed(handle)
free_unsharded_flat_param = handle.needs_unshard()
# No need to call `wait_stream()` since we unshard in the computation
# stream directly
computation_stream = state._device_handle.current_stream()
_unshard(state, handle, computation_stream, computation_stream)
if with_grads:
_unshard_grads(handle)
if rank0_only and state.rank != 0:
# Free the unsharded flattened parameter early
_reshard(state, handle, free_unsharded_flat_param)
if with_grads:
_reshard_grads(handle)
try:
yield
finally:
handle._training_state = HandleTrainingState.IDLE
else:
# Unflatten the unsharded flattened parameters
with contextlib.ExitStack() as stack:
# Invariant: rank == 0 or !rank0_only
if offload_to_cpu and handle.uses_sharded_strategy:
stack.enter_context(handle.to_cpu())
# NOTE: Since PyTorch enforces that a parameter and its
# gradients need to match metadata (e.g. device), we must
# move gradients to CPU *after* we move parameters.
# NOTE: This assumes 1 `FlatParameter`
if not state._use_orig_params:
stack.enter_context(_unflatten_as_params(state, module))
try:
yield
finally:
stack.close()
if writeback:
_writeback_to_local_shard(handle, with_grads)
_reshard(state, handle, free_unsharded_flat_param)
if with_grads:
_reshard_grads(handle)
handle._training_state = HandleTrainingState.IDLE
@contextlib.contextmanager
def _unshard_params_for_summon(
module: nn.Module,
state: _FSDPState,
writeback: bool,
rank0_only: bool,
offload_to_cpu: bool,
with_grads: bool,
):
_validate_unshard_params_args(
state, writeback, rank0_only, offload_to_cpu, with_grads
)
_lazy_init(state, module)
if state.training_state == TrainingState.FORWARD_BACKWARD:
raise AssertionError(
"Cannot manually unshard parameters during forward/backward"
)
elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
raise AssertionError(
"Cannot manually unshard parameters when already unsharding parameters"
)
with _unshard_fsdp_state_params(
module=module,
state=state,
writeback=writeback,
rank0_only=rank0_only,
offload_to_cpu=offload_to_cpu,
with_grads=with_grads,
):
try:
state.training_state = TrainingState.SUMMON_FULL_PARAMS
yield
finally:
state.training_state = TrainingState.IDLE
@contextlib.contextmanager
def _unshard_params(
module: nn.Module,
recurse: bool,
writeback: bool,
rank0_only: bool,
offload_to_cpu: bool,
with_grads: bool,
):
"""
This unshards FSDP-managed parameters for all modules with FSDP applied in
the module tree rooted at ``module``.
"""
if not recurse:
optional_state = _get_module_fsdp_state(module)
if optional_state is None:
with contextlib.nullcontext():
yield
return
states_and_modules = ([optional_state], [module])
else:
states_and_modules = traversal_utils._get_fsdp_states_with_modules(module)
with contextlib.ExitStack() as stack:
for state, module in zip(*states_and_modules):
stack.enter_context(
_unshard_params_for_summon(
module=module,
state=state,
writeback=writeback,
rank0_only=rank0_only,
offload_to_cpu=offload_to_cpu,
with_grads=with_grads,
)
)
yield
def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
"""
Deregisters the original parameters; registers the ``FlatParameter``.
"""
handle = _module_handle(state, module)
if not handle:
return
_p_assert(
handle._use_orig_params,
f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
f"handle: {handle._use_orig_params}",
)
handle._deregister_orig_params()
_register_flat_param(state, module)
def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
"""
Deregisters the ``FlatParameter``; registers the original parameters.
"""
handle = _module_handle(state, module)
if not handle:
return
_deregister_flat_param(state, module)
if handle.is_sharded(handle.flat_param):
handle._use_sharded_views()
handle._use_sharded_grad_views()
else:
handle._use_unsharded_views(as_params=True)

View File

@ -0,0 +1,262 @@
# mypy: allow-untyped-defs
import collections
import functools
import inspect
import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
import torch.nn as nn
from torch.distributed.fsdp._common_utils import (
_get_module_fsdp_state,
_override_module_mixed_precision,
)
from torch.distributed.fsdp.wrap import (
_construct_wrap_fn,
_or_policy,
_Policy,
_post_order_apply,
_recursive_wrap,
_run_mixed_precision_override_policy,
_wrap_module_cls_individually,
)
def _auto_wrap(
root_module: nn.Module,
policy: Union[Callable, _Policy],
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
root_kwargs: Dict[str, Any],
fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
):
"""
Auto wraps modules in ``root_module`` 's tree according to ``policy``
following a post-order traversal.
Precondition: ``root_kwargs`` should contain all arguments except
``module``. This function accepts the kwargs dict directly since it gets
forwarded into the post-order traversal function.
"""
mixed_precision = root_kwargs["mixed_precision"]
is_wrapper = inspect.isclass(fsdp_fn)
# TODO: We may relax this no-nested-wrapping constraint to support manual
# wrapping followed by auto wrapping.
_check_nested_wrapping(root_module)
if isinstance(policy, _Policy):
root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
target_module_to_kwargs = policy._run_policy(
root_module, ignored_modules, root_kwargs
)
if mixed_precision is not None:
target_module_to_kwargs = _run_mixed_precision_override_policy(
root_module,
mixed_precision._module_classes_to_ignore,
ignored_modules,
root_kwargs,
target_module_to_kwargs,
)
overridden_module_classes = _override_module_mixed_precision(
root_module, mixed_precision._module_classes_to_ignore
)
_warn_on_overridden_mixed_precision(overridden_module_classes)
use_orig_params = root_kwargs.get("use_orig_params", False)
_validate_frozen_params(
root_module,
set(target_module_to_kwargs.keys()),
ignored_params,
use_orig_params,
)
wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
_post_order_apply(root_module, wrap_fn)
return
recursive_wrap_kwargs = {
"module": root_module,
"auto_wrap_policy": policy,
"wrapper_cls": fsdp_fn,
"ignored_modules": ignored_modules,
"ignored_params": ignored_params,
"only_wrap_children": True,
}
if mixed_precision is not None:
# Wrap modules of the ignored types separately and register forward
# hooks to cast to fp32 and back to the original dtype, respectively
overridden_module_classes = _override_module_mixed_precision(
root_module, mixed_precision._module_classes_to_ignore
)
policy = functools.partial(
_or_policy,
policies=[
policy,
partial(
_wrap_module_cls_individually,
module_classes=mixed_precision._module_classes_to_ignore,
),
],
)
recursive_wrap_kwargs["auto_wrap_policy"] = policy
_warn_on_overridden_mixed_precision(overridden_module_classes)
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
def _check_nested_wrapping(root_module: nn.Module):
for module_name, module in root_module.named_modules():
if _get_module_fsdp_state(module) is not None:
raise ValueError(
"FSDP auto wrapping requires modules to not already have "
f"FSDP applied but found {module_name} in\n{root_module}"
)
def _warn_on_overridden_mixed_precision(
overridden_module_classes: Set[Type[nn.Module]],
):
if len(overridden_module_classes) == 0:
return
warnings.warn(
"Both mixed precision and an auto_wrap_policy were specified to FSDP, "
f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
"These modules will be wrapped as separate FSDP instacnes with mixed "
"precision disabled."
)
def _validate_frozen_params(
root_module: nn.Module,
modules_to_wrap: Set[nn.Module],
ignored_params: Set[nn.Parameter],
use_orig_params: bool,
):
"""
This checks that, given ``modules_to_wrap``, each module would manage
parameters that are uniformly frozen or non-frozen. This uniformity
requirement is strict for ``use_orig_params=False`` (hard error) and highly
recommended for ``use_orig_params=True`` (user warning).
"""
post_order_named_modules = _get_post_order_named_modules(root_module)
visited_modules: Set[nn.Module] = set()
for module_name, module in post_order_named_modules:
if module in modules_to_wrap:
param_to_fqn = _get_managed_param_to_fqn(
module, ignored_params, visited_modules, module_name
)
frozen_param_fqns: List[str] = []
frozen_param_numel = 0
nonfrozen_param_fqns: List[str] = []
nonfrozen_param_numel = 0
for param, fqn in param_to_fqn.items():
if param.requires_grad:
nonfrozen_param_fqns.append(fqn)
nonfrozen_param_numel += param.numel()
else:
frozen_param_fqns.append(fqn)
frozen_param_numel += param.numel()
if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
msg = f"{module_name} has both parameters with requires_grad=True and False."
if use_orig_params:
total_param_numel = frozen_param_numel + nonfrozen_param_numel
msg += (
" We do not recommend wrapping such modules since "
"the gradient memory usage will be higher than expected "
f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel "
"before sharding via reduce-scatter). "
)
else:
msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
msg += (
f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n"
f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
)
if use_orig_params:
warnings.warn(msg)
else:
raise ValueError(msg)
def _get_post_order_named_modules(
root_module: nn.Module,
) -> List[Tuple[str, nn.Module]]:
"""
This returns the named modules following a post-order traversal, which is a
valid reverse topological sort. We achieve this using the reverse of a
stack-based DFS order instead of reversing ``root_module.named_modules()``
since the former gives the modules in registration order at each level in
the module tree (as opposed to the reverse), which allows us to error/warn
on the first registered module that violates the condition.
For example, consider the following module structure:
M(
S1(),
S2(
SS1(),
SS2(),
),
S3(),
)
The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse
``named_modules()`` order is [S3, SS2, SS1, S2, S1, M].
"""
visited_modules = {root_module}
stack = [("", root_module)]
# Append and reverse at the end for linear-time algorithm
reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = []
while stack:
module_name, module = stack.pop()
reverse_post_order_named_modules.append((module_name, module))
for child_module_name, child_module in module.named_children():
if child_module is None: # only for overrides of `named_children()`
continue
if child_module not in visited_modules:
visited_modules.add(child_module)
if module_name != "":
child_module_name = module_name + "." + child_module_name
stack.append((child_module_name, child_module))
post_order_named_modules = list(reversed(reverse_post_order_named_modules))
return post_order_named_modules
def _get_managed_param_to_fqn(
module_to_wrap: nn.Module,
ignored_params: Set[nn.Parameter],
visited_modules: Set[nn.Module],
root_prefix: str,
) -> Dict[nn.Parameter, str]:
"""
This returns a dict that maps managed parameter to its FQN for the given
``module_to_wrap``. The dict's keys are exactly the parameters that would
be managed by the module, where this is achieved by calling this function
on the modules to wrap in reverse topological order, destructively updating
``visited_modules``, and not traversing into those modules. The FQNs are
prefixed from the root (via ``root_prefix``) to be more informative.
NOTE: This function is meant to be called pre-wrapping and iteratively in
reverse topological order to cover the full module tree. This differs from
the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
on the full module tree in one shot. Given those differences, we do not try
to unify the two.
"""
param_to_fqn: Dict[nn.Parameter, str] = {}
# Run BFS (or any tree traversal works)
queue = collections.deque([(module_to_wrap, root_prefix)])
visited_modules.add(module_to_wrap)
while queue:
module, prefix = queue.popleft()
for param_name, param in module.named_parameters(recurse=False):
if param not in ignored_params:
fqn = param_name if prefix == "" else prefix + "." + param_name
param_to_fqn[param] = fqn
for child_module_name, child_module in module.named_children():
if child_module is None: # only for overrides of `named_children()`
continue
if child_module not in visited_modules:
visited_modules.add(child_module)
child_prefix = (
child_module_name
if prefix == ""
else prefix + "." + child_module_name
)
queue.append((child_module, child_prefix))
return param_to_fqn

View File

@ -0,0 +1,410 @@
"""
This file includes public APIs for FSDP such as the classes used for the
constructor arguments.
"""
from dataclasses import dataclass
from enum import auto, Enum
from typing import Optional, Sequence, Type
import torch
from torch.nn.modules.batchnorm import _BatchNorm
__all__ = [
"ShardingStrategy",
"BackwardPrefetch",
"MixedPrecision",
"CPUOffload",
"StateDictType",
"StateDictConfig",
"FullStateDictConfig",
"LocalStateDictConfig",
"ShardedStateDictConfig",
"OptimStateDictConfig",
"FullOptimStateDictConfig",
"LocalOptimStateDictConfig",
"ShardedOptimStateDictConfig",
"StateDictSettings",
]
class ShardingStrategy(Enum):
"""
This specifies the sharding strategy to be used for distributed training by
:class:`FullyShardedDataParallel`.
- ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
For the parameters, this strategy unshards (via all-gather) before the
forward, reshards after the forward, unshards before the backward
computation, and reshards after the backward computation. For gradients,
it synchronizes and shards them (via reduce-scatter) after the backward
computation. The sharded optimizer states are updated locally per rank.
- ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
computation, and additionally, parameters are sharded outside
computation. For the parameters, this strategy unshards before the
forward, does not reshard them after the forward, and only reshards them
after the backward computation. The sharded optimizer states are updated
locally per rank. Inside ``no_sync()``, the parameters are not resharded
after the backward computation.
- ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
but instead replicated across ranks similar to PyTorch's
:class:`DistributedDataParallel` API. For gradients, this strategy
synchronizes them (via all-reduce) after the backward computation. The
unsharded optimizer states are updated locally per rank.
- ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
nodes. This results in reduced communication volume as expensive all-gathers and
reduce-scatters are only done within a node, which can be more performant for medium
-sized models.
- ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
since the unsharded parameters are not freed after the forward pass, saving the
all-gathers in the pre-backward.
"""
FULL_SHARD = auto()
SHARD_GRAD_OP = auto()
NO_SHARD = auto()
HYBRID_SHARD = auto()
_HYBRID_SHARD_ZERO2 = auto()
class BackwardPrefetch(Enum):
"""
This configures explicit backward prefetching, which improves throughput by
enabling communication and computation overlap in the backward pass at the
cost of slightly increased memory usage.
- ``BACKWARD_PRE``: This enables the most overlap but increases memory
usage the most. This prefetches the next set of parameters *before* the
current set of parameters' gradient computation. This overlaps the *next
all-gather* and the *current gradient computation*, and at the peak, it
holds the current set of parameters, next set of parameters, and current
set of gradients in memory.
- ``BACKWARD_POST``: This enables less overlap but requires less memory
usage. This prefetches the next set of parameters *after* the current
set of parameters' gradient computation. This overlaps the *current
reduce-scatter* and the *next gradient computation*, and it frees the
current set of parameters before allocating memory for the next set of
parameters, only holding the next set of parameters and current set of
gradients in memory at the peak.
- FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
the backward prefetching altogether. This has no overlap and does not
increase memory usage. In general, we do not recommend this setting since
it may degrade throughput significantly.
For more technical context: For a single process group using NCCL backend,
any collectives, even if issued from different streams, contend for the
same per-device NCCL stream, which implies that the relative order in which
the collectives are issued matters for overlapping. The two backward
prefetching values correspond to different issue orders.
"""
# NOTE: For both modes, the ordering that defines "current" and "next" is
# not always exact in the current implementation. A mistargeted prefetch
# simply means that the parameter memory is allocated earlier than needed,
# possibly increasing peak memory usage, but does not affect correctness.
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
@dataclass
class MixedPrecision:
"""
This configures FSDP-native mixed precision training.
Attributes:
param_dtype (Optional[torch.dtype]): This specifies the dtype for model
parameters during forward and backward and thus the dtype for
forward and backward computation. Outside forward and backward, the
*sharded* parameters are kept in full precision (e.g. for the
optimizer step), and for model checkpointing, the parameters are
always saved in full precision. (Default: ``None``)
reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
gradient reduction (i.e. reduce-scatter or all-reduce). If this is
``None`` but ``param_dtype`` is not ``None``, then this takes on
the ``param_dtype`` value, still running gradient reduction in low
precision. This is permitted to differ from ``param_dtype``, e.g.
to force gradient reduction to run in full precision. (Default:
``None``)
buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
buffers. FSDP does not shard buffers. Rather, FSDP casts them to
``buffer_dtype`` in the first forward pass and keeps them in that
dtype thereafter. For model checkpointing, the buffers are saved
in full precision except for ``LOCAL_STATE_DICT``. (Default:
``None``)
keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
gradients to full precision after the backward pass in preparation
for the optimizer step. If ``True``, then FSDP keeps the gradients
in the dtype used for gradient reduction, which can save memory if
using a custom optimizer that supports running in low precision.
(Default: ``False``)
cast_forward_inputs (bool): If ``True``, then this FSDP module casts
its forward args and kwargs to ``param_dtype``. This is to ensure
that parameter and input dtypes match for forward computation, as
required by many ops. This may need to be set to ``True`` when only
applying mixed precision to some but not all FSDP modules, in which
case a mixed-precision FSDP submodule needs to recast its inputs.
(Default: ``False``)
cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
casts its forward args and kwargs to ``param_dtype``, overriding
the value of ``cast_forward_inputs``. For non-root FSDP modules,
this does not do anything. (Default: ``True``)
_module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
module classes to ignore for mixed precision when using an
``auto_wrap_policy``: Modules of these classes will have FSDP
applied to them separately with mixed precision disabled (meaning
that the final FSDP construction would deviate from the specified
policy). If ``auto_wrap_policy`` is not specified, then this does
not do anything. This API is experimental and subject to change.
(Default: ``(_BatchNorm,)``)
.. note:: This API is experimental and subject to change.
.. note:: Only floating point tensors are cast to their specified dtypes.
.. note:: In ``summon_full_params``, parameters are forced to full
precision, but buffers are not.
.. note:: Layer norm and batch norm accumulate in ``float32`` even when
their inputs are in a low precision like ``float16`` or ``bfloat16``.
Disabling FSDP's mixed precision for those norm modules only means that
the affine parameters are kept in ``float32``. However, this incurs
separate all-gathers and reduce-scatters for those norm modules, which
may be inefficient, so if the workload permits, the user should prefer
to still apply mixed precision to those modules.
.. note:: By default, if the user passes a model with any ``_BatchNorm``
modules and specifies an ``auto_wrap_policy``, then the batch norm
modules will have FSDP applied to them separately with mixed precision
disabled. See the ``_module_classes_to_ignore`` argument.
.. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
``cast_forward_inputs=False`` by default. For the root FSDP instance,
its ``cast_root_forward_inputs`` takes precedence over its
``cast_forward_inputs``. For non-root FSDP instances, their
``cast_root_forward_inputs`` values are ignored. The default setting is
sufficient for the typical case where each FSDP instance has the same
``MixedPrecision`` configuration and only needs to cast inputs to the
``param_dtype`` at the beginning of the model's forward pass.
.. note:: For nested FSDP instances with different ``MixedPrecision``
configurations, we recommend setting individual ``cast_forward_inputs``
values to configure casting inputs or not before each instance's
forward. In such a case, since the casts happen before each FSDP
instance's forward, a parent FSDP instance should have its non-FSDP
submodules run before its FSDP submodules to avoid the activation dtype
being changed due to a different ``MixedPrecision`` configuration.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>> model[1],
>>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
>>> )
>>> model = FSDP(
>>> model,
>>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
>>> )
The above shows a working example. On the other hand, if ``model[1]``
were replaced with ``model[0]``, meaning that the submodule using
different ``MixedPrecision`` ran its forward first, then ``model[1]``
would incorrectly see ``float16`` activations instead of ``bfloat16``
ones.
"""
param_dtype: Optional[torch.dtype] = None
reduce_dtype: Optional[torch.dtype] = None
buffer_dtype: Optional[torch.dtype] = None
keep_low_precision_grads: bool = False
cast_forward_inputs: bool = False
cast_root_forward_inputs: bool = True
_module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
@dataclass
class CPUOffload:
"""
This configures CPU offloading.
Attributes:
offload_params (bool): This specifies whether to offload parameters to
CPU when not involved in computation. If ``True``, then this
offloads gradients to CPU as well, meaning that the optimizer step
runs on CPU.
"""
offload_params: bool = False
class StateDictType(Enum):
"""
This enum indicates that which type of ``state_dict`` the FSDP module is
currently processing (returning or loading).
The default value is FULL_STATE_DICT to comply the PyTorch convention.
..note::
FSDP currently supports three types of ``state_dict``:
1. ``state_dict/load_state_dict`: this pair of APIs return and load
the non-sharded, unflattened parameters. The semantics is the
same as using DDP.
2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
and load local sharded, flattened parameters. The values returned
by ``_local_state_dict`` can be directly used by FSDP and is only
meaningful to FSDP (because parameters are flattened). Note that
these APIs are meant for use via the :func:`state_dict_type`
context manager as follows:
>>> # xdoctest: +SKIP("undefined variables")
>>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
... state = fsdp.state_dict() # loads local state dict
3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
return and load sharded, unflattened parameters. The ``state_dict``
return by ``sharded_state_dict`` can be used by all other parallel
schemes (resharding may be required).
"""
FULL_STATE_DICT = auto()
LOCAL_STATE_DICT = auto()
SHARDED_STATE_DICT = auto()
@dataclass
class StateDictConfig:
"""
``StateDictConfig`` is the base class for all ``state_dict`` configuration
classes. Users should instantiate a child class (e.g.
``FullStateDictConfig``) in order to configure settings for the
corresponding ``state_dict`` type supported by FSDP.
Attributes:
offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict
values to CPU, and if ``False``, then FSDP keeps them on GPU.
(Default: ``False``)
"""
offload_to_cpu: bool = False
@dataclass
class FullStateDictConfig(StateDictConfig):
"""
``FullStateDictConfig`` is a config class meant to be used with
``StateDictType.FULL_STATE_DICT``. We recommend enabling both
``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state
dicts to save GPU memory and CPU memory, respectively. This config class
is meant to be used via the :func:`state_dict_type` context manager as
follows:
>>> # xdoctest: +SKIP("undefined variables")
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>> state = fsdp.state_dict()
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>> state_dict = torch.load("my_checkpoint.pt")
>>> model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
Attributes:
rank0_only (bool): If ``True``, then only rank 0 saves the full state
dict, and nonzero ranks save an empty dict. If ``False``, then all
ranks save the full state dict. (Default: ``False``)
"""
rank0_only: bool = False
@dataclass
class LocalStateDictConfig(StateDictConfig):
pass
@dataclass
class ShardedStateDictConfig(StateDictConfig):
"""
``ShardedStateDictConfig`` is a config class meant to be used with
``StateDictType.SHARDED_STATE_DICT``.
Attributes:
_use_dtensor (bool): If ``True``, then FSDP saves the state dict values
as ``DTensor``, and if ``False``, then FSDP saves them as
``ShardedTensor``. (Default: ``False``)
.. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`
and it is used by FSDP to determine the type of state dict values. Users should not
manually modify ``_use_dtensor``.
"""
_use_dtensor: bool = False
@dataclass
class OptimStateDictConfig:
"""
``OptimStateDictConfig`` is the base class for all ``optim_state_dict``
configuration classes. Users should instantiate a child class (e.g.
``FullOptimStateDictConfig``) in order to configure settings for the
corresponding ``optim_state_dict`` type supported by FSDP.
Attributes:
offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's
tensor values to CPU, and if ``False``, then FSDP keeps them on the
original device (which is GPU unless parameter CPU offloading is
enabled). (Default: ``True``)
"""
offload_to_cpu: bool = True
@dataclass
class FullOptimStateDictConfig(OptimStateDictConfig):
"""
Attributes:
rank0_only (bool): If ``True``, then only rank 0 saves the full state
dict, and nonzero ranks save an empty dict. If ``False``, then all
ranks save the full state dict. (Default: ``False``)
"""
rank0_only: bool = False
@dataclass
class LocalOptimStateDictConfig(OptimStateDictConfig):
offload_to_cpu: bool = False
@dataclass
class ShardedOptimStateDictConfig(OptimStateDictConfig):
"""
``ShardedOptimStateDictConfig`` is a config class meant to be used with
``StateDictType.SHARDED_STATE_DICT``.
Attributes:
_use_dtensor (bool): If ``True``, then FSDP saves the state dict values
as ``DTensor``, and if ``False``, then FSDP saves them as
``ShardedTensor``. (Default: ``False``)
.. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`
and it is used by FSDP to determine the type of state dict values. Users should not
manually modify ``_use_dtensor``.
"""
_use_dtensor: bool = False
@dataclass
class StateDictSettings:
state_dict_type: StateDictType
state_dict_config: StateDictConfig
optim_state_dict_config: OptimStateDictConfig

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,396 @@
# mypy: allow-untyped-defs
import logging
from collections import abc, defaultdict
from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
import torch
import torch.distributed as dist
from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
from torch.distributed.distributed_c10d import ProcessGroup
logger = logging.getLogger(__name__)
def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}
def _is_supported_device(tensor: torch.Tensor) -> bool:
return tensor.is_cuda or tensor.device.type in (
"xla",
"cpu",
"hpu",
"mtia",
torch._C._get_privateuse1_backend_name(),
)
class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
"""
Lazily serves tensor to request device. This class extends
_MultiDeviceReplicator to allow support for "cpu" as a device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert _is_supported_device(master_tensor)
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
class ShardedGradScaler(GradScaler):
"""
ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends
functionality from GradScaler:
* Supports Pytorch DDP and FSDP implementations
* Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP])
* Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns
* Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across
nodes
Example::
# Creates a ShardedGradScaler once at the beginning of training.
scaler = ShardedGradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See :class:`GradScaler` for explanation of scaling/unscaling and more use cases.
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
process group for sharding
"""
def __init__(
self,
device: str = "cuda",
init_scale: float = 2.0**16,
backoff_factor: float = 0.5,
growth_factor: float = 2.0,
growth_interval: int = 2000,
enabled: bool = True,
process_group: Optional[ProcessGroup] = dist.group.WORLD,
) -> None:
super().__init__(
device,
init_scale=init_scale,
backoff_factor=backoff_factor,
growth_factor=growth_factor,
growth_interval=growth_interval,
enabled=enabled,
)
if self._enabled:
self.process_group = process_group
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
if not self._enabled:
return outputs
if isinstance(outputs, torch.Tensor):
assert _is_supported_device(outputs)
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
scaled_output = outputs * self._scale.to(
device=outputs.device, non_blocking=True
)
# Here we ensure the return dtype is the same as the outputs dtype.
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
return scaled_output.type(outputs.dtype)
stash: List[_GeneralMultiDeviceReplicator] = []
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
assert _is_supported_device(val)
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_GeneralMultiDeviceReplicator(self._scale))
scaled_val = val * stash[0].get(val.device)
# Here we ensure the return dtype is the same as the outputs dtype.
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
return scaled_val.type(val.dtype)
if isinstance(val, abc.Iterable):
iterator = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterator)
return iterator
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _foreach_non_finite_check_and_unscale_cpu_(
self,
grads: Sequence[torch.Tensor],
found_inf: torch.Tensor,
inv_scale: torch.Tensor,
) -> None:
if len(grads) == 0:
return
assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."
assert found_inf.numel() == 1, "found_inf must be a 1-element tensor."
for grad in grads:
if grad.device.type != "cpu":
logger.error(
"tensor device is %s but was expected to be ``cpu``",
grad.device,
)
raise ValueError(
"Gradients were found on a non-CPU device when"
" expected to be on CPU."
)
if (
torch.isinf(grad).any().item() is True
or torch.isnan(grad).any().item() is True
):
found_inf.data = torch.tensor([1.0])
break
else:
grad.data *= inv_scale.item()
def _unscale_grads_(
self,
optimizer: torch.optim.Optimizer,
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool = True,
) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be thousands of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
# coalesce is not supported in torch.float16
param_grad_fp32 = param.grad.type(torch.float32).coalesce()
param.grad = param_grad_fp32.type(torch.float16)
to_unscale = param.grad._values()
else:
to_unscale = param.grad
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
if grads[0].device.type == "cpu":
self._foreach_non_finite_check_and_unscale_cpu_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
else:
torch._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
# There exist contexts (e.g. w/ `use_orig_params=True`) wherein some
# ranks may have no (non-zero sized) parameter shards, necessitating the
# initialization of `per_device_found_inf._per_device_tensors` here
if not per_device_found_inf._per_device_tensors:
assert self._scale is not None
per_device_found_inf.get(self._scale.device)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, True
)
optimizer_state["stage"] = OptState.UNSCALED
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
works = []
found_inf_on_cpus = []
found_inf_on_devices = []
for found_inf in optimizer_state["found_inf_per_device"].values():
if self._device != "cpu" and found_inf.device.type == "cpu":
found_inf_on_cpus.append(found_inf)
found_inf_on_device = found_inf.to(self._device)
found_inf_on_devices.append(found_inf_on_device)
works.append(
dist.all_reduce(
found_inf_on_device, async_op=True, group=self.process_group
)
)
else:
works.append(
dist.all_reduce(found_inf, async_op=True, group=self.process_group)
)
for work in works:
work.wait()
if found_inf_on_cpus:
torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices)
def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
"""
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
"""
assert self._scale is not None and self._growth_tracker is not None
if found_inf.item() >= 1.0:
self._scale *= self._backoff_factor
self._growth_tracker.fill_(0)
else:
successful = self._growth_tracker + 1
if successful == self._growth_interval:
self._scale *= self._growth_factor
self._growth_tracker.fill_(0)
else:
self._growth_tracker = successful
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated]
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
torch.FloatTensor with requires_grad=False."
assert new_scale.device.type == self._device, reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
if _scale.device.type == "cpu":
self._amp_update_scale_cpu_(found_inf_combined)
else:
torch._amp_update_scale_(
self._scale, # type: ignore[arg-type]
self._growth_tracker, # type: ignore[arg-type]
found_inf_combined,
self._growth_factor, # type: ignore[arg-type]
self._backoff_factor, # type: ignore[arg-type]
self._growth_interval, # type: ignore[arg-type]
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)

View File

@ -0,0 +1,608 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
import torch.nn as nn
__all__ = [
"always_wrap_policy",
"lambda_auto_wrap_policy",
"transformer_auto_wrap_policy",
"size_based_auto_wrap_policy",
"enable_wrap",
"wrap",
"CustomPolicy",
"ModuleWrapPolicy",
]
# NOTE: We intentionally keep this function simple and isolate the complexity
# to `fn` to enable using this function generically. We may move this to a
# non-FSDP-specific folder and/or make it public in the future.
def _post_order_apply(
root_module: nn.Module,
fn: Callable[[nn.Module], Optional[nn.Module]],
):
"""
This applies ``fn`` to every module in the module tree of ``root_module``
following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
then this replaces the original module with the newly returned one in the
tree. Otherwise, ``fn`` should return ``None``, in which case the module is
not changed.
"""
# Track visited modules to avoid visiting shared modules multiple times
visited_modules: Set[nn.Module] = {root_module}
def _post_order_apply_inner(
module: nn.Module,
module_name: str,
parent_module: Optional[nn.Module],
):
for child_module_name, child_module in module.named_children():
if child_module not in visited_modules:
visited_modules.add(child_module)
_post_order_apply_inner(child_module, child_module_name, module)
optional_module = fn(module)
if optional_module is not None:
assert isinstance(parent_module, nn.Module), (
"Non-root modules should have their parent module set but got "
f"{parent_module} for {module}"
)
assert module_name, (
"Non-root modules should have their module name set but got "
f"an empty module name for {module}"
)
assert isinstance(
optional_module, nn.Module
), f"fn should return None or an nn.Module but got {optional_module}"
setattr(parent_module, module_name, optional_module)
_post_order_apply_inner(root_module, "", None)
def _construct_wrap_fn(
root_module: nn.Module,
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
fsdp_fn: Callable,
) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
This constructs the "wrap" function to pass to :func:`_post_order_apply`
based on ``target_module_to_kwargs``, which should be constructed from the
wrapping policy.
"""
def fn(module: nn.Module) -> Optional[nn.Module]:
# Explicitly avoid wrapping the root module since for FSDP, it is
# handled by the caller
if module in target_module_to_kwargs and module is not root_module:
kwargs = target_module_to_kwargs[module]
return fsdp_fn(module, **kwargs)
return None
return fn
def _run_mixed_precision_override_policy(
root_module: nn.Module,
module_classes: Iterable[Type[nn.Module]],
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
):
module_classes_tuple = tuple(set(module_classes))
for module in root_module.modules():
if module in ignored_modules:
continue
elif isinstance(module, module_classes_tuple):
# This policy overrides any existing policy
if module not in target_module_to_kwargs:
# Only inherit from the root kwargs if not already specified
target_module_to_kwargs[module] = root_kwargs
target_module_to_kwargs[module]["mixed_precision"] = None
return target_module_to_kwargs
def always_wrap_policy(*args, **kwargs) -> bool:
"""
A simple recursive wrap policy that always returns ``True``. This means
that every submodule is wrapped by the wrapper class in
:func:`_recursive_wrap`.
"""
return True
class _Policy(ABC):
"""
This defines an abstract base class that represents a policy for applying
a module-level API.
"""
@abstractmethod
def _run_policy(
self,
root_module: nn.Module,
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
"""
This should return a dict ``target_module_to_kwargs`` that maps from
each target module to wrap to its kwargs.
"""
...
def _module_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
module_classes: Set[Type[nn.Module]],
) -> bool:
"""
This auto wrap policy wraps every module that is an instance of any type in
``module_classes`` as its own FSDP instance. The root module given by
``module`` is always wrapped as an FSDP instance regardless. Since the
wrapping proceeds bottom up, each FSDP instance manages the parameters in
its subtree excluding any already managed by a child FSDP instance.
Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
module_classes (Set[Type[nn.Module]]): Set of module classes that are
wrapped as FSDP instances.
Returns:
``True`` if ``recurse=True``, and whether ``module`` should be wrapped
if ``recurse=False``.
"""
if recurse:
return True # always recurse
return isinstance(module, tuple(module_classes))
class ModuleWrapPolicy(_Policy):
"""
This policy applies to every module of the specified module classes,
passing in the kwargs given to the root.
"""
def __init__(self, module_classes: Iterable[Type[nn.Module]]):
module_classes_set = set(module_classes)
self._module_classes = module_classes_set
self._module_classes_str = str(module_classes_set)
def _run_policy(
self,
root_module: nn.Module,
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
module_classes = tuple(self._module_classes)
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
for module in root_module.modules():
if module in ignored_modules:
continue
elif isinstance(module, module_classes):
# Shallow copy to avoid coupling changes across modules
target_module_to_kwargs[module] = copy.copy(root_kwargs)
return target_module_to_kwargs
def __call__(self, module, recurse, *args, **kwargs):
# nonwrapped_numel is not used.
return _module_wrap_policy(
module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
)
def __repr__(self) -> str:
return super().__repr__() + f"({self._module_classes_str})"
class CustomPolicy(_Policy):
"""
This policy takes in a lambda function that maps a given ``nn.Module`` to
either ``False``, ``True``, or a kwarg dictionary.
- If the function returns ``False`` or an empty dictionary, then the module
does not have the API applied.
- If the function returns ``True``, then the module has the API applied
with the root's kwargs.
- If the function returns a non-empty dictionary, then the module has the
API applied, and the dictionary overrides the root's kwargs.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> model = init_transformer_model(...)
>>> def lambda_fn(module: nn.Module):
>>> if module is model.lm_head:
>>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
>>> elif isinstance(module, TransformerBlock):
>>> return True
>>> return False
>>> policy = CustomPolicy(lambda_fn)
>>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
"""
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
self._lambda_fn = lambda_fn
def _run_policy(
self,
root_module: nn.Module,
ignored_modules: Set[nn.Module],
root_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
for module in root_module.modules():
if module in ignored_modules:
continue
res = self._lambda_fn(module)
if not isinstance(res, (dict, bool)):
raise ValueError(
"The lambda_fn passed to CustomPolicy should return "
f"False/True or a kwarg dict, but it returned {res}"
)
if not res:
continue
kwargs = copy.copy(root_kwargs)
if isinstance(res, dict):
# Override the root kwargs with the ones specified by the
# lambda function
kwargs.update(res)
target_module_to_kwargs[module] = kwargs
return target_module_to_kwargs
def lambda_auto_wrap_policy(
module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
) -> bool:
"""
A convenient auto wrap policy to wrap submodules based on an arbitrary user
function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
a `wrapper_cls` unit.
Return if a module should be wrapped during auto wrapping.
The first three parameters are required by :func:`_recursive_wrap`.
Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
this module will be wrapped.
"""
if recurse:
return True # always recurse
return lambda_fn(module)
def transformer_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
transformer_layer_cls: Set[Type[nn.Module]],
) -> bool:
"""
See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
same as ``module_classes``. Note that shared parameters must be wrapped in
the same FSDP instance, so this auto wrap policy can help wrap shared
embeddings into the same FSDP instance for transformer models.
"""
return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
def _wrap_module_cls_individually(
module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
):
if recurse:
# always recurse
return True
else:
# if not recursing, decide whether we should wrap based on whether the type of module
# is in `module_classes`.
return isinstance(module, tuple(module_classes))
def _or_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
policies,
) -> bool:
"""
A policy that wraps ``module`` if any policy in the passed in iterable of
``policies`` returns ``True``.
"""
return any(
policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
for policy in policies
)
def size_based_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
# Additional custom arguments
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
) -> bool:
"""
A size-based auto wrap policy.
Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
min_num_params (int): Customizable policy input that controls the size
threshold over which a module is ready to be wrapped. This is in
units of numel.
force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
as leaves, i.e. their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
excluded in wrapping.
Returns:
Whether ``module`` should be wrapped.
"""
force_leaf_modules = (
size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
if force_leaf_modules is None
else force_leaf_modules
)
exclude_wrap_modules = (
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
if exclude_wrap_modules is None
else exclude_wrap_modules
)
# Keep the argument `min_num_params` for BC for now, but it represents the
# minimum non-wrapped *numel* before triggering a wrapping
min_nonwrapped_numel = min_num_params
is_large = nonwrapped_numel >= min_nonwrapped_numel
if recurse:
# We should recurse if the module is big enough but not in force_leaf_modules list.
return is_large and not isinstance(module, tuple(force_leaf_modules))
else:
# If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
@contextlib.contextmanager
def enable_wrap(
*, wrapper_cls: Any, **wrapper_kwargs: Any
) -> Generator[None, None, None]:
"""
Context manager to wrap modules using a wrapper.
Useful for when you'd like to apply the same configuration arguments to all
child modules that you wrap. A particularly important use case is wrapping
large layers so that they get sharded (in-place) during initialization, to
avoid running out of system memory. Large layers can indicate that they
should be sharded via the ``wrap`` annotation and this context manager can
provide the exact configuration for these nested instances.
Usage::
with enable_wrap(wrapper_cls, **params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
Args:
wrapper_cls:
Class that `wrap` annotation will `wrap` modules with, such as
`FullyShardedDataParallel`.
**wrapper_kwargs:
Configuration settings that will be passed to all ``wrap``
instances inside the context
"""
kwargs = {
"wrapper_cls": wrapper_cls,
**wrapper_kwargs,
}
with _ConfigAutoWrap(**kwargs):
yield
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
"""
Annotate that a module should be wrapped. Annotated modules will only be
wrapped if inside of an :func:`enable_wrap` context manager. This allows
a module to be initialized both with and without a wrapper without code
change.
The class that this function wraps the passed in ``nn.Module`` with is the
passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
the ``wrapper_cls`` instance. In the case of duplicate kwargs in
``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
respected.
Usage::
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
Args:
module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
**wrap_overrides: configuration overrides that will take priority over
the values provided by the :func:`enable_wrap` context
"""
if _ConfigAutoWrap.in_autowrap_context:
assert _ConfigAutoWrap.wrapper_cls is not None
wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
return _wrap(
module,
_ConfigAutoWrap.wrapper_cls,
**wrap_overrides,
)
return module
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
assert wrapper_cls is not None
if hasattr(module, "_wrap_overrides"):
# If module has a _wrap_overrides attribute, we force overriding the
# FSDP config with these attributes for this module. Currently this
# is only used to disable mixed precision for BatchNorm when
# auto_wrapping.
overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type]
return wrapper_cls(module, **overrides)
return wrapper_cls(module, **kwargs)
def _recursive_wrap(
module: nn.Module,
auto_wrap_policy: Callable,
wrapper_cls: Callable,
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
only_wrap_children: bool = False,
**kwargs: Any,
) -> Tuple[nn.Module, int]:
"""
Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
``True`` with ``wrapper_cls``.
Args:
module (nn.Module): Module to recursively wrap.
auto_wrap_policy (Callable): A callable representing a policy that
determines which modules to recursively wrap with ``wrapper_cls``.
ignored_modules (Set[torch.nn.Module]): Modules to ignore when
wrapping.
ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
wrapping; these should be the parameters contained in the modules
in ``ignored_modules``.
Returns:
(nn.Module, int):
``module`` after wrapping and the numel recursively wrapped.
"""
assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
assert wrapper_cls is not None, "Must specify wrapper_cls"
# Make sure no child is already wrapped.
for _, child in module.named_modules():
if child in ignored_modules:
continue
try:
assert not isinstance(child, cast(type, wrapper_cls))
except TypeError:
# wrapper_cls is a function as opposed to a class type, just bypass above check.
pass
# We count all params, assuming none of them are already wrapped.
nonwrapped_numel = sum(
p.numel() for p in module.parameters() if p not in ignored_params
)
assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
total_wrapped_numel = 0
# Iterate through the children, recursively wrap if necessary
for name, child in module.named_children():
if child in ignored_modules:
continue
wrapped_child, num_wrapped_params = _recursive_wrap(
module=child,
auto_wrap_policy=auto_wrap_policy,
wrapper_cls=wrapper_cls,
ignored_modules=ignored_modules,
ignored_params=ignored_params,
**kwargs,
)
setattr(module, name, wrapped_child)
# Keep track of how many parameters have been wrapped
total_wrapped_numel += num_wrapped_params
# decide if we need to wrap the current module,
# since the left over parameters exceed the number of params to wrap
remainder = nonwrapped_numel - total_wrapped_numel
if not only_wrap_children and auto_wrap_policy(
module=module, recurse=False, nonwrapped_numel=remainder
):
# Leaf node or final wrapping of the remainder both happen here.
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
else:
return module, total_wrapped_numel
return module, 0
class _ConfigAutoWrap:
"""
Helper class to wrap modules based on default config args via a context manager.
See :func:`enable_wrap` for more information.
"""
in_autowrap_context: bool = False # Context flag
wrapper_cls: Optional[Callable] = None # The wrapper class
kwargs: Dict[str, Any] = {} # Wrapper's args
def __init__(self, **kwargs: Dict[str, Any]):
self.kwargs = kwargs
@staticmethod
def enable_autowrap_context(kwargs: Any) -> None:
if _ConfigAutoWrap.in_autowrap_context:
raise NotImplementedError(
"You are already within an autowrap context and we currently do not supported nested autowrap."
)
_ConfigAutoWrap.in_autowrap_context = True
# Get and save the wrapper cls for the context.
assert (
"wrapper_cls" in kwargs.keys()
), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
del kwargs["wrapper_cls"]
# Save the rest.
_ConfigAutoWrap.kwargs = kwargs
@staticmethod
def disable_autowrap_context() -> None:
_ConfigAutoWrap.in_autowrap_context = False
_ConfigAutoWrap.wrapper_cls = None
_ConfigAutoWrap.kwargs = {}
def __enter__(self) -> None:
self.enable_autowrap_context(self.kwargs)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.disable_autowrap_context()