I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,14 @@
from .api import CheckpointException
from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
from .filesystem import FileSystemReader, FileSystemWriter
from .metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
TensorStorageMetadata,
)
from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .state_dict_loader import load, load_state_dict
from .state_dict_saver import async_save, save, save_state_dict
from .storage import StorageReader, StorageWriter

View File

@ -0,0 +1,100 @@
from concurrent.futures import Future
from typing import Any, Dict, List, Optional
import torch.distributed as dist
import torch.distributed.checkpoint.state_dict_loader as loader
import torch.distributed.checkpoint.state_dict_saver as saver
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.storage import (
LoadPlanner,
SavePlanner,
StorageReader,
StorageWriter,
)
__all__: List[str] = []
class _Checkpointer:
"""This base class specefies a high level API for saving and loading
distributed `state_dict` 's. It provides an abstraction over the low-level APIs
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
:py:meth: `torch.distributed.state_dict_saver.save` and
:py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
readers and writers.
.. warning::
This feature is experimental and subject to removal/change.
"""
def __init__(
self,
storage_writer: StorageWriter,
storage_reader: StorageReader,
*,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
load_planner: Optional[LoadPlanner] = None,
save_planner: Optional[SavePlanner] = None,
):
"""Initializes the Checkpointer instance.
Args:
storage_writer: Instance of StorageWrite use to perform writes.
storage_reader: StorageReader used to load data from.
process_group: ProcessGroup to be used for cross-rank synchronization.
coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
loader_planner: Instance of LoadPlanner to use when loading.
save_planner: Instance of SavePlanner to use when saving.
"""
self.storage_writer = storage_writer
self.storage_reader = storage_reader
self.process_group = process_group
self.coordinator_rank = coordinator_rank
self.no_dist = no_dist
self.load_planner = load_planner
self.save_planner = save_planner
def save(
self,
state_dict: STATE_DICT_TYPE,
) -> Metadata:
"""Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
return saver.save(
state_dict,
self.storage_writer,
process_group=self.process_group,
coordinator_rank=self.coordinator_rank,
no_dist=self.no_dist,
planner=self.save_planner,
)
def async_save(
self,
state_dict: STATE_DICT_TYPE,
) -> Future:
"""
Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.
Returns:
Future: A future holding the resultant Metadata object from `save`.
"""
return saver.async_save(
state_dict,
storage_writer=self.storage_writer,
process_group=self.process_group,
planner=self.save_planner,
)
def load(self, state_dict: Dict[str, Any]) -> None:
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
loader.load(
state_dict,
storage_reader=self.storage_reader,
process_group=self.process_group,
planner=self.load_planner,
)

View File

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
from collections import defaultdict
from typing import Dict, List, Set, TYPE_CHECKING
from torch.distributed.checkpoint.planner import SavePlan, WriteItem
if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import MetadataIndex
__all__ = ["dedup_save_plans"]
def dedup_save_plans(
all_plans: List[SavePlan],
save_to_lowest_rank: bool = False,
) -> List[SavePlan]:
"""
Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
"""
write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
# map each write item to its plan
write_item_to_plan_indices[write_item.index].add(plan_idx)
write_item_idx_to_write_item[write_item.index] = write_item
# put item in the plan with the smallest size and remove it from the other plan_indices
to_remove: List[Set] = [set() for _ in range(len(all_plans))]
plan_to_size = [0] * len(all_plans)
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
if save_to_lowest_rank:
select_plan_idx = min(plan_indices)
else:
select_plan_idx = min(
plan_indices, key=lambda plan_idx: plan_to_size[plan_idx]
)
write_item = write_item_idx_to_write_item[write_item_idx]
# essentially ignores the storage size of anything that is not a tensor, since
# we don't know how much storage they represent
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
plan_indices.remove(select_plan_idx)
for plan_idx in plan_indices:
to_remove[plan_idx].add(write_item_idx)
for plan_idx, remove_set in enumerate(to_remove):
new_items = [
write_item
for write_item in all_plans[plan_idx].items
if write_item.index not in remove_set
]
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans

View File

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
import logging
from typing import Dict, List, TYPE_CHECKING
from torch.distributed.checkpoint.planner import SavePlan
if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import MetadataIndex
__all__ = ["dedup_tensors"]
def init_logger() -> logging.Logger:
logger = logging.getLogger(__name__)
level = logging.INFO
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
console.setFormatter(formatter)
console.setLevel(level)
logger.addHandler(console)
logger.propagate = False
return logger
logger = init_logger()
# TODO add docstring for dedup_tensors
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
all_plans = list(all_plans)
key_to_plan: Dict[MetadataIndex, List[int]] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
key_to_plan.setdefault(write_item.index, []).append(plan_idx)
replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}
# Remove duplicates by always keeping the first entry.
# Compute the per-rank remove set.
plan_to_keys: Dict[int, List[MetadataIndex]] = {}
for key, plans in replicated_items.items():
for plan_idx in plans[1:]:
plan_to_keys.setdefault(plan_idx, []).append(key)
if len(plan_to_keys) > 0:
logger.info("Duplicate keys to remove: %s", plan_to_keys)
for plan_idx, keys in plan_to_keys.items():
key_set = set(keys)
# rewrite items and remove elements
new_items = [
write_item
for write_item in all_plans[plan_idx].items
if write_item.index not in key_set
]
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans

View File

@ -0,0 +1,137 @@
# Mypy will not try inferring the types of any 3rd party libraries installed.
# mypy: ignore-errors
import io
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, Union
import fsspec
from fsspec import AbstractFileSystem
from fsspec.core import url_to_fs
from torch.distributed.checkpoint.filesystem import (
FileSystemBase,
FileSystemReader,
FileSystemWriter,
)
__all__ = [
"FsspecWriter",
"FsspecReader",
]
class FileSystem(FileSystemBase):
def __init__(self) -> None:
self.fs: Optional[AbstractFileSystem] = None
@contextmanager
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
assert self.fs is not None
with self.fs.transaction:
with fsspec.open(str(path), mode) as stream:
yield stream
def concat_path(
self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]:
return os.path.join(path, suffix)
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
self.fs, _ = url_to_fs(path)
return path
def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
self.fs.rename(path, new_path)
def mkdir(self, path: [str, os.PathLike]) -> None:
self.fs.makedirs(path, exist_ok=True)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
if isinstance(checkpoint_id, Path):
return False
try:
url_to_fs(checkpoint_id)
except ValueError:
return False
return True
def exists(self, path: Union[str, os.PathLike]) -> bool:
return self.fs.exists(path)
def rm_file(self, path: Union[str, os.PathLike]) -> None:
self.fs.rm(path)
# TODO: add the dcp.async_save mixin
class FsspecWriter(FileSystemWriter):
"""
Basic implementation of StorageWriter using FFspec.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
overwrite: bool = True,
) -> None:
"""
Initialize the writer pointing to `path`.
Args:
path: directory where the checkpoint will be written to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__(
path,
single_file_per_rank,
sync_files,
thread_count,
per_thread_copy_ahead,
overwrite=overwrite,
)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)
class FsspecReader(FileSystemReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
super().__init__(path)
self.fs = FileSystem()
self.path = self.fs.init_path(path)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)

View File

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict, Tuple
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from . import _version
from ._traverse import (
OBJ_PATH,
set_element,
STATE_DICT_ITEM,
traverse_state_dict,
traverse_state_dict_v_2_3,
)
"""
TODO:
Need to add ability to handle tuple, OrderedDict, NamedTuple.
Update mappings from dict to a class.
Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple.
"""
FLATTEN_MAPPING = Dict[str, OBJ_PATH]
# TODO: Update Docstring for nested_dict.py
def flatten_state_dict(
state_dict: STATE_DICT_TYPE,
) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:
"""
Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.
Use ``unflatten_state_dict`` to revert this process.
Returns:
A tuple with the flatten state_dict and a mapping from original to new state_dict.
N.B. The new keys are derived from the object paths, joined by dot.
For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.
"""
flattened: STATE_DICT_TYPE = {}
mappings: FLATTEN_MAPPING = {}
def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
new_fqn = ".".join(map(str, path))
if new_fqn in flattened:
raise ValueError(f"duplicated flatten key {new_fqn}")
flattened[new_fqn] = value
mappings[new_fqn] = path
# We started to flatten dictionary since v2.4. But in order to not break
# the checkpoints that were saved before v2.4, we need to keep the old
# traversal so that we can reconstruct those checkpoints.
use_v_2_3 = (
_version._derived_version is not None and _version._derived_version == "2_3"
)
if use_v_2_3:
traverse_state_dict_v_2_3(state_dict, flat_copy)
else:
traverse_state_dict(state_dict, flat_copy)
return flattened, mappings
def unflatten_state_dict(
state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING
) -> STATE_DICT_TYPE:
"""Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``."""
nested: STATE_DICT_TYPE = {}
for key, value in state_dict.items():
set_element(nested, mapping[key], value)
return nested

View File

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
from typing import TYPE_CHECKING
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.remote_device import _remote_device
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
from .utils import _element_wise_add, _normalize_device_info
if TYPE_CHECKING:
from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
# TODO: We need to refactor this code.
def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
r"""
Transform ``state_dict`` by flattening all nested ShardedTensor instances found.
The resulting ShardedTensor instances are only correct regarding the local shard and
MUST not be used for any other purpose but checkpointing, as no operator will work with them.
This function should be used in conjunction with a state_dict produced by FSDP's
StateDictType.SHARDED_STATE_DICT methods.
"""
new_state_dict: STATE_DICT_TYPE = {}
def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
if not isinstance(value, ShardedTensor):
set_element(new_state_dict, path, value)
return
shards = value.local_shards()
if len(shards) == 0:
return
if len(shards) != 1:
set_element(new_state_dict, path, value)
return
outer_shard = shards[0]
inner_st = outer_shard.tensor
if not isinstance(inner_st, ShardedTensor):
set_element(new_state_dict, path, value)
return
if len(inner_st.local_shards()) != 1:
raise ValueError("Cannot handle inner tensor with more than 1 shard")
inner_shard = inner_st.local_shards()[0]
local_shards = [
Shard(
tensor=inner_shard.tensor,
metadata=ShardMetadata(
shard_offsets=_element_wise_add(
outer_shard.metadata.shard_offsets,
inner_shard.metadata.shard_offsets,
),
shard_sizes=inner_shard.metadata.shard_sizes,
placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}",
),
)
]
st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata())
other_rank = 0 if dist.get_rank() > 0 else 1
device_info = _normalize_device_info(inner_shard.tensor.device.type, 0)
# Remove the outer ST shard the inner ST covers
for i, shard_md in enumerate(st_meta.shards_metadata):
if shard_md.shard_offsets == outer_shard.metadata.shard_offsets:
st_meta.shards_metadata.pop(i)
break
# Attribute other rank for the other shards
for shard_md in st_meta.shards_metadata:
shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}")
# Add other inner shards from the inner tensor
for inner_md in inner_st.metadata().shards_metadata:
if inner_md.shard_offsets != inner_shard.metadata.shard_offsets:
st_meta.shards_metadata.append(
ShardMetadata(
shard_offsets=_element_wise_add(
outer_shard.metadata.shard_offsets,
inner_md.shard_offsets,
),
shard_sizes=inner_md.shard_sizes,
placement=f"rank:{other_rank}/{device_info}",
)
)
# Finally add this shard
st_meta.shards_metadata.append(local_shards[0].metadata)
st = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=st_meta,
)
set_element(new_state_dict, path, st)
traverse_state_dict(state_dict, rewrite_dict)
return new_state_dict

View File

@ -0,0 +1,49 @@
import os
from typing import List, Type, Union
from .filesystem import FileSystemReader, FileSystemWriter
from .storage import StorageReader, StorageWriter
def _storage_setup(
storage: Union[StorageReader, StorageWriter, None],
checkpoint_id: Union[str, os.PathLike, None],
reader: bool = False,
) -> Union[None, StorageReader, StorageWriter]:
if storage:
if checkpoint_id is not None:
storage.reset(checkpoint_id)
return storage
if not checkpoint_id:
raise RuntimeError(
"`checkpoint_id` must be specificed if "
"storage_reader/storage_writer is None."
)
targets: List[Type[Union[StorageReader, StorageWriter]]] = []
if reader:
targets = [
FileSystemReader,
]
else:
targets = [
FileSystemWriter,
]
try:
from ._fsspec_filesystem import FsspecReader, FsspecWriter
targets.append(FsspecReader if reader else FsspecWriter)
except Exception:
pass
for target in targets:
if target.validate_checkpoint_id(checkpoint_id):
storage = target(checkpoint_id) # type: ignore[call-arg]
storage.reset(checkpoint_id)
return storage
raise RuntimeError(
"Cannot detect which StorageReader or StorageWriter to use. "
"Please specify the storage_reader/storage_writer."
)

View File

@ -0,0 +1,208 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import (
Callable,
cast,
Collection,
List,
Mapping,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
)
import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.tensor import DTensor
PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
T = TypeVar("T")
STATE_DICT_ITEM = object
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
return isinstance(value, torch.Tensor)
# TODO: update docstring for traverse.py
def traverse_state_dict(
state_dict: STATE_DICT_TYPE,
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
) -> None:
"""
Invoke ``visitor`` for each value recursively in ``state_dict``.
Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
``visitor`` will only be applied to elements in a list or a tuple, if the
container contains tensors or mappings.
"""
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]
if isinstance(value, Mapping):
return False
elif isinstance(value, list):
values = value
else:
return True
for entry in values:
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
return False
if keep_traversing is not None and keep_traversing(entry):
return False
return True
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
if isinstance(value, Mapping):
for k, v in value.items():
_traverse_obj(path + (str(k),), v)
elif _is_terminal(value):
visitor(path, value)
elif isinstance(value, (list, tuple)):
for i, v in enumerate(value):
_traverse_obj(path + (i,), v)
for key, value in state_dict.items():
_traverse_obj((str(key),), value)
def traverse_state_dict_v_2_3(
state_dict: STATE_DICT_TYPE,
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
) -> None:
"""
Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
to false for all elements.
By default, all collections with at least one ``torch.Tensor`` element are traversed.
Visitor takes a path argument that is a tuple of the keys used to reach it.
"""
# a value is terminal if it has no other containers values inside it
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]
if isinstance(value, Mapping):
values = value.values()
elif isinstance(value, list):
values = value
else:
return True
for entry in values:
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
return False
if keep_traversing is not None and keep_traversing(entry):
return False
return True
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
if _is_terminal(value):
visitor(path, value)
elif isinstance(value, Mapping):
for k, v in value.items():
_traverse_obj(path + (str(k),), v)
elif isinstance(value, list):
for i, v in enumerate(value):
_traverse_obj(path + (i,), v)
for key, value in state_dict.items():
_traverse_obj((str(key),), value)
def set_element(
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
) -> None:
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
cur_container = cast(CONTAINER_TYPE, root_dict)
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
while len(lst) <= idx:
lst.append(None)
for i in range(1, len(path)):
prev_key = path[i - 1]
key = path[i]
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
if isinstance(cur_container, Mapping):
cur_container = cast(
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val
cur_container = cur_container[prev_key]
key = path[-1]
if type(key) == int:
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value
def get_element(
root_dict: STATE_DICT_TYPE,
path: OBJ_PATH,
default_value: Optional[T] = None,
) -> Optional[T]:
"""Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
cur_value = cast(CONTAINER_TYPE, root_dict)
for part in path:
if type(part) is int:
if not isinstance(cur_value, list) or len(cur_value) < part:
return default_value
elif not isinstance(cur_value, Mapping) or part not in cur_value:
return default_value
cur_value = cast(CONTAINER_TYPE, cur_value[part])
return cast(Optional[T], cur_value)
def _print_nested(
value: STATE_DICT_ITEM,
prefix: str = "",
print_fun: Callable[[str], None] = print,
) -> None:
if type(value) is ShardedTensor:
print_fun(f"{prefix} ShardedTensor size: {value.size()}")
for shard in value.local_shards():
_print_nested(
shard.tensor,
f"{shard.metadata.shard_offsets} ",
print_fun=print_fun,
)
elif type(value) is (DTensor):
print_fun(f"{prefix} DistributedTensor size: {value.size()}")
# TODO: add local offset for _local_tensor in print_nested.
_print_nested(
value._local_tensor,
print_fun=print_fun,
)
elif isinstance(value, torch.Tensor):
print_fun(f"{prefix} Tensor size: {value.size()}")
else:
print_fun(f"{prefix} Type: {type(value)}")
def print_tensor(
path: OBJ_PATH,
value: STATE_DICT_ITEM,
print_fun: Callable[[str], None] = print,
) -> None:
"""
Use this callback with traverse_state_dict to print its content.
By default the content is printed using the builtin ``print`` but this can
be change by passing a different ``print_fun` callable.
"""
_print_nested(value, prefix=str(path), print_fun=print_fun)

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Optional
_derived_version: Optional[str] = None

