I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,9 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import socket
from string import Template
from typing import Any, List
def get_env_variable_or_raise(env_name: str) -> str:
r"""
Tries to retrieve environment variable. Raises ``ValueError``
if no environment variable found.
Args:
env_name (str): Name of the env variable
"""
value = os.environ.get(env_name, None)
if value is None:
msg = f"Environment variable {env_name} expected, but not set"
raise ValueError(msg)
return value
def get_socket_with_port() -> socket.socket:
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
raise RuntimeError("Failed to create a socket")
class macros:
"""
Defines simple macros for caffe2.distributed.launch cmd args substitution
"""
local_rank = "${local_rank}"
@staticmethod
def substitute(args: List[Any], local_rank: str) -> List[str]:
args_sub = []
for arg in args:
if isinstance(arg, str):
sub = Template(arg).safe_substitute(local_rank=local_rank)
args_sub.append(sub)
else:
args_sub.append(arg)
return args_sub

View File

@ -0,0 +1,10 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .cycling_iterator import CyclingIterator # noqa: F401
from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401

View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
class CyclingIterator:
"""
An iterator decorator that cycles through the
underlying iterator "n" times. Useful to "unroll"
the dataset across multiple training epochs.
The generator function is called as ``generator_fn(epoch)``
to obtain the underlying iterator, where ``epoch`` is a
number less than or equal to ``n`` representing the ``k``th cycle
For example if ``generator_fn`` always returns ``[1,2,3]``
then ``CyclingIterator(n=2, generator_fn)`` will iterate through
``[1,2,3,1,2,3]``
"""
def __init__(self, n: int, generator_fn, start_epoch=0):
self._n = n
self._epoch = start_epoch
self._generator_fn = generator_fn
self._iter = generator_fn(self._epoch)
def __iter__(self):
return self
def __next__(self):
try:
return next(self._iter)
except StopIteration as eod: # eod == end of data
if self._epoch < self._n - 1:
self._epoch += 1
self._iter = self._generator_fn(self._epoch)
return self.__next__()
else:
raise eod

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.utils.data.distributed import DistributedSampler
class ElasticDistributedSampler(DistributedSampler):
"""
Sampler that restricts data loading to a subset of
the dataset for elastic training.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
start_index (optional): Which index of the dataset to start sampling from
"""
def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
if start_index >= len(dataset):
raise ValueError(
f"Start index {start_index} should be less than dataset size {len(dataset)}"
)
self.start_index = start_index
self.num_samples = int(
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = (
torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
.add(self.start_index)
.tolist()
)
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import os
import socket
from contextlib import closing
from typing import Optional
import torch.distributed as dist
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.elastic.utils.store import barrier
__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"]
logger = get_logger(__name__)
_ADDRESS_IN_USE = "Address already in use"
_SOCKET_TIMEOUT = "Socket Timeout"
_TCP_STORE_INIT = "_tcp_store/num_members"
def create_c10d_store(
is_server: bool,
server_addr: str,
server_port: int = -1,
world_size: int = 1,
timeout: float = (60 * 10), # 10 min
wait_for_workers: bool = True,
retries=3,
use_libuv: Optional[bool] = None,
):
if use_libuv is not None:
logger.warning(
"argument use_libuv is deprecated and ignored. Set USE_LIBUV environment "
'variable to "0" to disable libuv, or "1" to enable it. If the env var '
"is not set, libuv will be used by default."
)
# check os.environ for use_libuv
use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option
if server_port == -1 and world_size > 1:
raise ValueError(
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
)
if server_port != -1:
logger.info("sever_port: %s, specified, ignoring retries", server_port)
# only retry when server_port is NOT static
attempt = retries if server_port == -1 else 1
while True:
if server_port != -1:
port = server_port
else:
port = get_free_port()
logger.info(
"Creating c10d store on %s:%s\n"
" world_size : %s\n"
" is_server : %s\n"
" timeout(sec): %s\n"
" use_libuv : %s\n",
server_addr,
port,
world_size,
is_server,
timeout,
use_libuv,
)
try:
store = dist.TCPStore(
host_name=server_addr,
port=port,
world_size=world_size,
is_master=is_server,
timeout=datetime.timedelta(seconds=timeout),
wait_for_workers=wait_for_workers,
use_libuv=use_libuv,
)
# skips full rank check when we don't have to wait for all workers
if wait_for_workers:
_check_full_rank(store, world_size, timeout=timeout)
logger.info("Successfully created c10d store")
return store
except RuntimeError as e:
# this is brittle, but the underlying exception type is not properly pybinded
# so we parse the error msg for now, interestingly this is how torch itself
# detects timeouts and port conflicts in their own unittests
# see - caffe2/torch/testing/_internal/common_utils.py
# TODO properly map the exceptions in pybind (c10d/init.cpp)
if str(e) == _ADDRESS_IN_USE: # this will only happen on the server
if attempt < retries:
logger.warning(
"port: %s already in use, attempt: [%s/%s]",
port,
attempt,
retries,
)
attempt += 1
else:
raise RuntimeError(
f"on {server_addr}, port: {port} already in use"
) from e
else:
raise
def _check_full_rank(store, world_size, timeout):
try:
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
except RuntimeError as e:
if str(e) == _SOCKET_TIMEOUT:
raise TimeoutError(
f"timed out waiting for all {world_size} members to join"
) from e
else:
raise
def get_free_port():
"""
Returns an unused port on localhost.
This function finds an unused port on localhost by opening to socket to bind
to a port and then closing it.
Returns:
int: an unused port on localhost
Example:
>>> # xdoctest: +SKIP("Nondeterministic")
>>> get_free_port()
63976
..note:
The port returned by :func:`get_free_port` is not reserved and may be
taken by another process after this function returns.
"""
sock = get_socket_with_port()
with closing(sock):
return sock.getsockname()[1]
def get_socket_with_port() -> socket.socket:
"""
Returns a free port on localhost that is "reserved" by binding a temporary
socket on it. Close the socket before passing the port to the entity
that requires it. Usage example
::
sock = _get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
sock.close()
# there is still a race-condition that some other process
# may grab this port before func() runs
func(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
logger.warning("Socket creation attempt failed.", exc_info=e)
raise RuntimeError("Failed to create a socket")

View File

@ -0,0 +1,14 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
def get_log_level() -> str:
"""
Return default log level for pytorch.
"""
return "WARNING"

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
import warnings
from typing import Optional
from torch.distributed.elastic.utils.log_level import get_log_level
def get_logger(name: Optional[str] = None):
"""
Util function to set up a simple logger that writes
into stderr. The loglevel is fetched from the LOGLEVEL
env. variable or WARNING as default. The function will use the
module name of the caller if no name is provided.
Args:
name: Name of the logger. If no name provided, the name will
be derived from the call stack.
"""
# Derive the name of the caller, if none provided
# Use depth=2 since this function takes up one level in the call stack
return _setup_logger(name or _derive_module_name(depth=2))
def _setup_logger(name: Optional[str] = None):
logger = logging.getLogger(name)
logger.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
return logger
def _derive_module_name(depth: int = 1) -> Optional[str]:
"""
Derives the name of the caller module from the stack frames.
Args:
depth: The position of the frame in the stack.
"""
try:
stack = inspect.stack()
assert depth < len(stack)
# FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
frame_info = stack[depth]
module = inspect.getmodule(frame_info[0])
if module:
module_name = module.__name__
else:
# inspect.getmodule(frame_info[0]) does NOT work (returns None) in
# binaries built with @mode/opt
# return the filename (minus the .py extension) as modulename
filename = frame_info[1]
module_name = os.path.splitext(os.path.basename(filename))[0]
return module_name
except Exception as e:
warnings.warn(
f"Error deriving logger module name, using <None>. Exception: {e}",
RuntimeWarning,
)
return None

View File

@ -0,0 +1,225 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from datetime import timedelta
from typing import Callable, Iterable, List, Optional
import torch
DistStoreError = torch._C._DistStoreError
_NUM_MEMBERS = "/num_members"
_LAST_MEMBER_CHECKIN = "/last_member"
_TRACE = "/TRACE"
_TRACING_GATE = "/TRACING_GATE"
_MAX_TRACE_MISSING_RANKS = 16
__all__ = ["store_timeout", "get_all", "synchronize", "barrier"]
@contextmanager
def store_timeout(store, timeout: float):
"""
This sets the timeout and then restores the old timeout when the context
manager exits.
Args:
store: the store to set the timeout on
timeout: the timeout to set
"""
old_timeout = store.timeout
store.set_timeout(timedelta(seconds=timeout))
yield
store.set_timeout(old_timeout)
def get_all(store, rank: int, prefix: str, world_size: int):
r"""
Given a store and a prefix, the method goes through the array of keys
of the following format: ``{prefix}{idx}``, where idx is in a range
from 0 to size, and tries to retrieve the data.
The Rank0 process waits at the end to make sure all other processes
finished the procedure before exiting.
Usage
::
values = get_all(store, 'torchelastic/data', 3)
value1 = values[0] # retrieves the data for key torchelastic/data0
value2 = values[1] # retrieves the data for key torchelastic/data1
value3 = values[2] # retrieves the data for key torchelastic/data2
"""
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])
barrier_key = _barrier_nonblocking(
store=store,
world_size=world_size,
key_prefix=f"{prefix}/finished",
)
if rank == 0:
# Rank0 runs the TCPStore daemon, as a result it needs to exit last.
# Otherwise, the barrier may timeout if rank0 process finished the work
# before other processes finished `get_all` method
store.wait([barrier_key])
return data_arr
def synchronize(
store,
data: bytes,
rank: int,
world_size: int,
key_prefix: str,
timeout: float = 300,
) -> List[bytes]:
"""
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
The ``data`` will be available on each of the agents.
Note: The data on the path is not deleted, as a result there can be stale data if
you use the same key_prefix twice.
Time complexity: O(N) per worker, O(N^2) globally.
"""
with store_timeout(store, timeout):
store.set(f"{key_prefix}{rank}", data)
agent_data = get_all(store, rank, key_prefix, world_size)
return agent_data
def _try_detecting_missing_ranks(
store,
world_size: int,
key_prefix: str,
rank: int,
rank_decoder: Callable[[int], str],
trace_timeout: float,
) -> Optional[Iterable[str]]:
store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>")
def _find_missing_ranks():
missing_rank_info = set()
ranks_missing = 0
for i in range(1, world_size):
# reduce noise, assuming in general 8 ranks per node
# It is valuable to know that 1 or >1 nodes have timed-out.
if ranks_missing >= _MAX_TRACE_MISSING_RANKS:
break
try:
if ranks_missing == 0:
store.wait(
[f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout)
)
else:
# use a shortest timeout, some ranks have failed to check-in
store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1))
except DistStoreError:
ranks_missing += 1
missing_rank_info.add(rank_decoder(i))
return missing_rank_info
def _checkin():
try:
store.wait([f"{key_prefix}{_TRACING_GATE}"])
return [f"[<check rank 0 ({rank_decoder(0)}) for missing rank info>]"]
except DistStoreError:
# in case rank0 is the source of the timeout, original exception will be raised
return None
if rank == 0:
missing_rank_info = _find_missing_ranks()
store.set(f"{key_prefix}{_TRACING_GATE}", "<val_ignored>")
return missing_rank_info
else:
return _checkin()
def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
"""
Does all the non-blocking operations for a barrier and returns the final key
that can be waited on.
"""
num_members_key = key_prefix + _NUM_MEMBERS
last_member_key = key_prefix + _LAST_MEMBER_CHECKIN
idx = store.add(num_members_key, 1)
if idx == world_size:
store.set(last_member_key, "<val_ignored>")
return last_member_key
def barrier(
store,
world_size: int,
key_prefix: str,
barrier_timeout: float = 300,
rank: Optional[int] = None,
rank_tracing_decoder: Optional[Callable[[int], str]] = None,
trace_timeout: float = 10,
) -> None:
"""
A global lock between agents. This will pause all workers until at least
``world_size`` workers respond.
This uses a fast incrementing index to assign waiting ranks and a success
flag set by the last worker.
Time complexity: O(1) per worker, O(N) globally.
Optionally, passing rank will enable tracing of missing ranks on timeouts.
`rank_tracing_decoder` lambda arg can be used to convert rank data
into a more meaninful information at an app level (e.g. hostname).
Note: Since the data is not removed from the store, the barrier can be used
once per unique ``key_prefix``.
"""
if rank is None:
assert rank_tracing_decoder is None, "Tracing requires rank information"
with store_timeout(store, barrier_timeout):
last_member_key = _barrier_nonblocking(
store=store, world_size=world_size, key_prefix=key_prefix
)
try:
store.wait([last_member_key])
except DistStoreError as e:
if rank is None:
raise e
else:
missing_ranks = _try_detecting_missing_ranks(
store,
world_size,
key_prefix,
rank,
rank_tracing_decoder or (lambda x: str(x)),
trace_timeout,
)
if missing_ranks is not None:
raise DistStoreError(
"Timed out waiting on barrier on "
"rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format(
rank,
key_prefix,
world_size,
f"[{', '.join(missing_ranks)}]",
barrier_timeout,
)
) from None
else:
raise e