View File

@ -0,0 +1,43 @@
# mypy: allow-untyped-defs
import traceback as tb
from typing import Any, Dict, Tuple
WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
__all__ = ["CheckpointException"]
def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
return (exc, tb.extract_tb(exc.__traceback__))
def _is_wrapped_exception(obj: Any) -> bool:
if not isinstance(obj, tuple):
return False
if len(obj) != 2:
return False
return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
class CheckpointException(BaseException):
"""Exception raised if failure was detected as part of a checkpoint load or save."""
def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
super().__init__(msg, failures)
self._failures = failures
@property
def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
"""Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
return self._failures
def __str__(self):
str = f"CheckpointException ranks:{self._failures.keys()}\n"
for rank, exc_pair in self._failures.items():
exc, trace = exc_pair
str += f"Traceback (most recent call last): (RANK {rank})\n"
if trace is not None:
str += "".join(tb.format_list(trace))
str += "".join(tb.format_exception_only(type(exc), value=exc))
return str

View File

@ -0,0 +1,546 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
import io
import logging
import operator
from collections import ChainMap
from functools import reduce
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint._nested_dict import (
FLATTEN_MAPPING,
flatten_state_dict,
)
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
STATE_DICT_TYPE,
STORAGE_TYPES,
StorageMeta,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
LoadPlan,
LoadPlanner,
ReadItem,
SavePlan,
SavePlanner,
WriteItem,
WriteItemType,
)
from torch.distributed.checkpoint.planner_helpers import (
_create_default_metadata_only_plan,
_create_read_items,
_create_write_items,
_init_state_dict,
)
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.distributed.tensor import DTensor
from . import _version
logger: logging.Logger = logging.getLogger(__name__)
__all__ = [
"DefaultSavePlanner",
"DefaultLoadPlanner",
"create_default_local_load_plan",
"create_default_global_load_plan",
"create_default_local_save_plan",
"create_default_global_save_plan",
]
# TODO: Update docstrings for default_planner.py
class DefaultSavePlanner(SavePlanner):
mappings: FLATTEN_MAPPING
def __init__(
self,
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
dedup_replicated_tensors: Optional[bool] = None,
dedup_save_to_lowest_rank: bool = False,
) -> None:
self.flatten_state_dict = flatten_state_dict
self.flatten_sharded_tensors = flatten_sharded_tensors
self.mappings = {}
self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank
if dedup_replicated_tensors is not None:
logger.warning(
"DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
"deprecated, and no longer has any effect. Please remove this argument "
"from your call."
)
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
storage_meta: Optional[StorageMeta] = None,
is_coordinator: bool = False,
) -> None:
if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)
if self.flatten_sharded_tensors:
state_dict = _flatten_sharded_tensors(state_dict)
self.state_dict = state_dict
self.is_coordinator = is_coordinator
def create_local_plan(self) -> SavePlan:
plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
if self.flatten_state_dict:
plan = dataclasses.replace(plan, planner_data=self.mappings)
self.plan = plan
return self.plan
def create_global_plan(
self, all_plans: List[SavePlan]
) -> Tuple[List[SavePlan], Metadata]:
all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
global_plan, metadata = create_default_global_save_plan(all_plans)
if self.flatten_state_dict:
# | does not work for Python 3.8 or older version.
# merged_mappings = reduce(
# lambda x, y: x | y, (p.planner_data for p in global_plan)
# )
planner_data_dict = [p.planner_data for p in global_plan]
merged_mappings = dict(ChainMap(*planner_data_dict))
metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
if not _validate_global_plan(global_plan, metadata):
raise ValueError("Failed to validate global plan")
self.global_plan = global_plan
self.metadata = metadata
return self.global_plan, self.metadata
def finish_plan(self, new_plan: SavePlan) -> SavePlan:
self.plan = new_plan
return new_plan
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
object = self.lookup_object(write_item.index)
return self.transform_object(write_item, object)
def lookup_object(self, index: MetadataIndex) -> Any:
"""Extension from the planner interface to make it easy to extend the default planner."""
return find_state_dict_object(self.state_dict, index)
def transform_object(self, write_item: WriteItem, object: Any):
"""Extension from the planner interface to make it easy to extend the default planner."""
if write_item.type == WriteItemType.BYTE_IO:
bytes = io.BytesIO()
torch.save(object, bytes)
object = bytes
return object
class DefaultLoadPlanner(LoadPlanner):
"""
DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
In particular it adds the following:
flatten_state_dict: Handle state_dict with nested dicts
flatten_sharded_tensors: For FSDP in 2D parallel mode
allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
"""
original_state_dict: STATE_DICT_TYPE
mappings: FLATTEN_MAPPING
def __init__(
self,
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
allow_partial_load: bool = False,
) -> None:
self.flatten_state_dict = flatten_state_dict
self.flatten_sharded_tensors = flatten_sharded_tensors
self.original_state_dict = {}
self.mappings = {}
self.allow_partial_load = allow_partial_load
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Optional[Metadata] = None,
is_coordinator: bool = False,
) -> None:
_init_state_dict(state_dict)
self.original_state_dict = state_dict
if self.flatten_sharded_tensors:
state_dict = _flatten_sharded_tensors(state_dict)
if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)
self.state_dict = state_dict
self.metadata = metadata
self.is_coordinator = is_coordinator
def create_local_plan(self) -> LoadPlan:
assert self.metadata is not None
if self.flatten_state_dict:
# To support checkpoints that are saved before v2.4, we have to
# differentiate if the missing keys are due to old checkpoints.
# The contracts are:
# 1. There are 3 cases when we found a missing key.
# 1.1 Actual missing key, but allow_partial_load is False
# 1.2 Actual missing key, but allow_partial load is True
# 1.3 Old checkpoint, but allow_partial_load is False
# 1.4 Old checkpoint, but allow_partial_load is True
# 2. If we found a missing key, we first convert the keys back to
# the key format of v2.3
# 3. If the previous missing keys are in the v2.3 keys, we assume
# this is a old checkpoint.
# 4. Pass the state_dict to `create_default_local_load_plan()`,
# which has the logic to check missing for allow_partial_load.
# So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
# `create_default_local_load_plan()`. The logic here is to determine
# whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
current_keys = set(self.state_dict.keys())
load_keys = set(self.metadata.state_dict_metadata.keys())
missing_keys = load_keys - current_keys
if missing_keys:
_version._derived_version = "2_3"
old_state_dict, old_mappings = flatten_state_dict(
self.original_state_dict
)
old_keys = set(old_state_dict.keys())
if old_keys & missing_keys:
self.state_dict, self.mappings = old_state_dict, old_mappings
# _derived_version is only used by flatten_state_dict now.
# Set it back to None so that later we can save to a new version.
_version._derived_version = None
return create_default_local_load_plan(
self.state_dict, self.metadata, not self.allow_partial_load
)
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
return create_default_global_load_plan(global_plan)
def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
return new_plan
def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
if self.flatten_state_dict:
set_element(
self.original_state_dict,
self.mappings[read_item.dest_index.fqn],
torch.load(value, weights_only=False),
)
else:
self.state_dict[read_item.dest_index.fqn] = torch.load(
value, weights_only=False
)
def resolve_tensor(self, read_item: ReadItem):
tensor = self.lookup_tensor(read_item.dest_index)
return self.transform_tensor(read_item, tensor)
def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
pass
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
"""Extension from the planner interface to make it easy to extend the default planner."""
return find_state_dict_object(self.state_dict, index)
def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
"""Extension from the planner interface to make it easy to extend the default planner."""
return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
"""
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
Useful for loading in state_dict without first initializing a model, such as
when converting a DCP checkpoint into a Torch save file.
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
.. warning::
Because the entire state dict is initialized, It's recommended to only utilize
this LoadPlanner on a single rank or process to avoid OOM.
"""
def __init__(self, keys=None, *args, **kwargs):
self.keys = keys
super().__init__(*args, **kwargs)
def _should_include_key(self, key: str, metadata: Metadata) -> bool:
if self.keys is None:
return True
if key in self.keys:
True
unflattened_keys: List[str] = []
planner_data = metadata.planner_data.get(key)
for unflattened_key in planner_data:
if unflattened_keys:
unflattened_keys.append(
".".join([unflattened_keys[-1], str(unflattened_key)])
)
else:
unflattened_keys.append(unflattened_key)
if any(unflattened_key in self.keys for unflattened_key in unflattened_keys):
return True
return False
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Optional[Metadata] = None,
is_coordinator: bool = False,
) -> None:
assert not state_dict
assert metadata is not None
# rebuild the state dict from the metadata
for k, v in metadata.state_dict_metadata.items():
if not self._should_include_key(k, metadata):
continue
if isinstance(v, TensorStorageMetadata):
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
if k in metadata.planner_data:
set_element(state_dict, metadata.planner_data[k], v)
else:
state_dict[k] = v
super().set_up_planner(state_dict, metadata, is_coordinator)
def create_default_local_load_plan(
state_dict: Dict[str, Any], metadata: Metadata, strict: bool = True
) -> LoadPlan:
requests = []
"""
Create the ``LoadPlan`` used by DefaultLoadPlanner.
It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
The default behavior is to match key exactly between state_dict and metadata.
It handles resharding by issuing multiple read requests against storage in order to match
load requirements.
"""
for fqn, obj in state_dict.items():
# ignore state_dict keys which do not exist in `state_dict` if strict=False
if fqn not in metadata.state_dict_metadata:
if strict:
raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
else:
continue
md = metadata.state_dict_metadata[fqn]
# Since DTensor supports submesh, adding extra check to ensure _create_read_items()
# gets called only when the current rank is part of the mesh for the corresponding DTensor.
if isinstance(obj, DTensor):
if obj.device_mesh.get_coordinate() is not None:
requests += _create_read_items(fqn, md, obj)
else:
requests += _create_read_items(fqn, md, obj)
return LoadPlan(requests)
def create_default_global_load_plan(
all_plans: List[LoadPlan],
) -> List[LoadPlan]:
"""
Create global load plan used by DefaultLoadPlanner.
The default load behavior involved no global coordination and this function
currently doesn't change the local plans.
"""
return all_plans
def create_default_local_save_plan(
state_dict: Dict[str, Any], is_coordinator: bool
) -> SavePlan:
"""
Create the ``SavePlan`` used by DefaultSavePlanner.
On non-coordinator ranks, this function ignores tensors and non-tensor objects,
only producing writes for ShardedTensor objects.
On the coordinator rank, produce writes for all values.
"""
requests = []
for fqn, obj in state_dict.items():
# Since DTensor supports submesh, adding extra check to ensure _create_write_items()
# gets called only when the current rank is part of the mesh for the corresponding DTensor.
if isinstance(obj, DTensor):
if obj.device_mesh.get_coordinate() is not None:
requests += _create_write_items(fqn, obj)
else:
# For the plain tensor and non-tensor values, add the request for all
# the ranks. Coordinator will decides whether to deduplicate the
# values based on the keys.
requests += _create_write_items(fqn, obj)
return SavePlan(requests)
def create_default_global_save_plan(
all_plans: List[SavePlan],
rewrite_index_hints: bool = True,
) -> Tuple[List[SavePlan], Metadata]:
"""
Create the global plan and metadata used by DefaultSavePlanner.
Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.
The only global planning change is to update index hints in all ``MetadataIndex`` objects if
``rewrite_index_hints`` is True.
"""
md: Dict[str, STORAGE_TYPES] = {}
new_plans = []
for plan in all_plans:
new_items = []
for item in plan.items:
if not item.type == WriteItemType.SHARD:
assert item.index.fqn not in md
if item.type == WriteItemType.BYTE_IO:
md[item.index.fqn] = BytesStorageMetadata()
new_items.append(item)
else:
assert item.tensor_data is not None
tensor_md = cast(
TensorStorageMetadata,
md.setdefault(
item.index.fqn,
TensorStorageMetadata(
properties=item.tensor_data.properties,
size=item.tensor_data.size,
chunks=[],
),
),
)
new_item = item
if rewrite_index_hints:
new_index = dataclasses.replace(
item.index, index=len(tensor_md.chunks)
)
new_item = dataclasses.replace(item, index=new_index)
new_items.append(new_item)
assert (
item.tensor_data.chunk is not None
), f"""
Cannot create MD for tensor without bounds.
FQN: {item.index.fqn}
"""
tensor_md.chunks.append(item.tensor_data.chunk)
new_plans.append(dataclasses.replace(plan, items=new_items))
return (new_plans, Metadata(md))
def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
"""Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``."""
plan = _create_default_metadata_only_plan(state_dict)
_, md = create_default_global_save_plan([plan])
return md
def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
"""Check if two boxes overlap. Tuples are (offset, lengths)."""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
# shard, we would check if one shard is above or on the left of the
# other shard.
ndims = len(box0.offsets)
for i in range(ndims):
if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
return False
if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
return False
return True
def _check_box_bounds(
outer_box_size: torch.Size, inner_box: ChunkStorageMetadata
) -> bool:
for i in range(len(outer_box_size)):
if inner_box.offsets[i] < 0:
return False
if inner_box.sizes[i] < 0:
return False
if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
return False
return True
def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
all_good = True
for key, value in metadata.state_dict_metadata.items():
if isinstance(value, BytesStorageMetadata):
continue
if len(value.size) == 0:
continue
chunks_volume = 0
for chunk_idx, chunk0 in enumerate(value.chunks):
# Compute the volume
if not _check_box_bounds(value.size, chunk0):
logger.warning(
"""
key:%s has out of bounds chunk:
tensor-size:%s chunk: %s
""",
key,
value.size,
chunk0,
)
all_good = False
chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
# Check for overlap
for chunk1 in value.chunks[chunk_idx + 1 :]:
if _check_box_overlap(chunk0, chunk1):
logger.warning(
"key:%s has overlapping chunks: %s %s", key, chunk0, chunk1
)
all_good = False
# Check whether combined chunk cover the whole tensor
tensor_volume = reduce(operator.mul, value.size, 1)
if chunks_volume != tensor_volume:
logger.warning(
"""
key:%s invalid fill tensor-volume:
%s chunks-volume: %s
""",
key,
tensor_volume,
chunks_volume,
)
all_good = False
return all_good

View File

@ -0,0 +1,765 @@
# mypy: allow-untyped-defs
import collections
import dataclasses
import io
import operator
import os
import pickle
import queue
import threading
import uuid
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
IO,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)
import torch
from torch import Tensor
from torch._utils import _get_available_device_type, _get_device_module
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.metadata import (
Metadata,
MetadataIndex,
STATE_DICT_TYPE,
StorageMeta,
)
from torch.distributed.checkpoint.planner import (
LoadItemType,
LoadPlan,
LoadPlanner,
ReadItem,
SavePlan,
SavePlanner,
WriteItem,
WriteItemType,
)
from torch.distributed.checkpoint.staging import BlockingAsyncStager
from torch.distributed.checkpoint.storage import (
StorageReader,
StorageWriter,
WriteResult,
)
from torch.distributed.checkpoint.utils import _create_file_view
from torch.futures import Future
__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"]
_metadata_fn: str = ".metadata"
@dataclass
class _StorageInfo:
"""This is the per entry storage info."""
relative_path: str
offset: int
length: int
@dataclass
class _StoragePrefix:
prefix: str
DEFAULT_SUFFIX = ".distcp"
def _generate_uuid() -> str:
return str(uuid.uuid4())
class _TensorLoader(ABC):
@abstractmethod
def add(self, size: int, obj: object) -> None:
pass
@abstractmethod
def start_loading(self) -> None:
pass
@abstractmethod
def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
pass
class _SerialCpuLoader(_TensorLoader):
def __init__(self, resolve_fun: Callable) -> None:
self.resolve_fun = resolve_fun
self.items: List[Tuple[int, object]] = []
def add(self, size: int, obj: object) -> None:
self.items.append((size, obj))
def start_loading(self) -> None:
pass
def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
for _, obj in self.items:
tensor = self.resolve_fun(obj).detach()
tensor = tensor.cpu()
if tensor.storage().size() != tensor.numel():
tensor = tensor.clone()
yield (
tensor,
obj,
)
class _OverlappingCpuLoader(_TensorLoader):
def __init__(
self,
resolve_fun: Callable,
stream: Optional[torch.Stream] = None,
inflight_threshhold: int = 1_000_000,
) -> None:
self.resolve_fun = resolve_fun
self.items: List[Tuple[int, object]] = []
self.inflight_threshhold = inflight_threshhold
self.in_flight_data = 0
self.current_items: collections.deque = collections.deque()
self.idx = 0
self.started = False
self.device_type = (
stream.device_type if stream else _get_available_device_type()
)
self.device_module = _get_device_module(self.device_type)
self.stream = cast(
torch.cuda.Stream, stream or self.device_module.current_stream()
)
if self.stream != self.device_module.current_stream():
self.stream.wait_stream(self.device_module.current_stream())
@property
def _done(self) -> bool:
return self.idx >= len(self.items)
def _drain(self) -> List[Tuple[torch.Tensor, object]]:
drained = []
if self.in_flight_data >= self.inflight_threshhold:
self.stream.synchronize()
while self.in_flight_data >= self.inflight_threshhold:
val = self.current_items.popleft()
self.in_flight_data -= val[0].numel() * val[0].element_size()
drained.append(val)
return drained
def _refill(self) -> None:
with self.device_module.stream(self.stream):
while not self._done and self.in_flight_data < self.inflight_threshhold:
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
if tensor.device.type == self.device_type:
tensor = tensor.to(device="cpu", non_blocking=True)
elif tensor.device == torch.device("cpu"):
if (
tensor.untyped_storage().size()
!= tensor.numel() * tensor.itemsize
):
# this forces the tensor to be both contiguous and with minimal storage
tensor = tensor.clone()
self.current_items.append(
(
tensor,
obj,
)
)
self.in_flight_data += tensor.numel() * tensor.element_size()
def _finish(self) -> Iterable[Tuple[torch.Tensor, object]]:
assert self._done
if len(self.current_items) > 0:
self.stream.synchronize()
return self.current_items
def add(self, size: int, obj: object) -> None:
if self.started:
raise RuntimeError("cannot add items after loading started")
self.items.append((size, obj))
def start_loading(self) -> None:
if self.started:
return
self.started = True
self.items.sort(key=operator.itemgetter(0))
self._refill()
def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
self.start_loading()
while not self._done:
drained = self._drain()
self._refill()
yield from drained
yield from self._finish()
def _item_size(item: WriteItem) -> int:
size = 1
assert item.tensor_data is not None
# can't use math.prod as PT needs to support older python
for s in item.tensor_data.size:
size *= s
dtype = item.tensor_data.properties.dtype
return size * torch._utils._element_size(dtype)
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
if bins == 1:
return [items]
bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_w.sort(key=_item_size, reverse=True)
for i, wi in enumerate(bytes_w):
buckets[i % bins].append(wi)
for wi in tensor_w:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0]
buckets[idx].append(wi)
bucket_sizes[idx] += _item_size(wi)
return buckets
def _write_item(
stream: io.IOBase,
data: Union[io.BytesIO, torch.Tensor],
write_item: WriteItem,
storage_key: str,
) -> WriteResult:
offset = stream.tell()
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
stream.write(data.getbuffer())
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
torch.save(data, cast(IO[bytes], stream))
length = stream.tell() - offset
return WriteResult(
index=write_item.index,
size_in_bytes=length,
storage_data=_StorageInfo(storage_key, offset, length),
)
def _write_files_from_queue(
create_stream: Callable,
file_queue: queue.Queue,
result_queue: queue.Queue,
planner: SavePlanner,
inflight_threshhold: int,
use_fsync: bool,
thread_count: int,
) -> None:
try:
while True:
file_name, storage_key, write_items = file_queue.get_nowait()
loader: _TensorLoader
custom_backend_name = torch._C._get_privateuse1_backend_name()
custom_device_mod = getattr(torch, custom_backend_name, None)
# TODO: Using the OverlappingCpuLoader with multiple threads creates significant
# performance degredation, observed as being related to cuda stream syncs. We
# should try to fix this and use _OverlappingCpuLoader for all threaded cases
if (
thread_count == 1
and (
torch.cuda.is_available()
or (custom_device_mod and custom_device_mod.is_available())
)
and inflight_threshhold > 0
):
loader = _OverlappingCpuLoader(
planner.resolve_data,
inflight_threshhold=inflight_threshhold,
)
else:
loader = _SerialCpuLoader(
planner.resolve_data,
)
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
for write_item in tensor_w:
loader.add(_item_size(write_item), write_item)
loader.start_loading()
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
write_results = []
with create_stream(file_name, "wb") as stream:
for write_item in bytes_w:
data = planner.resolve_data(write_item)
write_results.append(
_write_item(stream, data, write_item, storage_key)
)
for tensor, write_item in loader.values():
assert tensor.is_cpu
write_results.append(
_write_item(stream, tensor, write_item, storage_key)
)
if use_fsync:
try:
os.fsync(stream.fileno())
except AttributeError:
os.sync()
result_queue.put(write_results)
except queue.Empty:
pass
class FileSystemBase(ABC):
@contextmanager
@abstractmethod
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
...
@abstractmethod
def concat_path(
self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]:
...
@abstractmethod
def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
...
@abstractmethod
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
...
@abstractmethod
def mkdir(self, path: Union[str, os.PathLike]) -> None:
...
@classmethod
@abstractmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
...
@abstractmethod
def exists(self, path: Union[str, os.PathLike]) -> bool:
...
@abstractmethod
def rm_file(self, path: Union[str, os.PathLike]) -> None:
...
class FileSystem(FileSystemBase):
@contextmanager
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
with cast(Path, path).open(mode) as stream:
yield cast(io.IOBase, stream)
def concat_path(
self, path: Union[str, os.PathLike], suffix: str
) -> Union[str, os.PathLike]:
return cast(Path, path) / suffix
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
if not isinstance(path, Path):
path = Path(path)
return path
def rename(
self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
) -> None:
cast(Path, path).rename(cast(Path, new_path))
def mkdir(self, path: Union[str, os.PathLike]) -> None:
cast(Path, path).mkdir(parents=True, exist_ok=True)
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
if isinstance(checkpoint_id, Path):
return True
if "://" in str(checkpoint_id):
return False
for p in Path(checkpoint_id).parents:
if p.exists() and os.access(str(p), os.W_OK):
return True
return False
def exists(self, path: Union[str, os.PathLike]) -> bool:
return cast(Path, path).exists()
def rm_file(self, path: Union[str, os.PathLike]) -> None:
cast(Path, path).unlink()
class _FileSystemWriter(StorageWriter):
"""
Basic implementation of StorageWriter using file IO.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
overwrite: bool = True,
*args: Any,
**kwargs: Any,
) -> None:
"""
Initialize the writer pointing to `path`.
Args:
path: directory where the checkpoint will be written to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__()
self.fs = FileSystem()
self.path = self.fs.init_path(path)
self.single_file_per_rank = single_file_per_rank
self.sync_files = sync_files
self.thread_count = thread_count
self.per_thread_copy_ahead = per_thread_copy_ahead
self.save_id = _generate_uuid()
self.overwrite = overwrite
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
if checkpoint_id:
self.path = self.fs.init_path(checkpoint_id)
self.save_id = _generate_uuid()
def set_up_storage_writer(self, is_coordinator: bool) -> None:
pass
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
self.fs.mkdir(self.path)
if self.fs.exists(self.metadata_path):
if self.overwrite:
warnings.warn(
f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}."
" Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to"
" maintain this functionality or False to raise when an existing checkpoint is found."
)
else:
raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.")
return plan
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
for i, plan in enumerate(plans)
]
return new_plans
def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
) -> Future[List[WriteResult]]:
storage_plan: _StoragePrefix = plan.storage_data
file_count = 0
def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
return file_name
file_queue: queue.Queue = queue.Queue()
if self.single_file_per_rank:
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
file_name = gen_file()
path = self.fs.concat_path(self.path, file_name)
file_queue.put((path, file_name, bucket))
else:
for item in plan.items:
file_name = gen_file()
path = self.fs.concat_path(self.path, file_name)
file_queue.put((path, file_name, [item]))
result_queue: queue.Queue = queue.Queue()
threads = []
for _ in range(1, self.thread_count):
t = threading.Thread(
target=_write_files_from_queue,
args=(
self.fs.create_stream,
file_queue,
result_queue,
planner,
self.per_thread_copy_ahead,
self.sync_files,
self.thread_count,
),
)
t.start()
threads.append(t)
_write_files_from_queue(
create_stream=self.fs.create_stream,
file_queue=file_queue,
result_queue=result_queue,
planner=planner,
inflight_threshhold=self.per_thread_copy_ahead,
use_fsync=self.sync_files,
thread_count=self.thread_count,
)
for t in threads:
t.join()
res = []
try:
while True:
res += result_queue.get_nowait()
except queue.Empty:
fut: Future[List[WriteResult]] = Future()
fut.set_result(res)
return fut
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
storage_md = {}
for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
metadata.storage_data = storage_md
metadata.storage_meta = self.storage_meta()
tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp"))
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
if self.sync_files:
try:
os.fsync(metadata_file.fileno())
except AttributeError:
os.sync()
# delete in-case other checkpoints were present.
if self.fs.exists(self.metadata_path):
self.fs.rm_file(self.metadata_path)
self.fs.rename(tmp_path, self.metadata_path)
def storage_meta(self) -> Optional[StorageMeta]:
return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id)
@property
def metadata_path(self) -> Union[str, os.PathLike]:
return cast(Path, self.fs.concat_path(self.path, _metadata_fn))
@property
def checkpoint_id(self) -> Union[str, os.PathLike]:
"""
return the checkpoint_id that will be used to save the checkpoint.
"""
return self.path
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)
class FileSystemReader(StorageReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
super().__init__()
self.fs = FileSystem()
self.path = self.fs.init_path(path)
self.storage_data: Dict[MetadataIndex, _StorageInfo] = {}
self.load_id = _generate_uuid()
def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase:
return _create_file_view(file, sinfo.offset, sinfo.length)
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
self.storage_data = {}
if checkpoint_id:
self.path = self.fs.init_path(checkpoint_id)
self.load_id = _generate_uuid()
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
# group requests by file
per_file: Dict[str, List[ReadItem]] = {}
for read_item in plan.items:
item_md = self.storage_data[read_item.storage_index]
path = item_md.relative_path
per_file.setdefault(path, []).append(read_item)
for relative_path, reqs in per_file.items():
new_path = self.fs.concat_path(self.path, relative_path)
with self.fs.create_stream(new_path, "rb") as stream:
# TODO sort by offset and cache the reading
for req in reqs:
item_md = self.storage_data[req.storage_index]
file_slice = self._slice_file(stream, item_md)
if req.type == LoadItemType.BYTE_IO:
read_bytes = io.BytesIO(file_slice.read(item_md.length))
read_bytes.seek(0)
planner.load_bytes(req, read_bytes)
else:
tensor = cast(
Tensor,
torch.load(
cast(IO[bytes], file_slice),
map_location="cpu",
weights_only=True,
),
)
tensor = narrow_tensor_by_index(
tensor, req.storage_offsets, req.lengths
)
target_tensor = planner.resolve_tensor(req).detach()
assert (
target_tensor.size() == tensor.size()
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
fut: Future = Future()
fut.set_result(None)
return fut
# Implementing the abstract function in StorageReader
def read_metadata(self) -> Metadata:
path = self.fs.concat_path(self.path, ".metadata")
with self.fs.create_stream(path, "rb") as metadata_file:
metadata = pickle.load(metadata_file)
if getattr(metadata, "storage_meta", None) is None:
metadata.storage_meta = StorageMeta()
metadata.storage_meta.load_id = self.load_id
return metadata
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
self.storage_data = metadata.storage_data
assert self.storage_data is not None
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
return plans
@property
def checkpoint_id(self) -> Union[str, os.PathLike]:
"""
return the checkpoint_id that will be used to load the checkpoint.
"""
return self.path
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return FileSystem.validate_checkpoint_id(checkpoint_id)
class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
"""
Basic implementation of StorageWriter using file IO.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
cache_staged_state_dict: bool = False,
overwrite: bool = True,
) -> None:
"""
Initialize the writer pointing to `path`.
Args:
path: directory where the checkpoint will be written to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files : force files to be synced to permanent storage. Default to True.
thread_count: Number of IO threads to use to write. Default to 1.
per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__(
path=path,
single_file_per_rank=single_file_per_rank,
sync_files=sync_files,
thread_count=thread_count,
per_thread_copy_ahead=per_thread_copy_ahead,
cache_staged_state_dict=cache_staged_state_dict,
overwrite=overwrite,
)
def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""Override of AsyncStager.stage"""
# in the async case, the state dict is already on CPU, so maintaining this
# buffer makes no sense
self.per_thread_copy_ahead = 0
return super().stage(state_dict)

View File

@ -0,0 +1,280 @@
# mypy: allow-untyped-defs
import argparse
import os
from enum import Enum
from typing import cast, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint.default_planner import (
_EmptyStateDictLoadPlanner,
DefaultLoadPlanner,
)
from torch.distributed.checkpoint.metadata import (
Metadata,
STATE_DICT_TYPE,
STORAGE_TYPES,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.futures import Future
__all__ = [
"dcp_to_torch_save",
"torch_save_to_dcp",
"BroadcastingTorchSaveReader",
"DynamicMetaLoadPlanner",
]
class BroadcastingTorchSaveReader(StorageReader):
"""
StorageReader for reading a Torch Save file. This reader will read the entire checkpoint
on the coordinator rank, and then broadcast and shard each tensor to all ranks.
. N.B. Intended to be used with DynamicMetaLoadPlanner
.. warning::
Current implementation only supports loading Tensors.
>>> # xdoctest: +SKIP("undefined vars")
>>> sd = {"mode": model}
>>> dcp.load(
>>> sd,
>>> storage_reader=BroadcastingTorchSaveReader(),
>>> planner=DynamicMetaLoadPlanner(),
>>> checkpoint_id="path_to_model.pt"
>>> )
"""
def __init__(
self,
checkpoint_id: Optional[Union[str, os.PathLike]] = None,
coordinator_rank: int = 0,
) -> None:
self.checkpoint_id = checkpoint_id
self.coordinator_rank = coordinator_rank
def read_metadata(self) -> Metadata:
"""Extends the default StorageReader to support building the metadata file"""
# Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
# the disk
return Metadata(state_dict_metadata={})
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
"""
Reads torch save data on the coordinator rank, and broadcast afterwards
this incurrs a communication cost, but avoids having to load
the entire checkpoint on each rank, hopefully preventing OOM issues
"""
planner = cast(DefaultLoadPlanner, planner)
# data is read in on the coordinator rank, and broadcast afterwards
# this incurrs a communication cost, but it avoids having to load
# the entire checkpoint on each rank, hopefully preventing OOM issues
# TODO: read on each host, instead of only the coordinator
if self.is_coordinator:
assert self.checkpoint_id is not None
torch_state_dict = torch.load(
self.checkpoint_id, map_location="cpu", weights_only=False
)
if planner.flatten_state_dict:
torch_state_dict, _ = flatten_state_dict(torch_state_dict)
else:
torch_state_dict = None
for req in plan.items:
if req.type == LoadItemType.BYTE_IO:
raise RuntimeError(
f"Non-tensor value identified at {req.storage_index.fqn}. "
f"At this time {type(self).__name__} only supports loading Tensors."
)
# Broadcast the tensor from the coordinator rank
if self.is_coordinator:
pg_device = dist.distributed_c10d._get_pg_default_device()
tensor = torch_state_dict[req.storage_index.fqn].to(pg_device)
else:
tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
dist.broadcast(tensor, src=self.coordinator_rank, async_op=False)
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
target_tensor = planner.resolve_tensor(req).detach()
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes, "
f"{target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
fut: Future = Future()
fut.set_result(None)
return fut
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
"""Implementation of the StorageReader method"""
self.is_coordinator = is_coordinator
if self.is_coordinator:
assert dist.get_rank() == self.coordinator_rank
assert self.checkpoint_id is not None
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
"""Implementation of the StorageReader method"""
return plan
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
"""Implementation of the StorageReader method"""
return global_plan
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
"""Implementation of the StorageReader method"""
self.checkpoint_id = checkpoint_id
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
"""Implementation of the StorageReader method"""
return os.path.isfile(checkpoint_id)
class DynamicMetaLoadPlanner(DefaultLoadPlanner):
"""
Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,
avoiding the need to read metadata from disk. This is useful when reading formats which don't have a
metadata file, like Torch Save files.
. N.B. Intended to be used with BroadcastingTorchSaveReader
.. warning::
Current implementation only supports loading Tensors.
>>> # xdoctest: +SKIP("undefined vars")
>>> sd = {"mode": model}
>>> dcp.load(
>>> sd,
>>> storage_reader=BroadcastingTorchSaveReader(),
>>> planner=DynamicMetaLoadPlanner(),
>>> checkpoint_id="path_to_model.pt"
>>> )
"""
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Optional[Metadata] = None,
is_coordinator: bool = False,
) -> None:
"""Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
super().set_up_planner(state_dict, metadata, is_coordinator)
state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
for key, tensor in self.state_dict.items():
if not torch.is_tensor(tensor):
raise RuntimeError(
f"Non-tensor value identified at {key}. "
f"At this time {type(self).__name__} only supports loading Tensors."
)
state_dict_metadata[key] = TensorStorageMetadata(
TensorProperties(dtype=tensor.dtype),
tensor.size(),
_create_chunk_list(tensor),
)
self.metadata = Metadata(state_dict_metadata=state_dict_metadata)
def dcp_to_torch_save(
dcp_checkpoint_dir: Union[str, os.PathLike],
torch_save_path: Union[str, os.PathLike],
):
"""
Given a directory containing a DCP checkpoint, this function will convert it into a
Torch save file.
Args:
dcp_checkpoint_dir: Directory containing the DCP checkpoint.
torch_save_path: Filename to store the converted Torch save file.
.. warning::
To avoid OOM, it's recommended to only run this function on a single rank.
"""
sd: STATE_DICT_TYPE = {}
_load_state_dict(
sd,
storage_reader=FileSystemReader(dcp_checkpoint_dir),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
torch.save(sd, torch_save_path)
def torch_save_to_dcp(
torch_save_path: Union[str, os.PathLike],
dcp_checkpoint_dir: Union[str, os.PathLike],
):
"""
Given the location of a torch save file, converts it into a DCP checkpoint.
Args:
torch_save_path: Filename of the Torch save file.
dcp_checkpoint_dir: Directory to store the DCP checkpoint.
.. warning::
To avoid OOM, it's recommended to only run this function on a single rank.
"""
state_dict = torch.load(torch_save_path, weights_only=False)
# we don't need stateful behavior here because the expectation is anything loaded by
# torch.load would not contain stateful objects.
_save_state_dict(
state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True
)
if __name__ == "__main__":
class FormatMode(Enum):
TORCH_TO_DCP = "torch_to_dcp"
DCP_TO_TORCH = "dcp_to_torch"
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"mode",
type=str,
help="Conversion mode",
choices=[m.value for m in FormatMode],
default=FormatMode.TORCH_TO_DCP,
)
parser.add_argument("src", type=str, help="Path to the source model")
parser.add_argument("dst", type=str, help="Path to the destination model")
args = parser.parse_args()
print(
f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'"
)
checkpoint_missing_warning = (
f"No checkpoint found at {args.src}. Skipping conversion."
)
if args.mode == FormatMode.TORCH_TO_DCP.value:
if os.path.isfile(args.src):
torch_save_to_dcp(args.src, args.dst)
else:
print(checkpoint_missing_warning)
elif args.mode == FormatMode.DCP_TO_TORCH.value:
if os.path.isdir(args.src):
dcp_to_torch_save(args.src, args.dst)
else:
print(checkpoint_missing_warning)
else:
raise ValueError(f"Unknown conversion mode: {args.mode}")

View File

@ -0,0 +1,103 @@
# mypy: allow-untyped-defs
import functools
import time
from typing import Any, Callable, Dict, List, TypeVar
from typing_extensions import ParamSpec
from uuid import uuid4
import torch.distributed.c10d_logger as c10d_logger
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
__all__: List[str] = []
global _dcp_logger
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
_T = TypeVar("_T")
_P = ParamSpec("_P")
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
"""
Extracts log data from dcp method args
"""
msg_dict = {}
# checkpoint ID can be passed in through the serializer or through the checkpoint id directly
storage_writer = kwargs.get("storage_writer", None)
storage_reader = kwargs.get("storage_reader", None)
planner = kwargs.get("planner", None)
checkpoint_id = kwargs.get("checkpoint_id", None)
if not checkpoint_id and (serializer := storage_writer or storage_reader):
checkpoint_id = getattr(serializer, "checkpoint_id", None)
msg_dict["checkpoint_id"] = (
str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
)
# Uniquely identify a _dcp_method_logger wrapped function call.
msg_dict["uuid"] = str(uuid4().int)
if storage_writer:
msg_dict["storage_writer"] = storage_writer.__class__.__name__
if storage_reader:
msg_dict["storage_reader"] = storage_reader.__class__.__name__
if planner:
msg_dict["planner"] = planner.__class__.__name__
return msg_dict
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
msg_dict.update(c10d_logger._get_msg_dict(func_name, **msg_dict))
return msg_dict
def _dcp_method_logger(
log_exceptions: bool = False, **wrapper_kwargs: Any
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore
"""This method decorator logs the start, end, and exception of wrapped events."""
def decorator(func: Callable[_P, _T]):
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
msg_dict = _get_msg_dict(
func.__name__, *args, **{**wrapper_kwargs, **kwargs}
)
# log start event
msg_dict["event"] = "start"
t0 = time.time_ns()
msg_dict["time"] = t0
msg_dict["log_exceptions"] = log_exceptions
_dcp_logger.debug(msg_dict)
# exceptions
try:
result = func(*args, **kwargs)
except BaseException as error:
if log_exceptions:
msg_dict["event"] = "exception"
msg_dict["error"] = f"{error}"
msg_dict["time"] = time.time_ns()
_dcp_logger.error(msg_dict)
raise
# end event
msg_dict["event"] = "end"
t1 = time.time_ns()
msg_dict["time"] = time.time_ns()
msg_dict["times_spent"] = t1 - t0
_dcp_logger.debug(msg_dict)
return result
return wrapper
return decorator

View File

@ -0,0 +1,15 @@
import logging
from typing import List
from torch.distributed.logging_handlers import _log_handlers
__all__: List[str] = []
DCP_LOGGER_NAME = "dcp_logger"
_log_handlers.update(
{
DCP_LOGGER_NAME: logging.NullHandler(),
}
)

View File

@ -0,0 +1,182 @@
# mypy: allow-untyped-defs
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Union
import torch
from torch.distributed.checkpoint.stateful import StatefulT
__all__ = [
"ChunkStorageMetadata",
"TensorStorageMetadata",
"BytesStorageMetadata",
"Metadata",
"MetadataIndex",
"TensorProperties",
"StorageMeta",
]
@dataclass
class ChunkStorageMetadata:
"""
Each chunk is expected to have the same properties of the TensorStorageMetadata
that includes it.
"""
offsets: torch.Size
sizes: torch.Size
class _MEM_FORMAT_ENCODING(Enum):
"""Describe the memory format of a tensor."""
TORCH_CONTIGUOUS_FORMAT = 0
TORCH_CHANNELS_LAST = 1
TORCH_PRESERVE_FORMAT = 2
@dataclass
class TensorProperties:
"""Properties used to create :class:`Tensor`"""
# Regular tensor fields
dtype: torch.dtype = field(default_factory=torch.get_default_dtype)
# This field is deprecated.
layout: torch.layout = field(default=torch.strided)
# This field is deprecated.
requires_grad: bool = False
# This field is deprecated.
memory_format: torch.memory_format = field(default=torch.contiguous_format)
# This field is deprecated.
pin_memory: bool = False
def __getstate__(self):
# Since torch.memory_format cannot be pickled!
memory_format = self.memory_format
if memory_format == torch.contiguous_format:
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
elif memory_format == torch.channels_last:
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
elif memory_format == torch.preserve_format:
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
else:
raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
return (
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
)
def __setstate__(
self,
state,
):
(
self.dtype,
self.layout,
self.requires_grad,
mem_format_encoding,
self.pin_memory,
) = state
if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
memory_format = torch.contiguous_format
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
memory_format = torch.channels_last
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
memory_format = torch.preserve_format
else:
raise RuntimeError(
f"Invalid torch.memory_format encoding: {mem_format_encoding}"
)
self.memory_format = memory_format
@staticmethod
def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
return TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
)
@dataclass
class TensorStorageMetadata:
properties: TensorProperties
size: torch.Size
chunks: List[ChunkStorageMetadata]
@dataclass
class BytesStorageMetadata:
pass
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
@dataclass
class StorageMeta:
checkpoint_id: Union[str, os.PathLike, None] = None
save_id: Optional[str] = None
load_id: Optional[str] = None
@dataclass
class Metadata:
"""This class represents the metadata of the checkpoint."""
# Keys are the same from the `state_dict` used.
state_dict_metadata: Dict[str, STORAGE_TYPES]
# It is the responsibility of the planner and storage plugins to ensure
# backward compatibility of the planner_data and storage_data. DCP will
# also ensure the backward compatibility of the metadata in this file and
# the metadata of the built-in planner and storage plugins.
planner_data: Any = None
storage_data: Any = None
storage_meta: Optional[StorageMeta] = None
@dataclass(frozen=True)
class MetadataIndex:
"""This class represents a lookup key for items in a state dict or Metadata."""
fqn: str
"""Fully Qualified Name of the object"""
offset: Optional[torch.Size] = None
"""If the object is a tensor, offset into the tensor we're looking for"""
index: Optional[int] = field(hash=False, compare=False, default=None)
"""
Index hint when searching for tensor chunk to speedup lookups (optional)
A common representation of a sharded tensor is as a list of chunks so to
find the index in such a list you need to linear search it.
When constructing an instance of MetadataIndex that points to that list,
one can provide the index as a hint and it will be probed first before
the linear search and thus making it significantly faster.
"""
def __init__(
self,
fqn: str,
offset: Optional[Sequence[int]] = None,
index: Optional[int] = None,
):
# We must use object.__setattr__ due to frozen=True
object.__setattr__(self, "fqn", fqn)
object.__setattr__(self, "index", index)
if offset is not None:
object.__setattr__(self, "offset", torch.Size(offset))

View File

@ -0,0 +1,356 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.metadata import (
TensorProperties as ShardTensorProperties,
)
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
STATE_DICT_TYPE,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
from torch.distributed.checkpoint.planner_helpers import (
_create_read_items,
create_read_items_for_chunk_list,
)
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.distributed.checkpoint.utils import (
_element_wise_add,
_element_wise_sub,
_normalize_device_info,
)
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.remote_device import _remote_device
from torch.distributed.tensor import DTensor
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
# TODO: Update docstrings for optimizer.py
__all__ = [
"load_sharded_optimizer_state_dict",
]
def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
if device_type == "cpu":
return "cpu"
device_module = _get_device_module(device_type)
if device_module.is_available():
return _normalize_device_info(
device_type, global_rank % device_module.device_count()
)
return "cpu"
def _create_colwise_spec(
pg: Optional[dist.ProcessGroup] = None,
) -> ChunkShardingSpec:
pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
if pg is None:
placements = [
f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
for idx in range(dist.get_world_size())
]
else:
placements = [
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
for idx in range(pg.size())
]
return ChunkShardingSpec(
dim=0,
placements=cast(List[Union[_remote_device, str]], placements),
)
def _is_nested_tensor(val: torch.Tensor) -> bool:
if type(val) is ShardedTensor:
if len(val.local_shards()) == 0:
return False
if type(val.local_shards()[0].tensor) is ShardedTensor:
return True
if type(val.local_shards()[0].tensor) is DTensor:
raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
elif type(val) is DTensor and (
type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
):
raise ValueError("Cannot handle nested DTensor")
return False
def _alloc_tensor(
props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
) -> torch.Tensor:
if device_type == "cpu":
device = cast(torch.device, _get_device_module(device_type).current_device())
else:
device = torch.device(
device_type, _get_device_module(device_type).current_device()
)
return torch.empty(
size=size,
dtype=props.dtype,
layout=props.layout,
requires_grad=props.requires_grad,
pin_memory=props.pin_memory,
device=device,
)
def _get_state_dict_2d_layout(
state_dict: STATE_DICT_TYPE,
) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]:
"""
Load the right TP slice of the optimizer state.
This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata.
We take advantage of the model state_dict producing a sliced ST to figure out what we need to load.
This is pretty fragile and it might be easier for FSDP to compute this info for us.
Returns a dictionary where keys are the same of the state_dict and the value is a tuple of
(offset, size) for the current rank TP slice.
N.B. The state_dict *MUST* come from FSDP.sharded_state_dict.
"""
specs: STATE_DICT_2D_LAYOUT = {}
dp_pg: Optional[dist.ProcessGroup] = None
for key, value in state_dict.items():
specs[key] = (None, value.size())
if _is_nested_tensor(value):
assert (
len(value.local_shards()) == 1
), "Cannot handle ST with multiple shards"
assert isinstance(
value, ShardedTensor
), "Can only handle nested ShardedTensor"
shard = value.local_shards()[0]
specs[key] = (
shard.metadata.shard_offsets,
shard.metadata.shard_sizes,
)
dp_pg = shard.tensor._process_group # type: ignore[attr-defined]
return (
specs,
dp_pg,
)
class _ReaderWithOffset(DefaultLoadPlanner):
translation: Dict[MetadataIndex, MetadataIndex]
state_dict: STATE_DICT_TYPE
metadata: Metadata
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
super().__init__()
self.fqn_to_offset = fqn_to_offset
self.metadata = Metadata({})
self.state_dict = {}
self.translation = {}
def create_local_plan(self) -> LoadPlan:
requests = []
self.translation = {}
for fqn, obj in self.state_dict.items():
md = self.metadata.state_dict_metadata[fqn]
if not isinstance(obj, ShardedTensor):
requests += _create_read_items(fqn, md, obj)
continue
if fqn not in self.fqn_to_offset:
requests += _create_read_items(fqn, md, obj)
continue
offset = self.fqn_to_offset[fqn]
assert len(obj.local_shards()) == 1
original_shard = obj.local_shards()[0]
local_chunks = [
ChunkStorageMetadata(
offsets=torch.Size(
_element_wise_add(original_shard.metadata.shard_offsets, offset)
),
sizes=torch.Size(original_shard.metadata.shard_sizes),
)
]
reqs = create_read_items_for_chunk_list(
fqn, cast(TensorStorageMetadata, md), local_chunks
)
# TODO: The ReadItems will have a displaced MetadataIndex, fix it.
# TODO: we should change _create_sharded_read_items to have more ergonomic API
for ri in reqs:
assert ri.dest_index.offset is not None
original_offset = _element_wise_sub(ri.dest_index.offset, offset)
original_index = dataclasses.replace(
ri.dest_index, offset=torch.Size(original_offset)
)
self.translation[ri.dest_index] = original_index
requests += reqs
return LoadPlan(requests)
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
return super().lookup_tensor(self.translation.get(index, index))
def load_sharded_optimizer_state_dict(
model_state_dict: STATE_DICT_TYPE,
optimizer_key: str,
storage_reader: StorageReader,
planner: Optional[LoadPlanner] = None,
) -> STATE_DICT_TYPE:
"""
Load a state_dict in conjunction with FSDP sharded optimizer state.
This is the current recommended way to checkpoint FSDP.
>>> # xdoctest: +SKIP
>>> import torch.distributed.checkpoint as dist_cp
>>> # Save
>>> model: torch.nn.Model
>>> optim_params = model.parameters()
>>> optim = torch.optim.SGD(optim_params, lr=0.01)
>>> # Save
>>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
>>> state_dict = {
>>> "optimizer": FSDP.optim_state_dict(model, optim),
>>> "model": model.state_dict()
>>> }
>>> dist_cp.save_state_dict(
>>> state_dict=optim_state,
>>> storage_writer=dist_cp.FileSystemWriter("checkpoint"),
>>> planner=dist_cp.DefaultSavePlanner(),
>>> )
>>>
>>> # Load
>>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
>>> model_state_dict = model_tp.state_dict()
>>> checkpoint = {
>>> "model": model_state_dict
>>> }
>>> dist_cp.load_state_dict(
>>> state_dict=checkpoint,
>>> storage_reader=dist_cp.FileSystemReader(checkpoint_file),
>>> planner=dist_cp.DefaultLoadPlanner(),
>>> )
>>> model.load_state_dict(checkpoint["model_state"])
>>>
>>> optim_state = dist_cp.load_sharded_optimizer_state_dict(
>>> model_state_dict,
>>> optimizer_key="optimizer",
>>> storage_reader=dist_cp.FileSystemReader("checkpoint"),
>>> )
>>>
>>> flattened_osd = FSDP.optim_state_dict_to_load(
>>> model, optim, optim_state["optimizer"]
>>> )
>>>
>>> optim.load_state_dict(flattened_osd)
"""
metadata = storage_reader.read_metadata()
layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
device_module = _get_device_module(dp_pg_device_type)
if dp_pg is None:
placements = []
for i in range(dist.get_world_size()):
device_info = _normalize_device_info(
dp_pg_device_type, i % device_module.device_count()
)
placements.append(f"rank:{i}/{device_info}")
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
else:
sharding_spec = _create_colwise_spec(dp_pg)
# Create a state_dict for optimizer state
state_dict: STATE_DICT_TYPE = {}
fqn_to_offset: Dict[str, Sequence[int]] = {}
for key, value in metadata.state_dict_metadata.items():
key_path = metadata.planner_data[key]
if key_path[0] != optimizer_key:
continue
if isinstance(value, BytesStorageMetadata):
state_dict[key] = "<bytes_io>"
continue
# value: TensorStorageMetadata
if value.size.numel() == 1:
state_dict[key] = _alloc_tensor(
value.properties, value.size, dp_pg_device_type
)
elif dp_pg is None:
state_dict[key] = _create_chunk_sharded_tensor(
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
rank=dist.get_rank(),
world_size=dist.get_world_size(),
num_devices_per_node=device_module.device_count(),
pg=_get_default_group(),
)
else:
spec_key = key_path[2]
alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
properties = ShardTensorProperties(
dtype=value.properties.dtype,
layout=value.properties.layout,
requires_grad=value.properties.requires_grad,
memory_format=value.properties.memory_format,
pin_memory=value.properties.pin_memory,
)
st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties)
local_shards = []
current_rank = dist.get_rank(dp_pg)
for shard_md in st_md.shards_metadata:
if cast(_remote_device, shard_md.placement).rank() != current_rank:
continue
local_shards.append(
Shard(
tensor=_alloc_tensor(
value.properties, shard_md.shard_sizes, dp_pg_device_type
),
metadata=shard_md,
)
)
st = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, st_md, process_group=dp_pg
)
if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
state_dict[key] = st
# Whether we unflatten before or after doesn't matter
load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
# FIXME the type of planner is wrong in load_state_dict
planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner,
)
state_dict = unflatten_state_dict(state_dict, metadata.planner_data)
return state_dict

View File

@ -0,0 +1,417 @@
import abc
import io
import operator
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
from typing import Any, List, Optional, Tuple, Union
import torch
from torch.distributed.checkpoint.metadata import (
ChunkStorageMetadata,
Metadata,
MetadataIndex,
STATE_DICT_TYPE,
StorageMeta,
TensorProperties,
)
__all__ = [
"WriteItemType",
"LoadItemType",
"TensorWriteData",
"WriteItem",
"ReadItem",
"SavePlan",
"LoadPlan",
"SavePlanner",
"LoadPlanner",
]
class WriteItemType(Enum):
TENSOR = auto()
SHARD = auto()
BYTE_IO = auto()
class LoadItemType(Enum):
TENSOR = auto()
BYTE_IO = auto()
@dataclass(frozen=True)
class TensorWriteData:
chunk: ChunkStorageMetadata
properties: TensorProperties
size: torch.Size
@dataclass(frozen=True)
class WriteItem:
"""Dataclass which holds information about what needs to be written to storage."""
index: MetadataIndex
type: WriteItemType
# Value present if it's a tensor write
tensor_data: Optional[TensorWriteData] = None
def tensor_storage_size(self) -> Optional[int]:
"""
Calculates the storage size of the underlying tensor, or None if this is not a tensor write.
Returns:
Optional[int] storage size, in bytes of underlying tensor if any.
"""
if self.tensor_data is None:
return None
numels = reduce(operator.mul, self.tensor_data.size, 1)
dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
return numels * dtype_size
@dataclass(frozen=True)
class ReadItem:
# Read Item
type: LoadItemType
# Index into the state_dict
dest_index: MetadataIndex
# Offsets into destination tensor
dest_offsets: torch.Size
# Index into the checkpoint
storage_index: MetadataIndex
# Offset into the checkpoint data
storage_offsets: torch.Size
# Size of the hypercube to copy
lengths: torch.Size
@dataclass(frozen=True)
class SavePlan:
items: List[WriteItem]
storage_data: Any = None
planner_data: Any = None
@dataclass
class LoadPlan:
items: List[ReadItem]
storage_data: Any = None
planner_data: Any = None
class SavePlanner(abc.ABC):
"""
Abstract class defining the protocol used by save_state_dict to plan the save process.
SavePlanners are stateful objects that can be used to customize the whole save process.
SavePlanner acts as an access proxy to the state_dict, so any transformation done to it
will be visible to the whole process.
A planner subclass can expect the following sequence of calls during save_state_dict:
1) set_up_planner - called on all ranks.
Signals the start of a checkpoint save.
2) create_local_plan - called on all ranks.
Process the state_dict and produces a `SavePlan` that will be sent for global planning.
3) create_global_plan - called on the coordinator rank only.
Takes the SavePlan from all ranks and make any global decision.
4) finish_plan - called on all ranks.
This gives each rank a chance to adjust to global planning decisions.
5) resolve_data - called multiple times on each rank
Lookups a value on the `state_dict` for the storage layer to write.
Users are recommended to extend DefaultSavePlanner instead of this interface directly as
most changes can be expressed by changes in a single method.
There are 3 usual patterns of extension:
Rewriting state_dict. This is the simplest way to extend the save process as it
doesn't requite understanding the intrincacies of how SavePlan works:
>>> # xdoctest: +SKIP("undefined vars")
>>> class RenamePlanner(DefaultSavePlanner):
>>> def set_up_planner(
>>> self,
>>> state_dict: STATE_DICT_TYPE,
>>> storage_meta: Optional[StorageMeta],
>>> is_coordinator: bool,
>>> ) -> None:
>>> # prefix all keys with `foo_``
>>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
>>> # xdoctest: +SKIP("undefined vars")
>>> class FP16Planner(DefaultSavePlanner):
>>> def create_local_plan(self):
>>> plan = super().create_local_plan()
>>> for p in plan:
>>> if p.tensor_data is not None:
>>> p.tensor_data.properties.dtype = torch.float16
>>> return plan
>>>
>>> def resolve_data(self, write_item):
>>> item = super().resolve_data(write_item)
>>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
Using the global planning step to make central decisions that can't be made individually by each rank
>>> # xdoctest: +SKIP("undefined vars")
>>> from itertools import zip_longest
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>> # This sample doesn't handle ShardedTensors
>>> def create_global_plan(self, all_plans):
>>> iters = [iter(all_plans[0].items)] * len(all_plans)
>>> items_per_rank = [
>>> [item for item in items if item is not None]
>>> for items in zip(*zip_longest(*iters), strict=True)
>>> ]
>>> all_plans = [
>>> replace(plan, items=items)
>>> for plan, items in zip(all_plans, items_per_rank, strict=True)
>>> ]
>>> return super().create_global_plan(all_plans)
Finally, some planners need to save additional metadata in the checkpoint, this is
accomplished by having each rank contribute their data items in the local plan and
the global planner aggregate them:
>>> # xdoctest: +SKIP("undefined vars")
>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>> def create_local_plan(self) -> SavePlan:
>>> plan = super().create_local_plan()
>>> return replace(plan, planner_data="per-rank-data")
>>>
>>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>> global_plan, metadata = super().create_global_plan(all_plans)
>>> merged_data = [p.planner_data for p in global_plan]
>>> metadata = replace(metadata, planner_data=merged_data)
>>> return global_plan, metadata
"""
@abc.abstractmethod
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
storage_meta: Optional[StorageMeta] = None,
is_coordinator: bool = False,
) -> None:
"""
Initialize this planner to save ``state_dict``.
Implementations should save those values as they won't be provided lated in the save process.
This is called on all ranks.
"""
@abc.abstractmethod
def create_local_plan(self) -> SavePlan:
"""
Compute the save plan for the current rank.
This will be aggregated and passed to create_global_plan.
Planner specific data can be passed through SavePlan::planner_data.
This is called on all ranks.
"""
@abc.abstractmethod
def create_global_plan(
self, all_plans: List[SavePlan]
) -> Tuple[List[SavePlan], Metadata]:
"""
Compute the global checkpoint plan and return the local plan of each rank.
This is called on the coordinator rank only.
"""
@abc.abstractmethod
def finish_plan(self, new_plan: SavePlan) -> SavePlan:
"""
Merge the plan created by `create_local_plan` and the result of `create_global_plan`.
This is called on all ranks.
"""
@abc.abstractmethod
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
"""
Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.
Lookup the object associated with ``write_item`` in ``state_dict`` and apply any
transformation (such as serialization) prior to the storage layer consuming it.
Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
This method should be idempotent and thread-save. StorageWriter implementations
are free to call it as frequently as they need.
Any transformation that allocates memory should be lazily done when his method
is called in order to reduce peak memory required by checkpointing.
When returning tensors, they can be on any device or format, they can be views too.
It's the storage layer responsibility to figure out how to save them.
"""
class LoadPlanner:
"""
Abstract class defining the protocol used by load_state_dict to plan the load process.
LoadPlanner are stateful objects that can be used to customize the whole load process.
LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it
will be visible to the whole process.
A planner subclass can expect the following sequence of calls during load_state_dict:
1) set_up_planner - called on all ranks.
Signals the start of loading a checkpoint.
2) create_local_plan - called on all ranks.
Process the state_dict and produces a `LoadPlan` that will be sent for global planning.
3) create_global_plan - called on the coordinator rank only.
Takes the LoadPlan from all ranks and make any global decision.
4) load_bytes - called multiple times on each rank
This is called once per non-tensor value in state_dict.
5) resolve_tensor and commit_tensor - called multiple times on each rank
They are called in pair for each Tensor value in state_dict.
Users are recommended to extend DefaultLoadPlanner instead of this interface directly as
most changes can be expressed by changes in a single method.
There are two usual patterns of extension:
Rewriting state_dict. This is the simplest way to extend the load process as it
doesn't requite understanding the intrincacies of how LoadPlan works. We need
to keep a reference to the original state_dict as load happens in place so
we need to be able to perform it in place
>>> # xdoctest: +SKIP("undefined vars")
>>> class RenamePlanner(DefaultLoadPlanner):
>>> def set_up_planner(
>>> self,
>>> state_dict: STATE_DICT_TYPE,
>>> metadata: Metadata,
>>> is_coordinator: bool,
>>> ) -> None:
>>> self.original_state_dict = state_dict
>>> state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>> if self.flatten_sharded_tensors:
>>> state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>> if self.flatten_state_dict:
>>> state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>> self.state_dict = state_dict
>>> self.metadata = metadata
>>> self.is_coordinator = is_coordinator
>>>
>>> def load_bytes(self, read_item, value):
>>> # Remove the "foo_" prefix
>>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
Modifying resolve_tensor and commit_tensor to handle load time transformation.
>>> # xdoctest: +SKIP("undefined vars")
>>> class MetaModelMaterialize(DefaultSavePlanner):
>>> def resolve_tensor(self, read_item):
>>> tensor = super().resolve_tensor(read_item)
>>> return torch.empty_like(tensor, device="cpu")
>>>
>>> def commit_tensor(self, read_item, tensor):
>>> self.state_dict[read_item.dest_index.fqn] = tensor
"""
@abc.abstractmethod
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Optional[Metadata] = None,
is_coordinator: bool = False,
) -> None:
"""
Initialize this instance to load data into ``state_dict``.
. N.B. This is called on every rank.
"""
@abc.abstractmethod
def create_local_plan(self) -> LoadPlan:
"""
Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
. N.B. This is called on every rank.
"""
@abc.abstractmethod
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
"""
Compute the global load plan and return plans for each rank.
. N.B. This is called on the coordinator rank only
"""
@abc.abstractmethod
def finish_plan(self, central_plan: LoadPlan) -> LoadPlan:
"""Accept the plan from coordinator and return final LoadPlan."""
@abc.abstractmethod
def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
"""
Load the item described by ``read_item``and ``value``.
This method is expected to modify in-place the underlying state_dict.
The contents of ``value`` are defined by the SavePlanner used to produce
the checkpoint being loaded.
"""
def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO:
"""
Return the BytesIO to be used by the StorageReader to load `read_item`.
The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents.
"""
raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented")
@abc.abstractmethod
def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor:
"""
Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`.
The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.
If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data
back to the one in state_dict.
"""
@abc.abstractmethod
def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
"""
Call once the StorageReader finished loading data into ``tensor``.
The provided tensor is the same one returned by the call to ``resolve_tensor``.
This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to
copying it back to the one in the state_dict.
The contents of tensor will follow its device synchronization model.
"""

View File

@ -0,0 +1,386 @@
# mypy: allow-untyped-defs
import io
from typing import Any, Callable, cast, Dict, List
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from .metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
MetadataIndex,
STATE_DICT_TYPE,
STORAGE_TYPES,
TensorProperties,
TensorStorageMetadata,
)
from .planner import (
LoadItemType,
ReadItem,
SavePlan,
TensorWriteData,
WriteItem,
WriteItemType,
)
from .resharding import (
_check_shard_metadata_pair_overlap,
_shards_get_overlap_region_wrt_saved_tensor,
)
__all__: List[str] = ["create_read_items_for_chunk_list"]
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
return ChunkStorageMetadata(
offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
)
def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
return ChunkStorageMetadata(
offsets=torch.Size(shard_md.shard_offsets),
sizes=torch.Size(shard_md.shard_sizes),
)
def _sharded_tensor_metadata(
sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> TensorWriteData:
shard_properties = sharded_tensor.metadata().tensor_properties
properties = TensorProperties(
dtype=shard_properties.dtype,
layout=shard_properties.layout,
requires_grad=shard_properties.requires_grad,
memory_format=shard_properties.memory_format,
pin_memory=shard_properties.pin_memory,
)
return TensorWriteData(
chunk=_chunk_for_shard(shard_md),
properties=properties,
size=sharded_tensor.metadata().size,
)
def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.SHARD,
tensor_data=TensorWriteData(
chunk=ChunkStorageMetadata(
offsets=offsets,
sizes=sizes,
),
properties=TensorProperties.create_from_tensor(tensor.to_local()),
size=tensor.size(),
),
)
def _create_write_item_for_shard(
fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> WriteItem:
offsets = torch.Size(shard_md.shard_offsets)
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.SHARD,
tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
)
def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
offsets = torch.Size([0] * len(tensor.size()))
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.TENSOR,
tensor_data=TensorWriteData(
chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
properties=TensorProperties.create_from_tensor(tensor),
size=tensor.size(),
),
)
def _create_write_item_for_bytesio(fqn: str, bytes: Any):
return WriteItem(
index=MetadataIndex(fqn),
type=WriteItemType.BYTE_IO,
)
def _create_read_item_for_byteio(
dest_index, dest_offset, storage_index, storage_offset, length
):
return ReadItem(
type=LoadItemType.BYTE_IO,
dest_index=dest_index,
dest_offsets=torch.Size((dest_offset,)),
storage_index=storage_index,
storage_offsets=torch.Size((storage_offset,)),
lengths=torch.Size((length,)),
)
def _create_read_item_for_tensor(
dest_index, dest_offsets, storage_index, storage_offsets, lengths
):
return ReadItem(
type=LoadItemType.TENSOR,
dest_index=dest_index,
dest_offsets=torch.Size(dest_offsets),
storage_index=storage_index,
storage_offsets=torch.Size(storage_offsets),
lengths=torch.Size(lengths),
)
def create_read_items_for_chunk_list(
fqn: str,
checkpoint_md: TensorStorageMetadata,
local_chunks: List[ChunkStorageMetadata],
) -> List[ReadItem]:
"""
Create a list of ``ReadItem`` based on the checkpoint and local chunks.
This applies the resharding algorithm and computes the reads needed
to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
Args:
fqn (str) : The state_dict FQN to pass to ``ReadItem``.
checkpoint_md (TensorStorageMetadata): metadata for a given tensor
from a checkpoint.
local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
loaded.
Returns:
A list of ``ReadItem`` that will satisfy all input chunks.
"""
read_items = []
# this is a naive quadratic algo that can be optimized later
for idx, shard in enumerate(local_chunks):
for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
if not _check_shard_metadata_pair_overlap(shard, storage_md):
continue
storage_offsets = []
dest_offsets = []
lengths = []
for (
dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,
) in _shards_get_overlap_region_wrt_saved_tensor(
saved_shard=storage_md, current_shard=shard
):
storage_offsets.append(offset_for_saved_tensor)
dest_offsets.append(offset_for_current_tensor)
lengths.append(length)
read_items.append(
_create_read_item_for_tensor(
dest_index=MetadataIndex(fqn, shard.offsets, idx),
dest_offsets=dest_offsets,
storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
storage_offsets=storage_offsets,
lengths=lengths,
)
)
return read_items
def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
requests = []
for fqn, obj in state_dict.items():
if isinstance(obj, DTensor):
requests.append(_create_write_items_for_dtensor(fqn, obj))
elif isinstance(obj, ShardedTensor):
for shard_md in obj.metadata().shards_metadata:
requests.append(_create_write_item_for_shard(fqn, obj, shard_md))
elif isinstance(obj, torch.Tensor):
requests.append(_create_write_item_for_tensor(fqn, obj))
else:
requests.append(_create_write_item_for_bytesio(fqn, obj))
return SavePlan(requests)
def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
if hasattr(object, "__create_write_items__"):
# DTensor implements _Checkpointable
return object.__create_write_items__(fqn, object)
elif isinstance(object, ShardedTensor):
return [
_create_write_item_for_shard(fqn, object, shard.metadata)
for shard in object.local_shards()
]
elif isinstance(object, torch.Tensor):
return [_create_write_item_for_tensor(fqn, object)]
else:
return [_create_write_item_for_bytesio(fqn, object)]
def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
return ChunkStorageMetadata(
offsets=offsets,
sizes=sizes,
)
def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
if hasattr(tensor, "__create_chunk_list__"):
# DTensor implements _Checkpointable
local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
elif isinstance(tensor, ShardedTensor):
local_chunks = [
_chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
]
elif isinstance(tensor, torch.Tensor):
local_chunks = [_create_chunk_from_tensor(tensor)]
else:
raise ValueError(
"Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
f",but got {type(tensor)}"
)
return local_chunks
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
if not isinstance(md, BytesStorageMetadata):
try:
local_chunks = _create_chunk_list(obj)
except ValueError as ex:
raise ValueError(
f"Invalid checkpoint metadata for {fqn}, "
+ f"expected BytesStorageMetadata but found {type(md)}",
) from ex
return create_read_items_for_chunk_list(fqn, md, local_chunks)
else:
return [
_create_read_item_for_byteio(
dest_index=MetadataIndex(fqn),
dest_offset=0,
storage_index=MetadataIndex(fqn),
storage_offset=0,
length=0,
)
]
def _init_state_dict(state_dict: Dict[str, Any]) -> Any:
"""
Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
"""
def dtensor_func(value: DTensor):
device = getattr(value, "device", None)
if device == torch.device("meta"):
device_type = dist.distributed_c10d._get_pg_default_device().type
device = cast(
torch.device, _get_device_module(device_type).current_device()
)
new_local_tensor = torch.empty_like(value.to_local(), device=device)
# We need to pass shape and stride explicitly, since DTensor might be
# sharded unevenly.
dtensor = DTensor.from_local(
new_local_tensor,
device_mesh=value.device_mesh,
placements=value.placements,
shape=value.size(),
stride=value.stride(),
)
return dtensor
else:
return value
def sharded_tensor_func(value: Any):
device = getattr(value, "device", None)
if device == torch.device("meta"):
raise RuntimeError(
f"Found unsupported type {type(value)} for meta device loading."
)
else:
return value
def tensor_func(value: torch.Tensor):
device = getattr(value, "device", None)
if device == torch.device("meta"):
device_type = dist.distributed_c10d._get_pg_default_device().type
device = cast(
torch.device, _get_device_module(device_type).current_device()
)
tensor = torch.empty_like(value, device=device)
return tensor
else:
return value
_iterate_state_dict(
state_dict,
dtensor_func,
sharded_tensor_func,
tensor_func,
)
def _iterate_state_dict(
iter_object: Any,
dtensor_func: Callable,
sharded_tensor_func: Callable,
tensor_func: Callable,
):
"""
Iterate through the state dict, applying the given functions to each tensor type
and update the state dict in place.
Args:
iter_object (Any): the target state_dict.
sharded_tensor_func (Callable): the function to apply to ShardedTensor
dtensor_func (Callable): the function to apply to DTensor
tensor_func (Callable): the function to apply to Tensor
# TODO: let state_dict_util._iterate_state_dict() to support in place option
so we don't need to have two versions of _iterate_state_dict.
"""
if isinstance(iter_object, DTensor):
return dtensor_func(iter_object)
elif isinstance(iter_object, ShardedTensor):
return sharded_tensor_func(iter_object)
elif isinstance(iter_object, torch.Tensor):
return tensor_func(iter_object)
elif (
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
or iter_object is None
):
return iter_object
elif isinstance(iter_object, dict):
for key, value in iter_object.items():
iter_object[key] = _iterate_state_dict(
value, dtensor_func, sharded_tensor_func, tensor_func
)
return iter_object
elif isinstance(iter_object, (list, tuple)):
ret = [
_iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func)
for v in iter_object
]
if isinstance(iter_object, tuple):
ret = tuple(ret) # type: ignore[assignment]
return ret

View File

@ -0,0 +1,72 @@
# mypy: allow-untyped-defs
from typing import List, Tuple
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
__all__: List[str] = []
def _check_shard_metadata_pair_overlap(
shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
):
"""Check if two shards overlap."""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
# shard, we would check if one shard is above or on the left of the
# other shard.
ndims = len(shard1.offsets)
for i in range(ndims):
if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]:
return False
if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]:
return False
return True
def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
) -> List[Tuple[int, int, int, int]]:
"""
Return the overlapping region between saved_shard and current_shard.
There returned list has the same number of elements as the tensor's dimension.
For each element, we produce a tuple with the following contents:
(dimension, `saved_shard` offset, `current_shard` offset, length)
Offsets are relative to each shard.
"""
narrows = []
for dim, (
saved_shard_offset,
current_shard_offset,
saved_shard_size,
current_shard_size,
) in enumerate(
zip(
saved_shard.offsets,
current_shard.offsets,
saved_shard.sizes,
current_shard.sizes,
)
):
min_range_end = min(
saved_shard_offset + saved_shard_size,
current_shard_offset + current_shard_size,
)
length = min_range_end - max(current_shard_offset, saved_shard_offset)
if saved_shard_offset > current_shard_offset:
offset_for_saved_tensor = 0
offset_for_current_tensor = saved_shard_offset - current_shard_offset
else:
offset_for_saved_tensor = current_shard_offset - saved_shard_offset
offset_for_current_tensor = 0
narrows.append(
(dim, offset_for_saved_tensor, offset_for_current_tensor, length)
)
return narrows

View File

@ -0,0 +1,117 @@
from typing import Optional, runtime_checkable
from typing_extensions import Protocol
from torch.distributed._state_dict_utils import (
_copy_state_dict,
_create_cpu_state_dict,
_offload_state_dict_to_cpu,
)
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
__all__ = ["AsyncStager", "BlockingAsyncStager"]
@runtime_checkable
class AsyncStager(Protocol):
"""
This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users
to customize how data is staged previous to executing the usual dcp.save path in parallel.
The expected order of operations (concretely defined in `torch.distributed.state_dict_saver.async_save`)
is the following:
1. AsyncStager.stage_data(state_dict):
This call gives the AsyncStager the opportunity to 'stage'
the state_dict. The expectation and purpose of staging in this context is to create a "training-safe"
representation of the state dict, meaning that any updates to module data after staging is complete
should not be reflected in the state dict returned from this method. For example, in the default
case a copy of the entire state dict is created on CPU RAM and returned here, allowing users
to continue training without risking changes to data which is being serialized.
2. dcp.save is called on the state_dict returned from stage in parallel. This call is responsible
for serializing the state_dict and writing it to storage.
3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after
the serialization thread starts and before returning from dcp.async_save. If this is set to False,
the assumption is the user has defined a custom synchronization point for the the purpose of further
optimizing save latency in the training loop (for example, by overlapping staging with the
forward/backward pass), and it is the respondsibility of the user to call `AsyncStager.synchronize_staging`
at the appropriate time.
"""
# default to True since the common case is to stage synchronously
_synchronize_after_execute: bool = True
@property
def should_synchronize_after_execute(self) -> bool:
"""
Whether to synchronize after executing the stage.
"""
return self._synchronize_after_execute
def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""
Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is
innoculated from any updates incurred after the stage call is complete.
"""
raise NotImplementedError(
f"{self.__class__.__name__} must implement stage method"
)
def synchronize_staging(self) -> None:
"""
In the case `stage` is async in some way, this method should be called to ensure staging
is complete and it is safe to begin modifying the original `state_dict`
"""
class BlockingAsyncStager(AsyncStager):
"""
An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete.
This implementation also provides an option to optimize stage latency using pinned memory.
N.B. synchronize_staging is a no-op in this case.
"""
# default to True since the common case is to stage synchronously
_synchronize_after_execute: bool = False
def __init__(
self,
cache_staged_state_dict: bool = False,
type_check: bool = False,
):
"""
Initializes the BlockingAsyncStager.
Args:
cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
type_check: Whether to perform a type check during cpu_offload. Defaults to False.
"""
self.cache_staged_state_dict = cache_staged_state_dict
self.type_check = type_check
self.state_dict_cache: Optional[STATE_DICT_TYPE] = None
def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""
Returns a copy of `state_dict` on the CPU.
"""
if not self.cache_staged_state_dict:
return _offload_state_dict_to_cpu(state_dict, type_check=self.type_check)
if self.state_dict_cache is None:
self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True)
return _copy_state_dict(state_dict, self.state_dict_cache)
def synchronize_staging(self) -> None:
"""
No-op function, since staging is blocking.
"""

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,316 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import os
import warnings
from typing import Any, cast, Dict, Optional, Set, Union
from typing_extensions import deprecated
import torch
import torch.distributed as dist
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.logger import _dcp_method_logger
from torch.distributed.checkpoint.stateful import Stateful
from ._storage_utils import _storage_setup
from .default_planner import DefaultLoadPlanner
from .planner import LoadPlan, LoadPlanner
from .storage import StorageReader
from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile
__all__ = ["load_state_dict", "load"]
@deprecated(
"`load_state_dict` is deprecated and will be removed in future versions. "
"Please use `load` instead.",
category=FutureWarning,
)
def load_state_dict(
state_dict: Dict[str, Any],
storage_reader: StorageReader,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[LoadPlanner] = None,
) -> None:
"""This method is deprecated. Please switch to 'load'."""
storage_reader.reset()
with _profile():
# TODO: test returning `load` here instead.
return _load_state_dict(
state_dict,
storage_reader,
process_group,
coordinator_rank,
no_dist,
planner,
)
@_dcp_method_logger(log_exceptions=True)
@_api_bc_check
def load(
state_dict: Dict[str, Any],
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None,
planner: Optional[LoadPlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> None:
"""
Load a distributed ``state_dict`` in SPMD style.
Each rank will try to read the least amount of data necessary
to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
or :class:`DTensor` instances, each rank only reads data for their local shards.
For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
load will first call ``state_dict`` before attempting deserialization, followed by
``load_state_dict`` once the deserialization is complete.
.. warning::
All tensors in ``state_dict`` must be allocated on their
destination device *prior to* calling this function.
All non-tensor data is loaded using `torch.load()` and modified in place
on state_dict.
.. warning::
Users must call `load_state_dict` on the root module to ensure load
pos-processing and non-tensor data properly propagates.
.. note:
If no process group is initialized, this function will assume the intent
is to load a checkpoint into the local process. This can be useful in the
case of local inference, and when using regular Tensors (as opposed to DTensor
or ShardedTensor)
.. note:
Rank 0 is assumed to be the coordinator rank.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
storage_reader (Optional[StorageReader]):
Instance of StorageWriter used to perform reads. If this is not
specified, DCP will automatically infer the reader based on the
checkpoint_id. If checkpoint_id is also None, an exception will
be raised. (Default: ``None``)
planner (Optional[LoadPlanner]):
Instance of LoadPlanner. If this is not specificed, the default
planner will be used. (Default: ``None``)
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
Returns:
None.
Examples
>>> # xdoctest: +SKIP
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>> state_dict=model_state_dict,
>>> storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)
.. note::
load_state_dict uses collectives to coordinate reads across ranks.
For NCCL-based process groups, internal tensor representations of
objects must be moved to the GPU device before communication takes place.
In this case, the device used is given by ``torch.cuda.current_device()``
and it is the user's responsibility to ensure that this is set so that each
rank has an individual GPU, via ``torch.cuda.set_device()``.
"""
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
)
with _profile():
storage_reader = cast(
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
)
if no_dist:
keys = list(state_dict.keys())
else:
keys = _all_gather_keys(state_dict, process_group)
if keys != sorted(state_dict.keys()):
warnings.warn(
"Detected mismatched keys in state dict after all gather!"
" This behavior is unsupported and may cause errors may cause errors."
)
statetful_sd = {}
for key in keys:
if key not in state_dict:
continue
elem = state_dict[key]
statetful_sd[key] = (
elem.state_dict() if isinstance(elem, Stateful) else elem
)
_load_state_dict(
state_dict=statetful_sd,
storage_reader=storage_reader,
process_group=process_group,
no_dist=no_dist,
planner=planner,
)
for key in keys:
if key not in state_dict:
continue
elem = state_dict[key]
if isinstance(elem, Stateful):
elem.load_state_dict(statetful_sd[key])
state_dict[key] = statetful_sd[key]
def _load_state_dict(
state_dict: Dict[str, Any],
storage_reader: StorageReader,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[LoadPlanner] = None,
) -> None:
torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultLoadPlanner()
ckpt_kwargs = {}
if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
@_dcp_method_logger(**ckpt_kwargs)
def local_step():
assert planner is not None
metadata = storage_reader.read_metadata()
planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_reader.prepare_local_plan(local_plan)
return local_plan
@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
assert planner is not None
all_local_plans = planner.create_global_plan(all_local_plans)
all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
return all_local_plans
central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
@_dcp_method_logger(**ckpt_kwargs)
def read_data():
assert planner is not None
final_local_plan = planner.finish_plan(central_plan)
all_reads = storage_reader.read_data(final_local_plan, planner)
all_reads.wait()
return None
_ = distW.all_gather("read", read_data)
def _load_state_dict_from_keys(
keys: Optional[Union[Set[str], str]] = None,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
Load only the specified keys from the checkpoint, if no keys are specified, the entire
checkpoint will be loaded. Note, this method completely loads the checkpoint into the
current process and is not distributed.
.. warning::
.. warning::
All non-tensor data is loaded using `torch.load()`
.. note:
As opposed to the usual pattern, this function does not take a state dict as input
and does not load inplace. Instead, a new state dict is directly initialized and read
from file.
.. note:
If no process group is initialized, this function will assume the intent
is to load a checkpoint into the local process. This can be useful in the
case of local inference, and when using regular Tensors (as opposed to DTensor
or ShardedTensor)
.. note:
Rank 0 is assumed to be the coordinator rank.
Args:
keys (Optional[Union[Set[str], str]]):
Loads any key specified in this set. If no keys are specified, the entire checkpoint
is loaded.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
storage_reader (Optional[StorageReader]):
Instance of StorageWriter used to perform reads. If this is not
specified, DCP will automatically infer the reader based on the
checkpoint_id. If checkpoint_id is also None, an exception will
be raised. (Default: ``None``)
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
Returns:
State dict from specified keys
"""
torch._C._log_api_usage_once(
"torch.distributed.checkpoint._load_state_dict_from_keys"
)
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
)
storage_reader = cast(
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
)
if isinstance(keys, str):
keys = {keys}
sd: Dict[str, Any] = {}
_load_state_dict(
state_dict=sd,
storage_reader=storage_reader,
process_group=process_group,
no_dist=no_dist,
planner=_EmptyStateDictLoadPlanner(keys=keys or set()),
)
return sd

View File

@ -0,0 +1,333 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
import os
import warnings
from concurrent.futures import Future, ThreadPoolExecutor
from typing import cast, Optional, Union
from typing_extensions import deprecated
import torch
import torch.distributed as dist
from torch.distributed._state_dict_utils import _offload_state_dict_to_cpu
from torch.distributed.checkpoint._storage_utils import _storage_setup
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.logger import _dcp_method_logger
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.staging import AsyncStager
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.distributed_c10d import _get_default_group
from .utils import _api_bc_check, _DistWrapper, _profile
__all__ = ["save_state_dict", "save", "async_save"]
@deprecated(
"`save_state_dict` is deprecated and will be removed in future versions."
"Please use `save` instead.",
category=FutureWarning,
)
def save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
"""This method is deprecated. Please switch to 'save'."""
storage_writer.reset()
# TODO: test returning `save` here instead.
with _profile():
return _save_state_dict(
state_dict,
storage_writer,
process_group,
coordinator_rank,
no_dist,
planner,
)
@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type]
@_api_bc_check
def save(
state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Metadata:
"""
Save a distributed model in SPMD style.
This function is different from ``torch.save()`` as it handles
``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards.
For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
save will call ``state_dict`` before serialization.
.. warning::
There is no guarantees of Backwards Compatibility across PyTorch versions
for saved state_dicts.
.. warning::
If using the `process_group` argument, make sure that only its ranks
call `save_state_dict` and that all data in state_dict belong to it.
.. note::
When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of
the shard_group should be calling `save_state_dict` and the corresponding process
group needs to be passed in.
.. note::
If no process group is available, this function assumes the intention is to save the
state_dict in the local process.
.. note:
Rank 0 is assumed to be the coordinator rank.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
storage_writer (Optional[StorageWriter]):
Instance of StorageWriter used to perform writes. If this is not
specified, DCP will automatically infer the writer based on the
checkpoint_id. If checkpoint_id is also None, an exception will
be raised. (Default: ``None``)
planner (Optional[SavePlanner]):
Instance of SavePlanner. If this is not specificed, the default
planner will be used. (Default: ``None``)
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
Returns:
Metadata: Metadata object for the saved checkpoint.
Example:
>>> # xdoctest: +SKIP
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save(
>>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer,
>>> )
.. note::
save_state_dict uses collectives to coordinate writes across ranks.
For NCCL-based process groups, internal tensor representations of
objects must be moved to the GPU device before communication takes place.
In this case, the device used is given by ``torch.cuda.current_device()``
and it is the user's responsibility to ensure that this is set so that
each rank has an individual GPU, via ``torch.cuda.set_device()``.
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint.save")
no_dist = not (dist.is_available() and dist.is_initialized())
if no_dist:
warnings.warn(
"torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process."
)
with _profile():
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
)
return _save_state_dict(
state_dict=_stateful_to_state_dict(state_dict),
storage_writer=storage_writer,
process_group=process_group,
no_dist=no_dist,
planner=planner,
)
@_dcp_method_logger(log_exceptions=True)
def async_save(
state_dict: STATE_DICT_TYPE,
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> Future:
"""Asynchronous version of ``save``. This code first de-stages the state_dict on to the
staging storage (defaults to CPU memory), and then calls the `save` in a separate thread.
.. warning::
This feature is experimental and subject to change.
Args:
state_dict (Dict[str, Any]): The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
storage_writer (Optional[StorageWriter]):
Instance of StorageWriter used to perform 'stage' and 'save'. If
this is not specified, DCP will automatically infer the writer based on the
checkpoint_id. If checkpoint_id is also None, an exception will
be raised. (Default: ``None``)
planner (Optional[SavePlanner]):
Instance of SavePlanner. If this is not specificed, the default
planner will be used. (Default: ``None``)
process_group (Optional[ProcessGroup]):
ProcessGroup to be used for cross-rank synchronization.
(Default: ``None``)
Returns:
Future: A future holding the resultant Metadata object from `save`.
Example:
>>> # xdoctest: +SKIP
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>> state_dict=state_dict,
>>> storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... do some work ...
>>>
>>> checkpoint_future.result()
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save")
if dist.is_available() and dist.is_initialized():
pg = process_group or _get_default_group()
assert (
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
storage_writer = cast(
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
)
state_dict = _stateful_to_state_dict(state_dict)
if isinstance(storage_writer, AsyncStager):
staged_state_dict = storage_writer.stage(state_dict)
else: # provides bwc for storage_writers not implementing AsyncStager
staged_state_dict = _offload_state_dict_to_cpu(state_dict, type_check=False)
executor = ThreadPoolExecutor(max_workers=1)
f: Future = executor.submit(
save,
staged_state_dict,
checkpoint_id=checkpoint_id,
storage_writer=storage_writer,
planner=planner,
process_group=process_group,
)
f.add_done_callback(lambda f: executor.shutdown(wait=False))
if (
isinstance(storage_writer, AsyncStager)
and storage_writer.should_synchronize_after_execute
):
storage_writer.synchronize_staging()
return f
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
stateful_state_dict = {}
for key, elem in state_dict.items():
stateful_state_dict[key] = (
elem.state_dict() if isinstance(elem, Stateful) else elem
)
return stateful_state_dict
def _save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
global_metadata = None
ckpt_kwargs = {}
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
@_dcp_method_logger(**ckpt_kwargs)
def local_step():
assert planner is not None
storage_meta = storage_writer.storage_meta()
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
warnings.warn(
"The function definition for SavePlanner.set_up_planner has been updated"
" to include the storage_meta argument. Please update your implementation"
" to include this parameter."
)
planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type]
else:
planner.set_up_planner(
state_dict=state_dict,
storage_meta=storage_meta,
is_coordinator=distW.is_coordinator,
)
storage_writer.set_up_storage_writer(distW.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
nonlocal global_metadata
assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
@_dcp_method_logger(**ckpt_kwargs)
def write_data():
assert planner is not None
final_local_plan = planner.finish_plan(central_plan)
all_writes = storage_writer.write_data(final_local_plan, planner)
all_writes.wait()
return all_writes.value()
@_dcp_method_logger(**ckpt_kwargs)
def finish_checkpoint(all_results):
assert global_metadata is not None
storage_writer.finish(metadata=global_metadata, results=all_results)
return global_metadata
return distW.all_reduce("write", write_data, finish_checkpoint)

View File

@ -0,0 +1,42 @@
from typing import Any, Dict, runtime_checkable, TypeVar
from typing_extensions import Protocol
__all__ = ["Stateful", "StatefulT"]
@runtime_checkable
class Stateful(Protocol):
"""
Stateful protocol for objects that can be checkpointed and restored.
"""
def state_dict(self) -> Dict[str, Any]:
"""
Objects should return their state_dict representation as a dictionary.
The output of this function will be checkpointed, and later restored in
`load_state_dict()`.
.. warning::
Because of the inplace nature of restoring a checkpoint, this function
is also called during `torch.distributed.checkpoint.load`.
Returns:
Dict: The objects state dict
"""
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Restore the object's state from the provided state_dict.
Args:
state_dict: The state dict to restore from
"""
...
StatefulT = TypeVar("StatefulT", bound=Stateful)

View File

@ -0,0 +1,284 @@
import abc
import os
from dataclasses import dataclass
from typing import Any, List, Optional, Union
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
from torch.distributed.checkpoint.planner import (
LoadPlan,
LoadPlanner,
SavePlan,
SavePlanner,
)
from torch.futures import Future
__all__ = ["WriteResult", "StorageWriter", "StorageReader"]
@dataclass(frozen=True)
class WriteResult:
index: MetadataIndex
size_in_bytes: int
storage_data: Any
class StorageWriter(abc.ABC):
"""
Interface used by ``save_state_dict`` to write to storage.
One StorageWriter instance acts as both the coordinator and the follower
in a distributed checkpoint. As part of initialization, each instance
is told its role.
A subclass should expect the following sequence of calls.
0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
1) (all ranks) set_up_storage_writer()
2) (all ranks) prepare_local_plan()
3) (coordinator) prepare_global_plan()
4) (all ranks) write_data()
5) (coordinator) finish()
"""
@abc.abstractmethod
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
"""
Calls to indicates a brand new checkpoint write is going to happen.
A checkpoint_id may be present if users set the checkpoint_id for
this checkpoint write. The meaning of the checkpiont_id is
storage-dependent. It can be a path to a folder/file or a key for
a key-value storage.
Args:
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
"""
...
@abc.abstractmethod
def set_up_storage_writer(self, is_coordinator: bool) -> None:
"""
Initialize this instance.
Args:
is_coordinator (bool): Whether this instance is responsible for coordinating
the checkpoint.
"""
@abc.abstractmethod
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
"""
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommended
way is to store storage specific data in SavePlan::storage_data.
Args:
plan (SavePlan): The local plan from the ``SavePlanner`` in use.
Returns:
A transformed ``SavePlan`` after storage local planning
"""
@abc.abstractmethod
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
"""
Perform centralized planning of storage.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferred
way is to store storage specific data in SavePlan::storage_data.
Args:
plans: A list of ``SavePlan`` instances, one for each rank.
Returns:
A list of transformed ``SavePlan`` after storage global planning
"""
@abc.abstractmethod
def write_data(
self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]:
"""
Write all items from ``plan`` using ``planner`` to resolve the data.
A subclass should call ``SavePlanner::resolve_data`` on each item
from the plan to get access to the underlying object to write.
Subclasses should lazily call `resolve_data` as it can allocate memory.
In case of tensors, make following assumptions:
- They might be on any device, including not matching the one on ``WriteItem::tensor_data``
- They might be views or not contiguous. Only the projection needs to be saved.
Args:
plan (SavePlan): The save plan to execute.
planner (SavePlanner): Planner object to be used to resolve items to data.
Returns:
A future that completes to a list of WriteResult
"""
@abc.abstractmethod
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
"""
Write the metadata and marks the current checkpoint as successful.
The actual format/schema used for serializing `metadata` is an
implementation detail. The only requirement is that it's recoverable
in to the same object graph.
Args:
metadata (Metadata): metadata for the new checkpoint
results: A list of WriteResults from all ranks.
Returns:
None
"""
@classmethod
@abc.abstractmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
"""
Check if the given checkpoint_id is supported by the stroage. This allow
us to enable automatic storage selection.
"""
...
def storage_meta(self) -> Optional[StorageMeta]:
"""
Return the storage-specific metadata. This is used to store additional information
in a checkpoint that can be useful for providing request-level observability. StorageMeta
is passed to the ``SavePlanner`` during save calls. Returns None by default.
TODO: provide an example
"""
return None
class StorageReader(abc.ABC):
"""
Interface used by ``load_state_dict`` to read from storage.
One StorageReader instance acts as both the coordinator and the follower
in a distributed checkpoint. As part of initialization, each instance
is told its role.
A subclass should expected the following sequence of calls by ``load_state_dict``:
0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
1) (all ranks) read_metadata()
2) (all ranks) set_up_storage_reader()
3) (all ranks) prepare_local_plan()
4) (coordinator) prepare_global_plan()
5) (all ranks) read_data()
"""
@abc.abstractmethod
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
"""
Calls to indicates a brand new checkpoint read is going to happen.
A checkpoint_id may be present if users set the checkpoint_id for
this checkpoint read. The meaning of the checkpiont_id is
storage-dependent. It can be a path to a folder/file or a key for
a key-value storage.
Args:
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is more like a key-value store.
(Default: ``None``)
"""
...
@abc.abstractmethod
def read_metadata(self) -> Metadata:
"""
Read the checkpoint metadata.
Returns:
The metadata object associated with the checkpoint being loaded.
"""
@abc.abstractmethod
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
"""
Initialize this instance.
Args:
metadata (Metadata): The metadata schema to use.
is_coordinator (bool): Whether this instance is responsible for coordinating
the checkpoint.
"""
@abc.abstractmethod
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
"""
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommended
way is to store storage specific data in LoadPlan::storage_data.
Args:
plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
Returns:
A transformed ``LoadPlan`` after storage local planning
"""
@abc.abstractmethod
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
"""
Perform centralized planning of storage loading.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferred
way is to store storage specific data in LoadPlan::storage_data.
Args:
plans: A list of ``LoadPlan`` instances, one for each rank.
Returns:
A list of transformed ``LoadPlan`` after storage global planning
"""
@abc.abstractmethod
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
"""
Read all items from ``plan`` using ``planner`` to resolve the data.
A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
object into the right place.
A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
tensors that in should load data into.
It's the StorageLayer responsibility to properly schedule any cross device copies
required.
Args:
plan (LoadPlan): The local plan to execute on
planner (LoadPlanner): The planner object to use to resolve items.
Returns:
A future that completes once all reads are finished.
"""
@classmethod
@abc.abstractmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
"""
Check if the given checkpoint_id is supported by the stroage. This allow
us to enable automatic storage selection.
"""
...

View File

@ -0,0 +1,431 @@
# mypy: allow-untyped-defs
import cProfile
import inspect
import io
import itertools
import os
import warnings
from contextlib import contextmanager
from functools import wraps
from pstats import Stats
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard
from .api import (
_is_wrapped_exception,
_wrap_exception,
CheckpointException,
WRAPPED_EXCEPTION,
)
from .metadata import MetadataIndex, STATE_DICT_TYPE
__all__ = ["find_tensor_shard", "find_state_dict_object"]
T = TypeVar("T")
R = TypeVar("R")
def _get_failure_dict(
results: List[Union[T, WRAPPED_EXCEPTION]]
) -> Dict[int, WRAPPED_EXCEPTION]:
return cast(
Dict[int, WRAPPED_EXCEPTION],
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
)
def _all_gather_keys(
local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None
) -> List[Any]:
"""Gathers all keys, and returns them sorted."""
keys = list(local_dict.keys())
gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
dist.all_gather_object(gathered_keys, keys, group=group)
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
class _DistWrapper:
"""
This is a wrapper around PG that provides a series of features around object collectives.
It works without distributed initialized, where most collectives turns into nops.
All variants that take functions are exception robust, meaning that if one or more
ranks raise errors, all ranks will observe those.
"""
def __init__(
self,
group: Optional[dist.ProcessGroup],
use_dist: bool,
coordinator_rank: int,
):
self.group = group
self.use_dist = use_dist
self.coordinator_rank = coordinator_rank
if self.use_dist:
self.rank = dist.get_rank(group)
self.is_coordinator = self.rank == coordinator_rank
else:
self.rank = 0
self.is_coordinator = True
def get_rank(self) -> int:
return self.rank
def get_world_size(self) -> int:
if self.use_dist:
return dist.get_world_size(self.group)
return 1
def broadcast_object(self, object: Optional[T]) -> T:
"""Implement functionality similar to c10d::broadcast_object_list but without distributed enabled."""
object_list = [object]
if self.use_dist:
dist.broadcast_object_list(
object_list=object_list,
group=self.group,
src=self.coordinator_rank,
)
return cast(T, object_list[0])
def gather_object(self, object: T) -> Optional[List[T]]:
"""Implement functionality similar to c10d::gather_object but without distributed enabled."""
if self.use_dist:
gather_objs = (
cast(List[T], [None] * dist.get_world_size(self.group))
if self.is_coordinator
else None
)
dist.gather_object(
obj=object,
object_gather_list=gather_objs if self.is_coordinator else None,
dst=self.coordinator_rank,
group=self.group,
)
result = gather_objs
else:
result = [object]
return result
def all_gather_object(self, object: T) -> List[T]:
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
if self.use_dist:
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
dist.all_gather_object(
object_list=gather_objs, obj=object, group=self.group
)
else:
gather_objs = [object]
return gather_objs
def scatter_object(self, object_list: Optional[List[T]]) -> T:
"""Implement functionality similar to c10d::scatter_object but without distributed enabled."""
if self.use_dist:
gather_result = cast(List[T], [None])
dist.scatter_object_list(
scatter_object_output_list=gather_result,
scatter_object_input_list=object_list if self.is_coordinator else None,
src=self.coordinator_rank,
group=self.group,
)
local_reply = gather_result[0]
else:
assert object_list is not None
local_reply = object_list[0]
return local_reply
def reduce_scatter(
self,
step: str,
map_fun: Callable[[], T],
reduce_fun: Callable[[List[T]], List[R]],
) -> R:
"""
Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
This method operates in the following way:
Run ``map_fun`` on all ranks
Gather results on rank 0
Call ``reduce_fun`` on all those values
Scatter to each rank part of the result.
"""
local_data: Union[WRAPPED_EXCEPTION, T]
try:
local_data = map_fun()
except BaseException as e:
local_data = _wrap_exception(e)
all_data = self.gather_object(local_data)
all_results: Optional[List[Union[R, CheckpointException]]] = None
if self.is_coordinator:
assert all_data is not None
node_failures = _get_failure_dict(all_data)
if len(node_failures) == 0:
try:
# N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
all_results = cast(
List[Union[R, CheckpointException]],
reduce_fun(cast(List[T], all_data)),
)
except BaseException as e:
node_failures[self.rank] = _wrap_exception(e)
if len(node_failures) > 0:
all_results = [
CheckpointException(step, node_failures)
] * self.get_world_size()
result = self.scatter_object(all_results)
if isinstance(result, CheckpointException):
raise result
return result
def all_reduce(
self,
step: str,
map_fun: Callable[[], T],
reduce_fun: Callable[[List[T]], R],
) -> R:
"""
Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
This method operates in the following way:
Run ``map_fun`` on all ranks
Gather results on rank 0
Call ``reduce_fun`` on all those values
Broadcast the reduced value to all ranks.
"""
local_data: Union[T, WRAPPED_EXCEPTION]
try:
local_data = map_fun()
except BaseException as e:
local_data = _wrap_exception(e)
all_data = self.gather_object(local_data)
result: Optional[Union[R, CheckpointException]] = None
if self.is_coordinator:
assert all_data is not None
node_failures = _get_failure_dict(all_data)
if len(node_failures) == 0:
try:
result = reduce_fun(cast(List[T], all_data))
except BaseException as e:
node_failures[self.rank] = _wrap_exception(e)
if len(node_failures) > 0:
result = CheckpointException(step, node_failures)
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result
return cast(R, final_result)
def all_gather(
self,
step: str,
map_fun: Callable[[], T],
) -> List[T]:
"""
Compute a value on each rank, then all_gather them.
This method operates in the following way:
Run ``map_cp`` on all ranks
all_gather the values to all ranks
"""
result: Union[T, WRAPPED_EXCEPTION]
try:
result = map_fun()
except BaseException as e:
result = _wrap_exception(e)
all_results = self.all_gather_object(result)
node_failures = _get_failure_dict(all_results)
if len(node_failures) > 0:
raise CheckpointException(step, node_failures)
return cast(List[T], all_results)
def broadcast(
self,
step: str,
map_fun: Callable[[], T],
) -> T:
"""
Compute a value on rank 0 and broadcast it.
This method operates in the following way:
Run ``map_cp`` on rank 0
broadcast the value
"""
result: Optional[Union[T, CheckpointException]] = None
if self.is_coordinator:
try:
result = map_fun()
except BaseException as e:
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result
return cast(T, final_result)
def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
if index.offset is None:
raise ValueError(
f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
)
shards = tensor.local_shards()
# index fast path
if index.index is not None:
if (
len(shards) > index.index
and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
):
return shards[index.index]
for shard in shards:
if torch.Size(shard.metadata.shard_offsets) == index.offset:
return shard
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
if hasattr(tensor, "__get_tensor_shard__"):
# DTensor implements _Checkpointable
return tensor.__get_tensor_shard__(index) # type: ignore[attr-defined]
if isinstance(tensor, ShardedTensor):
return _find_shard(tensor, index).tensor
if index.offset is not None:
# special case looking up a tensor by origin
if index.offset == torch.Size([0] * len(tensor.size())):
return tensor
raise ValueError(
f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
)
return tensor
def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
if index.fqn not in state_dict:
raise ValueError(f"Could not find FQN: '{index.fqn}'")
obj = state_dict[index.fqn]
if isinstance(obj, torch.Tensor):
return find_tensor_shard(obj, index)
elif index.offset is not None:
raise ValueError(
f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
)
return obj
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
return [i_a + i_b for i_a, i_b in zip(a, b)]
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
return [i_a - i_b for i_a, i_b in zip(a, b)]
class _ReaderView(io.IOBase):
def __init__(self, base_stream: io.IOBase, offset: int, len: int):
super().__init__()
self.offset = offset
self.len = len
self.base_stream = base_stream
self.seek(0)
def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
if __whence == os.SEEK_SET:
__offset = self.offset + __offset
elif __whence == os.SEEK_END:
__whence = os.SEEK_SET
__offset = (self.offset + self.len) - __offset
return self.base_stream.seek(__offset, __whence)
def tell(self) -> int:
return self.base_stream.tell() - self.offset
def readable(self) -> bool:
return self.base_stream.readable()
def seekable(self) -> bool:
return self.base_stream.seekable()
def readinto(self, b):
return self.base_stream.readinto(b) # type: ignore[attr-defined]
def read(self, size=-1):
return self.base_stream.read(size)
def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
# FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
return _ReaderView(file, offset, length)
def _normalize_device_info(device_type: str, device_id: int) -> str:
"""Device info normalization."""
if device_type == "cpu":
return "cpu"
return f"{device_type}:{device_id}"
# TODO: integrate with distributed logging flag
ENABLE_PROFILE = False
@contextmanager
def _profile():
# Only log the profiling when it is enable and is on rank0 or dist is not
# avaiable.
if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0):
profiler = cProfile.Profile()
profiler.enable()
try:
yield
finally:
profiler.disable()
stats = Stats(profiler)
stats.sort_stats("time").print_stats(10)
else:
yield
def _api_bc_check(func):
@wraps(func)
def inner_func(*args, **kwargs) -> Any:
if len(args) == 2:
warnings.warn(
f"The argument order of {func.__name__} has been changed. "
"Please check the document to avoid future breakages."
)
sig = inspect.signature(func)
kwonlyargs = [
p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
]
if "storage_writer" in kwonlyargs:
assert "storage_writer" not in kwargs, (args, kwargs)
kwargs["storage_writer"] = args[1]
elif "storage_reader" in kwonlyargs:
assert "storage_reader" not in kwargs, (args, kwargs)
kwargs["storage_reader"] = args[1]
else:
raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")
return func(args[0], **kwargs)
else:
return func(*args, **kwargs)
return inner_func