I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,77 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Torchelastic agent and user worker failover contract:
**TL;DR;**:
* TE(torchelastic) expects user workers to finish with the 5 minutes drift
* It is better to design DDP app to fail for all workers, rather than a single one.
* TE does not synchronize number of restarts between agents
* TE re-rendezvous does not trigger restart decrease
* When a single agent finishes its job(successfully or not), it will close rendezvous.
If other agents still have workers in progress, they will be terminated.
* Based on above, scale down does not work if at least single agent finishes the job.
* When Scale up is detected by agents, it will not decrease ``max_restarts``
In general TE(torchelastic) can launch arbitrary user code, but there is some
clarifications need to be done around what failover mechanism torchelastic
provides and what failover mechanism it expects from user workers.
Torchelastic currently supports DDP style applications. That means that
TE expects *ALL* workers finish approximately at the same time. In practice,
it is nearly to impossible to guarantee that all workers in arbitrary
DDP application finish at the time, so TE provides a finalization barrier
that waits for TIMEOUT(5 minutes) for worker finalization.
**Worker Failure**
When worker fails, TE will check the number of restarts
available, if there is more than 0 restarts, TE will start a new rendezvous
round and restart the worker process. New rendezvous round will other
TE agents to terminate their workers.
.. note:: The TE agent does not synchronize restarts between themselves.
When a single agent performs restart, it will trigger a local ``max_restarts``
decrease, other agent will not decrease their ``max_restarts``.
the user to run the distributed application locally on a dev host.
A single worker failure can cause the whole cluster to fail:
If a single worker is constantly failing, it will cause the TE agent
``max_restarts`` to go to zero. This will cause an agent to finish its
work and close rendezvous. If there are any other workers on different
agents, they will be terminated.
**Re-Rendezvous**
Re-rendezvous occurs when TE agents detect a new node
trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents
will terminate its workers and start a new rendezvous round.
Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous
has already max_nodes, the new node won't be added to the wait list right
away since there is no need to tear down a rendezvous that is already fully
utilized. The new node will wait until its timeout (600 secs by default)
and periodically check the number of participants. If the number becomes
less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs.
*Scale up event*. When scale up event happens, torchelastic rendezvous
will detect that there are new nodes trying to join. Torchelastic agent
will stop all workers and perform re-rendezvous. Note: when scale up event
happens, *``max_restarts``* will *not* decrease.
*Scale down event*. When scale down event happens, rendezvous will not
notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` ,
it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` ,
TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*.
"""

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
The elastic agent is the control plane of torchelastic.
It is a process that launches and manages underlying worker processes.
The agent is responsible for:
1. Working with distributed torch: the workers are started with all the
necessary information to successfully and trivially call
``torch.distributed.init_process_group()``.
2. Fault tolerance: monitors workers and upon detecting worker failures
or unhealthiness, tears down all workers and restarts everyone.
3. Elasticity: Reacts to membership changes and restarts workers with the new
members.
The simplest agents are deployed per node and works with local processes.
A more advanced agent can launch and manage workers remotely. Agents can
be completely decentralized, making decisions based on the workers it manages.
Or can be coordinated, communicating to other agents (that manage workers
in the same job) to make a collective decision.
"""
from .api import ( # noqa: F401
ElasticAgent,
RunResult,
SimpleElasticAgent,
Worker,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE

View File

@ -0,0 +1,942 @@
# mypy: ignore-errors
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import json
import os
import signal
import socket
import time
import traceback
import warnings
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch.distributed.elastic.rendezvous as rdzv
import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.events import Event, EventSource, record
from torch.distributed.elastic.metrics import prof, put_metric
from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException
from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError
from torch.distributed.elastic.utils.logging import get_logger
__all__ = [
"WorkerSpec",
"Worker",
"WorkerState",
"WorkerGroup",
"RunResult",
"ElasticAgent",
"SimpleElasticAgent",
]
_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state"
DEFAULT_ROLE = "default"
logger = get_logger(__name__)
@dataclass
class WorkerSpec:
"""Blueprint information about a particular type of worker.
For a given role, there must only exist a single worker spec.
Worker spec is expected to be homogeneous across all nodes (machine),
that is each node runs the same number of workers for a particular spec.
Args:
role: user-defined role for the workers with this spec
local_world_size: number local workers to run
fn: (deprecated use entrypoint instead)
entrypoint: worker function or command
args: arguments to pass to ``entrypoint``
rdzv_handler: handles rdzv for this set of workers
max_restarts: number of max retries for the workers
monitor_interval: monitor status of workers every ``n`` seconds
master_port: fixed port to run the c10d store on rank 0
if not specified then will chose a random free port
master_addr: fixed master_addr to run the c10d store on rank 0
if not specified then will chose hostname on agent rank 0
redirects: redirect std streams to a file,
selectively redirect for a particular
local rank by passing a map
tee: tees the specified std stream(s) to console + file,
selectively tee for a particular local rank by passing a map,
takes precedence over ``redirects`` settings.
"""
role: str
local_world_size: int
rdzv_handler: rdzv.RendezvousHandler
fn: Optional[Callable] = None
# TODO @kiuk - make entrypoint a required field
entrypoint: Union[Callable, str, None] = None
args: Tuple = ()
max_restarts: int = 3
monitor_interval: float = 0.1
master_port: Optional[int] = None
master_addr: Optional[str] = None
local_addr: Optional[str] = None
def __post_init__(self):
assert self.local_world_size > 0
assert self.monitor_interval > 0
if self.fn:
warnings.warn(
"WorkerSpec.fn will be deprecated,"
" please use WorkerSpec.entrypoint instead",
category=DeprecationWarning,
)
self.entrypoint = self.fn
assert self.entrypoint
def get_entrypoint_name(self):
"""Get the entry point name.
If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__``
else if the entrypoint is a binary (e.g. ``str``), returns the binary name.
"""
if isinstance(self.entrypoint, str):
return os.path.basename(self.entrypoint)
else:
assert self.entrypoint is not None
return self.entrypoint.__qualname__
class Worker:
"""A worker instance.
Contrast this with ``WorkerSpec`` that represents the specifications of a
worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to
a ``WorkerSpec`` as an object is to a class.
The ``id`` of the worker is interpreted
by the specific implementation of ``ElasticAgent``. For a local
agent, it could be the ``pid (int)`` of the worker, for a remote
agent it could be encoded as ``host:port (string)``.
Args:
id (Any): uniquely identifies a worker (interpreted by the agent)
local_rank (int): local rank of the worker
global_rank (int): global rank of the worker
role_rank (int): rank of the worker across all workers that have the same role
world_size (int): number of workers (globally)
role_world_size (int): number of workers that have the same role
"""
__slots__ = [
"id",
"local_rank",
"global_rank",
"role_rank",
"world_size",
"role_world_size",
]
def __init__(
self,
local_rank: int,
global_rank: int = -1,
role_rank: int = -1,
world_size: int = -1,
role_world_size: int = -1,
):
# unique identifier for this worker
self.id: Any = None
# rank of the worker among workers with the same role being monitored
# by the same ``agent`` instance.
self.local_rank: int = local_rank
# rank of the worker among all the workers across all roles
# across all ``agent`` instances.
# Global rank is not stable between re-rendezvous.
self.global_rank: int = global_rank
# rank of the worker among all the workers with the same role
# across all ``agent`` instances.
# Role rank is not stable between re-rendezvous.
self.role_rank: int = role_rank
# total number of workers (globally). Due to elasticity
# the world size may change between re-rendezvous.
self.world_size: int = world_size
# total number of workers that share the same role. Due to elasticity
# the role world size may change between re-rendezvous.
self.role_world_size: int = role_world_size
def __str__(self):
return (
f"local_rank={self.local_rank},global_rank={self.global_rank}"
f",role_rank={self.role_rank},world_size={self.world_size}"
f",role_world_size={self.role_world_size}"
)
def __repr__(self):
return str(self)
class WorkerState(str, Enum):
"""A state of the ``WorkerGroup``.
Workers in a worker group change state as a unit. If a single worker
in a worker group fails the entire set is considered failed::
UNKNOWN - agent lost track of worker group state, unrecoverable
INIT - worker group object created not yet started
HEALTHY - workers running and healthy
UNHEALTHY - workers running and unhealthy
STOPPED - workers stopped (interrupted) by the agent
SUCCEEDED - workers finished running (exit 0)
FAILED - workers failed to successfully finish (exit !0)
A worker group starts from an initial ``INIT`` state,
then progresses to ``HEALTHY`` or ``UNHEALTHY`` states,
and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state.
Worker groups can be interrupted and temporarily put into ``STOPPED`` state
by the agent. Workers in ``STOPPED`` state are scheduled to be restarted
in the near future by the agent. Some examples of workers being put into
``STOPPED`` state are:
1. Worker group failure|unhealthy observed
2. Membership change detected
When actions (start, stop, rdzv, retry, etc) on worker group fails
and results in the action being partially applied to the worker group
the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled
exceptions during state change events on the agent. The agent is not
expected to recover worker groups in ``UNKNOWN`` state and is better off
self terminating and allowing the job manager to retry the node.
"""
UNKNOWN = "UNKNOWN"
INIT = "INIT"
HEALTHY = "HEALTHY"
UNHEALTHY = "UNHEALTHY"
STOPPED = "STOPPED"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
@staticmethod
def is_running(state: "WorkerState") -> bool:
"""Return the state of the Worker.
Returns:
True if the worker state represents workers still running
(e.g. that the process exists but not necessarily healthy).
"""
return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}
class WorkerGroup:
"""A set of ``Worker`` instances.
The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker
group contains cross instance workers or not depends on the implementation of the agent.
"""
__slots__ = [
"spec",
"workers",
"store",
"group_rank",
"group_world_size",
"state",
"master_addr",
"master_port",
]
def __init__(self, spec: WorkerSpec):
self.spec = spec
self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
# assigned after rdzv
self.store = None
self.group_rank = None
self.group_world_size = None
self.master_addr = None
self.master_port = None
self.state = WorkerState.INIT
class _RoleInstanceInfo:
"""The class is used by the agent to exchange the information with other agents.
The information is used to determine the rank of the workers that agent
manages in heterogeneous environments, where different agents can have
different number of workers.
"""
__slots__ = ["role", "rank", "local_world_size"]
def __init__(self, role: str, rank: int, local_world_size: int):
r"""Initialize the agent class instance.
Args:
role (str): user-defined role for the workers with this spec
rank (int): the rank of the agent
local_world_size (int): number of local workers to run
"""
self.role = role
self.rank = rank
self.local_world_size = local_world_size
def serialize(self) -> bytes:
dict_data = {
"role": self.role,
"rank": self.rank,
"local_world_size": self.local_world_size,
}
return json.dumps(dict_data).encode(encoding="UTF-8")
@staticmethod
def deserialize(data: bytes):
dict_data = json.loads(data.decode(encoding="UTF-8"))
return _RoleInstanceInfo(
dict_data["role"], dict_data["rank"], dict_data["local_world_size"]
)
@staticmethod
def compare(obj1, obj2) -> int:
if obj1.role == obj2.role:
return obj1.rank - obj2.rank
elif obj1.role > obj2.role:
return 1
else:
return -1
@staticmethod
def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]:
start_idx, end_idx = -1, -1
for idx, role_info in enumerate(roles_infos):
if role_info.role == role:
if start_idx == -1:
start_idx = idx
end_idx = idx
return (start_idx, end_idx)
@dataclass
class RunResult:
"""Return results of the worker executions.
Run results follow an "all-or-nothing" policy where the run is successful if and
only if ALL local workers managed by this agent complete successfully.
If the result is successful (e.g. ``is_failed() = False``) then the ``return_values``
field contains the outputs (return values) of the workers managed by THIS agent mapped
by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of
global rank 0.
.. note:: ``return_values`` are only meaningful for when the worker entrypoint
is a function. Workers specified as a binary entrypoint do not canonically
have a return value and the ``return_values`` field is meaningless and
may be empty.
If ``is_failed()`` returns ``True`` then the ``failures`` field contains the
failure information, again, mapped by the GLOBAL rank of the worker that failed.
The keys in ``return_values`` and ``failures`` are mutually exclusive, that is,
a worker's final state can only be one of: succeeded, failed. Workers intentionally
terminated by the agent according to the agent's restart policy, are not represented
in either ``return_values`` nor ``failures``.
"""
state: WorkerState
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
def is_failed(self) -> bool:
return self.state == WorkerState.FAILED
def _get_fq_hostname() -> str:
return socket.getfqdn(socket.gethostname())
class ElasticAgent(abc.ABC):
"""An agent process responsible for managing one or more worker processes.
The worker processes are assumed to be regular distributed PyTorch scripts.
When the worker process is created by the agent, the agent provides the
necessary information for the worker processes to properly initialize
a torch process group.
The exact deployment topology and ratio of agent-to-worker is dependent
on the specific implementation of the agent and the user's job placement
preferences. For instance, to run a distributed training job on GPU with
8 trainers (one per GPU) one can:
1. Use 8 x single GPU instances, place an agent per instance, managing
1 worker per agent.
2. Use 4 x double GPU instances, place an agent per instance, managing
2 workers per agent.
3. Use 2 x quad GPU instances, place an agent per instance, managing
4 workers per agent.
4. Use 1 x 8 GPU instance, place an agent per instance, managing
8 workers per agent.
Usage
::
group_result = agent.run()
if group_result.is_failed():
# workers failed
failure = group_result.failures[0]
logger.exception("worker 0 failed with exit code : %s", failure.exit_code)
else:
return group_result.return_values[0] # return rank 0's results
"""
@abc.abstractmethod
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
"""Run the agent.
Supports retrying the worker group on failures up to ``max_restarts``.
Returns:
The result of the execution, containing the return values or
failure details for each worker mapped by the worker's global rank.
Raises:
Exception - any other failures NOT related to worker process
"""
raise NotImplementedError
@abc.abstractmethod
def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
"""Return the ``WorkerGroup`` for the given ``role``.
Note that the worker group is a mutable object and hence in a
multi-threaded/process environment it may change state.
Implementors are encouraged (but not required) to return
a defensive read-only copy.
"""
raise NotImplementedError
class SimpleElasticAgent(ElasticAgent):
"""An ``ElasticAgent`` that manages one particular type of worker role.
An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec``
such as one particular type of worker role.
"""
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
self._worker_group = WorkerGroup(spec)
self._remaining_restarts = self._worker_group.spec.max_restarts
self._store = None
self._exit_barrier_timeout = exit_barrier_timeout
self._total_execution_time = 0
def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
return self._worker_group
@abc.abstractmethod
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
r"""Start ``worker_group.spec.local_world_size`` number of workers.
This is according to worker spec for the worker group .
Returns a map of ``local_rank`` to worker ``id``.
"""
raise NotImplementedError
@abc.abstractmethod
def _stop_workers(
self, worker_group: WorkerGroup, is_restart: bool = False
) -> None:
r"""Stop all workers in the given worker group.
Implementors must deal with workers in all states defined by
``WorkerState``. That is, it must gracefully handle stopping
non-existent workers, unhealthy (stuck) workers, etc.
"""
raise NotImplementedError
@abc.abstractmethod
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
r"""Check on the workers for the ``worker_group``.
This function also returns the new state of the worker group.
"""
raise NotImplementedError
@abc.abstractmethod
def _shutdown(
self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
) -> None:
"""Clean up any resources that were allocated during the agent's work.
Args:
death_sig: Signal to send to the child process, SIGTERM is default
"""
raise NotImplementedError
@prof
def _rendezvous(self, worker_group: WorkerGroup) -> None:
r"""Run rendezvous for the workers specified by the worker spec.
Assigns workers a new global rank and world size.
Updates the rendezvous store for the worker group.
"""
spec = worker_group.spec
with self.record_duration("RENDEZVOUS"):
rdzv_info = spec.rdzv_handler.next_rendezvous()
store = rdzv_info.store
group_rank = rdzv_info.rank
group_world_size = rdzv_info.world_size
# master_addr/master_port could be explicitly overriden
# TODO: BC - specific to static rdzv and can be simplifed further
master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr
master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port
self._store = store
with self.record_duration("ASSIGN_WORKER_RANKS"):
workers = self._assign_worker_ranks(
store, group_rank, group_world_size, spec
)
worker_group.workers = workers
worker_group.store = store
worker_group.group_rank = group_rank
worker_group.group_world_size = group_world_size
worker_group.master_addr = master_addr
worker_group.master_port = master_port
restart_count = spec.max_restarts - self._remaining_restarts
logger.info(
"[%(role)s] Rendezvous complete for workers. Result:\n"
" restart_count=%(restart_count)s\n"
" master_addr=%(master_addr)s\n"
" master_port=%(master_port)s\n"
" group_rank=%(group_rank)s\n"
" group_world_size=%(group_world_size)s\n"
" local_ranks=%(local_ranks)s\n"
" role_ranks=%(role_ranks)s\n"
" global_ranks=%(global_ranks)s\n"
" role_world_sizes=%(role_world_sizes)s\n"
" global_world_sizes=%(global_world_sizes)s\n",
{
"role": spec.role,
"restart_count": restart_count,
"master_addr": master_addr,
"master_port": master_port,
"group_rank": group_rank,
"group_world_size": group_world_size,
"local_ranks": [worker.local_rank for worker in workers],
"role_ranks": [worker.role_rank for worker in workers],
"global_ranks": [worker.global_rank for worker in workers],
"role_world_sizes": [worker.role_world_size for worker in workers],
"global_world_sizes": [worker.world_size for worker in workers],
},
)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _assign_worker_ranks(
self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
) -> List[Worker]:
"""Determine proper ranks for worker processes.
The rank assignment is done according to the following algorithm:
1. Each agent writes its configuration(group_rank, group_world_size
, num_workers) to the common store.
2. The rank 0 agent reads all the role_info from the store and
determines each agents worker ranks.
3. Determine the global rank: the global rank of the workers is computed
by cumulative sum of the local_world_size for all workers in front of it.
For efficiency reasons each worker is assigned a base global rank
such that it's workers are in the range [base_global_rank,
base_global_rank + local_world_size).
4. Determine the role rank: The role rank is determined using the algorithms
in the point 3 with the exception that the ranks are calculated with
respect to the role name.
5. The rank 0 agent writes the assigned ranks to the store.
6. Each agent reads the assigned ranks from the store.
Time complexity: each worker O(1), rank0 O(n), overall O(n)
"""
ROLE_INFO_PREFIX = "torchelastic/role_info/"
ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/"
agent_role_info = _RoleInstanceInfo(
spec.role, group_rank, spec.local_world_size
)
store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize())
# tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations.
if group_rank == 0:
role_infos_bytes = store.multi_get(
[f"torchelastic/role_info/{i}" for i in range(group_world_size)]
)
role_infos = [
_RoleInstanceInfo.deserialize(info_bytes)
for info_bytes in role_infos_bytes
]
role_sizes = defaultdict(lambda: 0)
global_size = 0
for role_info in role_infos:
role_sizes[role_info.role] += role_info.local_world_size
global_size += role_info.local_world_size
base_global_rank = 0
role_ranks = defaultdict(lambda: 0)
keys = []
values = []
for i, role_info in enumerate(role_infos):
keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}")
values.append(
json.dumps(
[
base_global_rank,
global_size,
role_ranks[role_info.role],
role_sizes[role_info.role],
]
)
)
base_global_rank += role_info.local_world_size
role_ranks[role_info.role] += role_info.local_world_size
store.multi_set(keys, values)
# get will block until the data is available in the store.
(
base_global_rank,
global_world_size,
base_role_rank,
role_world_size,
) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}"))
workers = []
for local_rank in range(spec.local_world_size):
worker = Worker(
local_rank=local_rank,
global_rank=base_global_rank + local_rank,
role_rank=base_role_rank + local_rank,
world_size=global_world_size,
role_world_size=role_world_size,
)
workers.append(worker)
return workers
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _initialize_workers(self, worker_group: WorkerGroup) -> None:
r"""Start a fresh set of workers for the worker_group.
Essentially, a rendezvous followed by a ``start_workers``.
The caller should first call ``_stop_workers()`` to stop running workers
prior to calling this method.
Optimistically sets the state of the worker group that
just started as ``HEALTHY`` and delegates the actual monitoring
of state to ``_monitor_workers()`` method
"""
role = worker_group.spec.role
logger.info("[%s] Rendezvous'ing worker group", role)
# TODO after stopping workers, wait at least monitor_interval*2 for
# workers on different nodes to fail on a collective op before waiting
# on the rdzv barrier, this way we ensure that nodes enter rdzv
# at around the same time and reduce false positive rdzv timeout errors
self._rendezvous(worker_group)
logger.info("[%s] Starting worker group", role)
worker_ids = self._start_workers(worker_group)
for local_rank, w_id in worker_ids.items():
worker = worker_group.workers[local_rank]
worker.id = w_id
worker_group.state = WorkerState.HEALTHY
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
"""Restart (stops, rendezvous, starts) all local workers in the group."""
role = worker_group.spec.role
logger.info("[%s] Stopping worker group", role)
self._stop_workers(worker_group, is_restart=True)
worker_group.state = WorkerState.STOPPED
self._initialize_workers(worker_group)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
start_time = time.monotonic()
shutdown_called: bool = False
try:
result = self._invoke_run(role)
self._total_execution_time = int(time.monotonic() - start_time)
self._record_metrics(result)
self._record_worker_events(result)
return result
except RendezvousGracefulExitError as e:
logger.info("Rendezvous gracefully exited: %s", e)
except SignalException as e:
logger.warning("Received %s death signal, shutting down workers", e.sigval)
self._shutdown(e.sigval)
shutdown_called = True
raise
finally:
if not shutdown_called:
self._shutdown()
# record the execution time in case there were any exceptions during run.
self._total_execution_time = int(time.monotonic() - start_time)
def get_event_failed(self) -> Event:
return self._construct_event(
state="FAILED",
source=EventSource.AGENT,
raw_error=traceback.format_exc(),
)
def get_event_succeeded(self) -> Event:
return self._construct_event(
state="SUCCEEDED",
source=EventSource.AGENT,
)
def _record_worker_events(self, result: RunResult) -> None:
for worker in self._worker_group.workers:
failure = result.failures.get(worker.global_rank)
state: str = self._get_worker_state(worker, result)
raw_error = json.dumps(failure.error_file_data) if failure else None
record(self._construct_event(state, EventSource.WORKER, worker, raw_error))
def _get_worker_state(self, worker: Worker, result: RunResult) -> str:
failure = result.failures.get(worker.global_rank)
if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure:
# The worker got terminated by the torchelastic agent via SIGTERM signal
return "TERMINATED"
elif failure or worker.global_rank in result.return_values:
return result.state.value
else:
raise ValueError(f"Unknown worker: {worker.global_rank}")
@contextmanager
def record_duration(self, state: str):
start_time = time.perf_counter()
try:
yield
finally:
end_time = time.perf_counter()
duration_ms = (end_time - start_time) * 1000
record(
self._construct_event(
state=state, source=EventSource.AGENT, duration_ms=duration_ms
)
)
def _construct_event(
self,
state: str,
source: EventSource,
worker: Optional[Worker] = None,
raw_error: Optional[str] = None,
duration_ms: Optional[float] = None,
) -> Event:
wg = self._worker_group
spec = wg.spec
md = {
"group_world_size": wg.group_world_size,
"entry_point": spec.get_entrypoint_name(),
}
if worker:
md["local_rank"] = (worker.local_rank,)
md["role_rank"] = (worker.role_rank,)
md["role_world_size"] = (worker.role_world_size,)
global_rank = worker.global_rank
worker_id = str(worker.id)
else:
global_rank = None
worker_id = None
md_str = json.dumps(md)
metadata = {
"run_id": spec.rdzv_handler.get_run_id(),
"global_rank": global_rank,
"group_rank": wg.group_rank,
"worker_id": worker_id,
"role": spec.role,
"hostname": _get_fq_hostname(),
"state": state,
"total_run_time": self._total_execution_time,
"rdzv_backend": spec.rdzv_handler.get_backend(),
"raw_error": raw_error,
"metadata": md_str,
"agent_restarts": spec.max_restarts - self._remaining_restarts,
"duration_ms": duration_ms,
}
return Event(
f"torchelastic.worker.status.{state}", source=source, metadata=metadata
)
def _record_metrics(self, group_results: RunResult):
is_failed = group_results.is_failed()
self._record_flakiness_metric(is_failed)
spec = self._worker_group.spec
restarts_happened = self._remaining_restarts != spec.max_restarts
put_metric(f"workers.{spec.role}.run_total", 1)
self._record_metric_with_condition(
"run_success_with_retries", not is_failed and restarts_happened
)
self._record_metric_with_condition(
"run_success_no_retries", not is_failed and not restarts_happened
)
self._record_metric_with_condition(
"run_failed_with_retries", is_failed and restarts_happened
)
self._record_metric_with_condition(
"run_failed_no_retries", is_failed and not restarts_happened
)
def _record_metric_with_condition(self, metric_name, condition):
spec = self._worker_group.spec
if condition:
put_metric(f"workers.{spec.role}.{metric_name}", 1)
else:
put_metric(f"workers.{spec.role}.{metric_name}", 0)
def _record_flakiness_metric(self, is_failed: bool = False):
if is_failed:
flakiness = 100.0
else:
spec = self._worker_group.spec
flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / (
spec.max_restarts + 1
)
spec = self._worker_group.spec
put_metric(f"workers.{spec.role}.flakiness", int(flakiness))
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role
spec = self._worker_group.spec
role = spec.role
logger.info(
"[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name()
)
self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler
while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state
put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
logger.info(
"[%s] worker group successfully finished."
" Waiting %s seconds for other agents to finish.",
role,
self._exit_barrier_timeout,
)
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
if self._remaining_restarts > 0:
logger.info(
"[%s] Worker group %s. "
"%s/%s attempts left;"
" will restart worker group",
role,
state.name,
self._remaining_restarts,
spec.max_restarts,
)
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
logger.info(
"[%s] Detected %s "
"new nodes from group_rank=%s; "
"will restart worker group",
role,
num_nodes_waiting,
group_rank,
)
self._restart_workers(self._worker_group)
else:
raise Exception( # noqa: TRY002
f"[{role}] Worker group in {state.name} state"
)
def _exit_barrier(self):
"""
Define a barrier that keeps the agent process alive until all workers finish.
Wait for ``exit_barrier_timeout`` seconds for all agents to finish
executing their local workers (either successfully or not). This
acts as a safety guard against user scripts that terminate at different
times.
"""
logger.info(
"Local worker group finished (%s). "
"Waiting %s seconds for other agents to finish",
self._worker_group.state,
self._exit_barrier_timeout,
)
start = time.time()
try:
store_util.barrier(
store=self._store,
world_size=self._worker_group.group_world_size,
key_prefix=_TERMINAL_STATE_SYNC_ID,
barrier_timeout=self._exit_barrier_timeout,
)
logger.info(
"Done waiting for other agents. Elapsed: %s seconds",
time.time() - start,
)
except SignalException as e:
logger.warning("Got termination signal: %s", e.sigval)
raise
except Exception:
logger.exception(
"Error waiting on exit barrier. Elapsed: %s seconds",
time.time() - start,
)

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable
from torch.distributed.elastic.utils.logging import get_logger
log = get_logger(__name__)
__all__ = ["HealthCheckServer", "create_healthcheck_server"]
class HealthCheckServer:
"""
Interface for health check monitoring server, which can be extended
by starting tcp/http server on the specified port.
Args:
alive_callback: Callable[[], int], callback to last progress time of agent
port: int, port number to start tcp/http server
timeout: int, timeout seconds to decide agent is alive/dead
"""
_alive_callback: Callable[[], int]
_port: int
_timeout: int
def __init__(
self, alive_callback: Callable[[], int], port: int, timeout: int
) -> None:
self._alive_callback = alive_callback
self._port = port
self._timeout = timeout
def start(self) -> None:
"""
Unsupported functionality for Pytorch, doesn't start any health check server
"""
log.warning("No health check server started")
def stop(self) -> None:
"""
Function to stop health check server
"""
log.info("Stopping noop health check server.")
def create_healthcheck_server(
alive_callback: Callable[[], int],
port: int,
timeout: int,
) -> HealthCheckServer:
"""
creates health check server object
"""
return HealthCheckServer(alive_callback, port, timeout)

View File

@ -0,0 +1,410 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import signal
import socket
import time
import uuid
from string import Template
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events
from torch.distributed.elastic.agent.server.api import (
RunResult,
SimpleElasticAgent,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from torch.distributed.elastic.agent.server.health_check_server import (
create_healthcheck_server,
HealthCheckServer,
)
from torch.distributed.elastic.metrics.api import prof
from torch.distributed.elastic.multiprocessing import (
LogsSpecs,
PContext,
start_processes,
)
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
if TYPE_CHECKING:
from torch.distributed.elastic.events.api import EventMetadataValue
logger = get_logger(__name__)
__all__ = [
"LocalElasticAgent",
"TORCHELASTIC_ENABLE_FILE_TIMER",
"TORCHELASTIC_TIMER_FILE",
"TORCHELASTIC_HEALTH_CHECK_PORT",
]
TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER"
TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT"
TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE"
class LocalElasticAgent(SimpleElasticAgent):
"""An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
This agent is deployed per host and is configured to spawn ``n`` workers.
When using GPUs, ``n`` maps to the number of GPUs available on the host.
The local agent does not communicate to other local agents deployed on
other hosts, even if the workers may communicate inter-host. The worker id
is interpreted to be a local process. The agent starts and stops all worker
processes as a single unit.
The worker function and argument passed to the worker function must be
python multiprocessing compatible. To pass multiprocessing data structures
to the workers you may create the data structure in the same multiprocessing
context as the specified ``start_method`` and pass it as a function argument.
The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait
for other agents to finish. This acts as a safety net to handle cases where
workers finish at different times, to prevent agents from viewing workers
that finished early as a scale-down event. It is strongly advised that the
user code deal with ensuring that workers are terminated in a synchronous
manner rather than relying on the exit_barrier_timeout.
A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an
environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has
been defined in the ```LocalElasticAgent``` process.
Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE```
can be set with a unique file name for the named pipe. If the environment
variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent```
will internally create a unique file name and set it to the environment
variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will
be propagated to the worker processes to allow them to connect to the same
named pipe that ```LocalElasticAgent``` uses.
Logs are written to the specified log directory. Each log line will be by default
prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``).
Log prefixes can be customized by passing a `template string
<https://docs.python.org/3/library/string.html#template-strings>`_ as the
``log_line_prefix_template`` argument.
The following macros (identifiers) are substituted at runtime:
``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with
global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``.
Example launching function
::
def trainer(args) -> str:
return "do train"
def main():
start_method="spawn"
shared_queue= multiprocessing.get_context(start_method).Queue()
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint=trainer,
args=("foobar",),
...<OTHER_PARAMS...>)
agent = LocalElasticAgent(spec, start_method)
results = agent.run()
if results.is_failed():
print("trainer failed")
else:
print(f"rank 0 return value: {results.return_values[0]}")
# prints -> rank 0 return value: do train
Example launching binary
::
def main():
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint="/usr/local/bin/trainer",
args=("--trainer-args", "foobar"),
...<OTHER_PARAMS...>)
agent = LocalElasticAgent(spec)
results = agent.run()
if not results.is_failed():
print("binary launches do not have return values")
"""
def __init__(
self,
spec: WorkerSpec,
logs_specs: LogsSpecs,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_line_prefix_template: Optional[str] = None,
):
super().__init__(spec, exit_barrier_timeout)
self._start_method = start_method
self._pcontext: Optional[PContext] = None
self._rdzv_handler = spec.rdzv_handler
self._log_line_prefix_template = log_line_prefix_template
self._worker_watchdog: Optional[timer.FileTimerServer] = None
self._logs_specs = logs_specs
self._health_check_server: Optional[HealthCheckServer] = None
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
watchdog_enabled = os.getenv(enable_watchdog_env_name)
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
watchdog_file_path = os.getenv(watchdog_file_env_name)
if watchdog_enabled is not None and str(watchdog_enabled) == "1":
if watchdog_file_path is None:
watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path)
if not envs:
logger.warning(
"Empty envs variables, using empty run_id for FileTimerServer"
)
run_id = ""
else:
run_id = envs[0]["TORCHELASTIC_RUN_ID"]
self._worker_watchdog = timer.FileTimerServer(
file_path=watchdog_file_path,
run_id=run_id,
max_interval=0.1,
daemon=True,
log_event=self._log_watchdog_event,
)
self._worker_watchdog.start()
logger.info("FileTimerServer started")
else:
logger.info(
"Environment variable '%s' not found. Do not start FileTimerServer.",
enable_watchdog_env_name,
)
# Propagate the watchdog file env to worker processes
if watchdog_file_path is not None:
for worker_env in envs.values():
worker_env[watchdog_file_env_name] = watchdog_file_path
@staticmethod
def _get_current_time_secs() -> int:
return int(time.time())
def _setup_healthcheck(self) -> None:
healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT
healthcheck_port = os.getenv(healthcheck_port_env_name)
if healthcheck_port is not None:
logger.info(
"Found healthcheck port %s: %s",
healthcheck_port_env_name,
healthcheck_port,
)
if self._worker_watchdog is None:
logger.info(
"FileTimerServer doesn't exist, using current time as dummy callback"
)
alive_callback = LocalElasticAgent._get_current_time_secs
else:
alive_callback = self._worker_watchdog.get_last_progress_time
self._health_check_server = create_healthcheck_server(
alive_callback=alive_callback,
port=int(healthcheck_port),
timeout=60,
)
self._health_check_server.start()
else:
logger.info(
"Environment variable '%s' not found. Do not start health check.",
healthcheck_port_env_name,
)
def _get_fq_hostname(self) -> str:
return socket.getfqdn(socket.gethostname())
def _log_watchdog_event(
self,
name: str,
request: Optional[timer.FileTimerRequest],
) -> None:
wg = self._worker_group
spec = wg.spec
md = {"watchdog_event": name}
if request is not None:
md["worker_pid"] = str(request.worker_pid)
md["scope_id"] = request.scope_id
md["expiration_time"] = str(request.expiration_time)
md["signal"] = str(request.signal)
md_str = json.dumps(md)
state = "RUNNING"
metadata: Dict[str, EventMetadataValue] = {
"run_id": spec.rdzv_handler.get_run_id(),
"global_rank": None,
"group_rank": wg.group_rank,
"worker_id": None,
"role": spec.role,
"hostname": self._get_fq_hostname(),
"state": state,
"total_run_time": self._total_execution_time,
"rdzv_backend": spec.rdzv_handler.get_backend(),
"raw_error": None,
"metadata": md_str,
"agent_restarts": spec.max_restarts - self._remaining_restarts,
}
# Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later.
# The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry.
event = events.Event(
name=name, source=events.EventSource.AGENT, metadata=metadata
)
events.record(event)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _stop_workers(
self, worker_group: WorkerGroup, is_restart: bool = False
) -> None:
self._shutdown(is_restart=is_restart)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
spec = worker_group.spec
store = worker_group.store
assert store is not None
restart_count = spec.max_restarts - self._remaining_restarts
use_agent_store: bool = spec.rdzv_handler.use_agent_store
logger.info("use_agent_store: %s", use_agent_store)
args: Dict[int, Tuple] = {}
envs: Dict[int, Dict[str, str]] = {}
log_line_prefixes: Optional[Dict[int, str]] = (
{} if self._log_line_prefix_template else None
)
for worker in worker_group.workers:
local_rank = worker.local_rank
worker_env = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
"ROLE_NAME": spec.role,
"LOCAL_WORLD_SIZE": str(spec.local_world_size),
"WORLD_SIZE": str(worker.world_size),
"GROUP_WORLD_SIZE": str(worker_group.group_world_size),
"ROLE_WORLD_SIZE": str(worker.role_world_size),
"MASTER_ADDR": worker_group.master_addr,
"MASTER_PORT": str(worker_group.master_port),
"TORCHELASTIC_RESTART_COUNT": str(restart_count),
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv(
"TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1)
),
}
if "OMP_NUM_THREADS" in os.environ:
worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
if self._log_line_prefix_template:
log_line_prefix = Template(
self._log_line_prefix_template
).safe_substitute(
role_name=spec.role,
rank=worker.global_rank,
local_rank=local_rank,
)
log_line_prefixes[local_rank] = log_line_prefix
envs[local_rank] = worker_env
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))
args[local_rank] = tuple(worker_args)
self._setup_local_watchdog(envs=envs)
self._setup_healthcheck()
assert spec.entrypoint is not None
assert self._logs_specs is not None
self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
logs_specs=self._logs_specs,
log_line_prefixes=log_line_prefixes,
start_method=self._start_method,
)
return self._pcontext.pids()
def _shutdown(
self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
) -> None:
if self._worker_watchdog is not None:
self._worker_watchdog.stop()
self._worker_watchdog = None
if self._health_check_server is not None:
self._health_check_server.stop()
self._health_check_server = None
if self._pcontext:
self._pcontext.close(death_sig)
if not is_restart and self._rdzv_handler:
self._rdzv_handler.shutdown()
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
role = worker_group.spec.role
worker_pids = {w.id for w in worker_group.workers}
assert self._pcontext is not None
pc_pids = set(self._pcontext.pids().values())
if worker_pids != pc_pids:
logger.error(
"[%s] worker pids do not match process_context pids."
" Expected: %s, actual: %s",
role,
worker_pids,
pc_pids,
)
return RunResult(state=WorkerState.UNKNOWN)
result = self._pcontext.wait(0)
if result:
if result.is_failed():
# map local rank failure to global rank
worker_failures = {}
for local_rank, failure in result.failures.items():
worker = worker_group.workers[local_rank]
worker_failures[worker.global_rank] = failure
return RunResult(
state=WorkerState.FAILED,
failures=worker_failures,
)
else:
# copy ret_val_queue into a map with a global ranks
workers_ret_vals = {}
for local_rank, ret_val in result.return_values.items():
worker = worker_group.workers[local_rank]
workers_ret_vals[worker.global_rank] = ret_val
return RunResult(
state=WorkerState.SUCCEEDED,
return_values=workers_ret_vals,
)
else:
return RunResult(state=WorkerState.HEALTHY)

View File

@ -0,0 +1,52 @@
import os
from contextlib import contextmanager, ExitStack
from typing import Generator
from torch.distributed.elastic.multiprocessing.errors import record
__all__ = [
"worker_main",
]
TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
@contextmanager
def _worker_server(socket_path: str) -> Generator[None, None, None]:
from torch._C._distributed_c10d import _WorkerServer
server = _WorkerServer(socket_path)
try:
yield
finally:
server.shutdown()
@contextmanager
@record
def worker_main() -> Generator[None, None, None]:
"""
This is a context manager that wraps your main entry function. This combines
the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
exposes handlers via a unix socket specified by
``Torch_WORKER_SERVER_SOCKET``.
Example
::
@worker_main()
def main():
pass
if __name__=="__main__":
main()
"""
with ExitStack() as stack:
socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
if socket_path is not None:
stack.enter_context(_worker_server(socket_path))
yield

View File

@ -0,0 +1,170 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Module contains events processing mechanisms that are integrated with the standard python logging.
Example of usage:
::
from torch.distributed.elastic import events
event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
events.get_logging_handler(destination="console").info(event)
"""
import inspect
import logging
import os
import socket
import traceback
from typing import Dict, Optional
from torch.distributed.elastic.events.handlers import get_logging_handler
from .api import ( # noqa: F401
Event,
EventMetadataValue,
EventSource,
NodeState,
RdzvEvent,
)
_events_loggers: Dict[str, logging.Logger] = {}
def _get_or_create_logger(destination: str = "null") -> logging.Logger:
"""
Construct python logger based on the destination type or extends if provided.
Available destination could be found in ``handlers.py`` file.
The constructed logger does not propagate messages to the upper level loggers,
e.g. root logger. This makes sure that a single event can be processed once.
Args:
destination: The string representation of the event handler.
Available handlers found in ``handlers`` module
"""
global _events_loggers
if destination not in _events_loggers:
_events_logger = logging.getLogger(f"torchelastic-events-{destination}")
_events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
# Do not propagate message to the root logger
_events_logger.propagate = False
logging_handler = get_logging_handler(destination)
_events_logger.addHandler(logging_handler)
# Add the logger to the global dictionary
_events_loggers[destination] = _events_logger
return _events_loggers[destination]
def record(event: Event, destination: str = "null") -> None:
_get_or_create_logger(destination).info(event.serialize())
def record_rdzv_event(event: RdzvEvent) -> None:
_get_or_create_logger("dynamic_rendezvous").info(event.serialize())
def construct_and_record_rdzv_event(
run_id: str,
message: str,
node_state: NodeState,
name: str = "",
hostname: str = "",
pid: Optional[int] = None,
master_endpoint: str = "",
local_id: Optional[int] = None,
rank: Optional[int] = None,
) -> None:
"""
Initialize rendezvous event object and record its operations.
Args:
run_id (str): The run id of the rendezvous.
message (str): The message describing the event.
node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED).
name (str): Event name. (E.g. Current action being performed).
hostname (str): Hostname of the node.
pid (Optional[int]): The process id of the node.
master_endpoint (str): The master endpoint for the rendezvous store, if known.
local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py
rank (Optional[int]): The rank of the node, if known.
Returns:
None
Example:
>>> # See DynamicRendezvousHandler class
>>> def _record(
... self,
... message: str,
... node_state: NodeState = NodeState.RUNNING,
... rank: Optional[int] = None,
... ) -> None:
... construct_and_record_rdzv_event(
... name=f"{self.__class__.__name__}.{get_method_name()}",
... run_id=self._settings.run_id,
... message=message,
... node_state=node_state,
... hostname=self._this_node.addr,
... pid=self._this_node.pid,
... local_id=self._this_node.local_id,
... rank=rank,
... )
"""
# We don't want to perform an extra computation if not needed.
if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
return
# Set up parameters.
if not hostname:
hostname = socket.getfqdn()
if not pid:
pid = os.getpid()
# Determines which file called this function.
callstack = inspect.stack()
filename = "no_file"
if len(callstack) > 1:
stack_depth_1 = callstack[1]
filename = os.path.basename(stack_depth_1.filename)
if not name:
name = stack_depth_1.function
# Delete the callstack variable. If kept, this can mess with python's
# garbage collector as we are holding on to stack frame information in
# the inspect module.
del callstack
# Set up error trace if this is an exception
if node_state == NodeState.FAILED:
error_trace = traceback.format_exc()
else:
error_trace = ""
# Initialize event object
event = RdzvEvent(
name=f"{filename}:{name}",
run_id=run_id,
message=message,
hostname=hostname,
pid=pid,
node_state=node_state,
master_endpoint=master_endpoint,
rank=rank,
local_id=local_id,
error_trace=error_trace,
)
# Finally, record the event.
record_rdzv_event(event)

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Dict, Optional, Union
__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
EventMetadataValue = Union[str, int, float, bool, None]
class EventSource(str, Enum):
"""Known identifiers of the event producers."""
AGENT = "AGENT"
WORKER = "WORKER"
@dataclass
class Event:
"""
The class represents the generic event that occurs during the torchelastic job execution.
The event can be any kind of meaningful action.
Args:
name: event name.
source: the event producer, e.g. agent or worker
timestamp: timestamp in milliseconds when event occurred.
metadata: additional data that is associated with the event.
"""
name: str
source: EventSource
timestamp: int = 0
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
def __str__(self):
return self.serialize()
@staticmethod
def deserialize(data: Union[str, "Event"]) -> "Event":
if isinstance(data, Event):
return data
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
return Event(**data_dict)
def serialize(self) -> str:
return json.dumps(asdict(self))
class NodeState(str, Enum):
"""The states that a node can be in rendezvous."""
INIT = "INIT"
RUNNING = "RUNNING"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
@dataclass
class RdzvEvent:
"""
Dataclass to represent any rendezvous event.
Args:
name: Event name. (E.g. Current action being performed)
run_id: The run id of the rendezvous
message: The message describing the event
hostname: Hostname of the node
pid: The process id of the node
node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
master_endpoint: The master endpoint for the rendezvous store, if known
rank: The rank of the node, if known
local_id: The local_id of the node, if defined in dynamic_rendezvous.py
error_trace: Error stack trace, if this is an error event.
"""
name: str
run_id: str
message: str
hostname: str
pid: int
node_state: NodeState
master_endpoint: str = ""
rank: Optional[int] = None
local_id: Optional[int] = None
error_trace: str = ""
def __str__(self):
return self.serialize()
@staticmethod
def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
if isinstance(data, RdzvEvent):
return data
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
return RdzvEvent(**data_dict)
def serialize(self) -> str:
return json.dumps(asdict(self))

View File

@ -0,0 +1,22 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict
_log_handlers: Dict[str, logging.Handler] = {
"console": logging.StreamHandler(),
"dynamic_rendezvous": logging.NullHandler(),
"null": logging.NullHandler(),
}
def get_logging_handler(destination: str = "null") -> logging.Handler:
global _log_handlers
return _log_handlers[destination]

View File

@ -0,0 +1,164 @@
#!/usr/bin/env/python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Metrics API.
**Overview**:
The metrics API in torchelastic is used to publish telemetry metrics.
It is designed to be used by torchelastic's internal modules to
publish metrics for the end user with the goal of increasing visibility
and helping with debugging. However you may use the same API in your
jobs to publish metrics to the same metrics ``sink``.
A ``metric`` can be thought of as timeseries data
and is uniquely identified by the string-valued tuple
``(metric_group, metric_name)``.
torchelastic makes no assumptions about what a ``metric_group`` is
and what relationship it has with ``metric_name``. It is totally up
to the user to use these two fields to uniquely identify a metric.
.. note:: The metric group ``torchelastic`` is reserved by torchelastic for
platform level metrics that it produces.
For instance torchelastic may output the latency (in milliseconds)
of a re-rendezvous operation from the agent as
``(torchelastic, agent.rendezvous.duration.ms)``
A sensible way to use metric groups is to map them to a stage or module
in your job. You may also encode certain high level properties
the job such as the region or stage (dev vs prod).
**Publish Metrics**:
Using torchelastic's metrics API is similar to using python's logging
framework. You first have to configure a metrics handler before
trying to add metric data.
The example below measures the latency for the ``calculate()`` function.
::
import time
import torch.distributed.elastic.metrics as metrics
# makes all metrics other than the one from "my_module" to go /dev/null
metrics.configure(metrics.NullMetricsHandler())
metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
def my_method():
start = time.time()
calculate()
end = time.time()
metrics.put_metric("calculate_latency", int(end-start), "my_module")
You may also use the torch.distributed.elastic.metrics.prof` decorator
to conveniently and succinctly profile functions
::
# -- in module examples.foobar --
import torch.distributed.elastic.metrics as metrics
metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
@metrics.prof
def foo():
pass
class Bar():
@metrics.prof
def baz():
pass
``@metrics.prof`` will publish the following metrics
::
<leaf_module or classname>.success - 1 if the function finished successfully
<leaf_module or classname>.failure - 1 if the function threw an exception
<leaf_module or classname>.duration.ms - function duration in milliseconds
**Configuring Metrics Handler**:
`torch.distributed.elastic.metrics.MetricHandler` is responsible for emitting
the added metric values to a particular destination. Metric groups can be
configured with different metric handlers.
By default torchelastic emits all metrics to ``/dev/null``.
By adding the following configuration metrics,
``torchelastic`` and ``my_app`` metric groups will be printed out to
console.
::
import torch.distributed.elastic.metrics as metrics
metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic")
metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app")
**Writing a Custom Metric Handler**:
If you want your metrics to be emitted to a custom location, implement
the `torch.distributed.elastic.metrics.MetricHandler` interface
and configure your job to use your custom metric handler.
Below is a toy example that prints the metrics to ``stdout``
::
import torch.distributed.elastic.metrics as metrics
class StdoutMetricHandler(metrics.MetricHandler):
def emit(self, metric_data):
ts = metric_data.timestamp
group = metric_data.group_name
name = metric_data.name
value = metric_data.value
print(f"[{ts}][{group}]: {name}={value}")
metrics.configure(StdoutMetricHandler(), group="my_app")
Now all metrics in the group ``my_app`` will be printed to stdout as:
::
[1574213883.4182858][my_app]: my_metric=<value>
[1574213940.5237644][my_app]: my_metric=<value>
"""
from typing import Optional
from .api import ( # noqa: F401
configure,
ConsoleMetricHandler,
get_elapsed_time_ms,
getStream,
MetricData,
MetricHandler,
MetricsConfig,
NullMetricHandler,
prof,
profile,
publish_metric,
put_metric,
)
def initialize_metrics(cfg: Optional[MetricsConfig] = None):
pass
try:
from torch.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403
except ModuleNotFoundError:
pass

View File

@ -0,0 +1,216 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import time
from collections import namedtuple
from functools import wraps
from typing import Dict, Optional
from typing_extensions import deprecated
__all__ = [
"MetricsConfig",
"MetricHandler",
"ConsoleMetricHandler",
"NullMetricHandler",
"MetricStream",
"configure",
"getStream",
"prof",
"profile",
"put_metric",
"publish_metric",
"get_elapsed_time_ms",
"MetricData",
]
MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
class MetricsConfig:
__slots__ = ["params"]
def __init__(self, params: Optional[Dict[str, str]] = None):
self.params = params
if self.params is None:
self.params = {}
class MetricHandler(abc.ABC):
@abc.abstractmethod
def emit(self, metric_data: MetricData):
pass
class ConsoleMetricHandler(MetricHandler):
def emit(self, metric_data: MetricData):
print(
f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
)
class NullMetricHandler(MetricHandler):
def emit(self, metric_data: MetricData):
pass
class MetricStream:
def __init__(self, group_name: str, handler: MetricHandler):
self.group_name = group_name
self.handler = handler
def add_value(self, metric_name: str, metric_value: int):
self.handler.emit(
MetricData(time.time(), self.group_name, metric_name, metric_value)
)
_metrics_map: Dict[str, MetricHandler] = {}
_default_metrics_handler: MetricHandler = NullMetricHandler()
# pyre-fixme[9]: group has type `str`; used as `None`.
def configure(handler: MetricHandler, group: Optional[str] = None):
if group is None:
global _default_metrics_handler
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
# as `MetricHandler`.
_default_metrics_handler = handler
else:
_metrics_map[group] = handler
def getStream(group: str):
if group in _metrics_map:
handler = _metrics_map[group]
else:
handler = _default_metrics_handler
return MetricStream(group, handler)
def _get_metric_name(fn):
qualname = fn.__qualname__
split = qualname.split(".")
if len(split) == 1:
module = fn.__module__
if module:
return module.split(".")[-1] + "." + split[0]
else:
return split[0]
else:
return qualname
def prof(fn=None, group: str = "torchelastic"):
r"""
@profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
The metric name defaults to the qualified name (``class_name.def_name``) of the function.
If the function does not belong to a class, it uses the leaf module name instead.
Usage
::
@metrics.prof
def x():
pass
@metrics.prof(group="agent")
def y():
pass
"""
def wrap(f):
@wraps(f)
def wrapper(*args, **kwargs):
key = _get_metric_name(f)
try:
start = time.time()
result = f(*args, **kwargs)
put_metric(f"{key}.success", 1, group)
except Exception:
put_metric(f"{key}.failure", 1, group)
raise
finally:
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
return result
return wrapper
if fn:
return wrap(fn)
else:
return wrap
@deprecated("Deprecated, use `@prof` instead", category=FutureWarning)
def profile(group=None):
"""
@profile decorator adds latency and success/failure metrics to any given function.
Usage
::
@metrics.profile("my_metric_group")
def some_function(<arguments>):
"""
def wrap(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
start_time = time.time()
result = func(*args, **kwargs)
publish_metric(group, f"{func.__name__}.success", 1)
except Exception:
publish_metric(group, f"{func.__name__}.failure", 1)
raise
finally:
publish_metric(
group,
f"{func.__name__}.duration.ms",
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
)
return result
return wrapper
return wrap
def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
"""
Publish a metric data point.
Usage
::
put_metric("metric_name", 1)
put_metric("metric_name", 1, "metric_group_name")
"""
getStream(metric_group).add_value(metric_name, metric_value)
@deprecated(
"Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead",
category=FutureWarning,
)
def publish_metric(metric_group: str, metric_name: str, metric_value: int):
metric_stream = getStream(metric_group)
metric_stream.add_value(metric_name, metric_value)
def get_elapsed_time_ms(start_time_in_seconds: float):
"""Return the elapsed time in millis from the given start time."""
end_time = time.time()
return int((end_time - start_time_in_seconds) * 1000)

View File

@ -0,0 +1,233 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary.
For functions, it uses ``torch.multiprocessing`` (and therefore python
``multiprocessing``) to spawn/fork worker processes. For binaries it uses python
``subprocessing.Popen`` to create worker processes.
Usage 1: Launching two trainers as a function
::
from torch.distributed.elastic.multiprocessing import Std, start_processes
def trainer(a, b, c):
pass # train
# runs two trainers
# LOCAL_RANK=0 trainer(1,2,3)
# LOCAL_RANK=1 trainer(4,5,6)
ctx = start_processes(
name="trainer",
entrypoint=trainer,
args={0: (1,2,3), 1: (4,5,6)},
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
log_dir="/tmp/foobar",
redirects=Std.ALL, # write all worker stdout/stderr to a log file
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
)
# waits for all copies of trainer to finish
ctx.wait()
Usage 2: Launching 2 echo workers as a binary
::
# same as invoking
# echo hello
# echo world > stdout.log
ctx = start_processes(
name="echo"
entrypoint="echo",
log_dir="/tmp/foobar",
args={0: "hello", 1: "world"},
redirects={1: Std.OUT},
)
Just like ``torch.multiprocessing``, the return value of the function
:func:`start_processes` is a process context (:class:`api.PContext`). If a function
was launched, a :class:`api.MultiprocessContext` is returned and if a binary
was launched a :class:`api.SubprocessContext` is returned. Both are specific
implementations of the parent :class:`api.PContext` class.
"""
from typing import Callable, Dict, Optional, Tuple, Union
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
_validate_full_rank,
DefaultLogsSpecs,
LogsDest,
LogsSpecs,
MultiprocessContext,
PContext,
ProcessFailure,
RunProcsResult,
SignalException,
Std,
SubprocessContext,
to_map,
)
from torch.distributed.elastic.utils.logging import get_logger
__all__ = [
"start_processes",
"MultiprocessContext",
"PContext",
"ProcessFailure",
"RunProcsResult",
"SignalException",
"Std",
"LogsDest",
"LogsSpecs",
"DefaultLogsSpecs",
"SubprocessContext",
"to_map",
]
def start_processes(
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
start_method: str = "spawn",
) -> PContext:
"""
Start ``n`` copies of ``entrypoint`` processes with the provided options.
``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
The number of copies is determined by the number of entries for ``args`` and
``envs`` arguments, which need to have the same key set.
``args`` and ``env`` parameters are the arguments and environment variables
to pass down to the entrypoint mapped by the replica index (local rank).
All local ranks must be accounted for.
That is, the keyset should be ``{0,1,...,(nprocs-1)}``.
.. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings.
If any other type is given, then it is casted to a string representation
(e.g. ``str(arg1)``). Furthermore, a binary failure will only write
an ``error.json`` error file if the main function is annotated with
``torch.distributed.elastic.multiprocessing.errors.record``. For function launches,
this is done by default and there is no need to manually annotate
with the ``@record`` annotation.
``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
the local rank to specify the redirect behavior for.
Any missing local ranks will default to ``Std.NONE``.
``tee`` acts like the unix "tee" command in that it redirects + prints to console.
To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
For each process, the ``log_dir`` will contain:
#. ``{local_rank}/error.json``: if the process failed, a file with the error info
#. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT``
#. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR``
.. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
Example:
::
log_dir = "/tmp/test"
# ok; two copies of foo: foo("bar0"), foo("bar1")
start_processes(
name="trainer",
entrypoint=foo,
args:{0:("bar0",), 1:("bar1",),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
# invalid; envs missing for local rank 1
start_processes(
name="trainer",
entrypoint=foo,
args:{0:("bar0",), 1:("bar1",),
envs:{0:{}},
log_dir=log_dir
)
# ok; two copies of /usr/bin/touch: touch file1, touch file2
start_processes(
name="trainer",
entrypoint="/usr/bin/touch",
args:{0:("file1",), 1:("file2",),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
# caution; arguments casted to string, runs:
# echo "1" "2" "3" and echo "[1, 2, 3]"
start_processes(
name="trainer",
entrypoint="/usr/bin/echo",
args:{0:(1,2,3), 1:([1,2,3],),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
Args:
name: a human readable short name that describes what the processes are
(used as header when tee'ing stdout/stderr outputs)
entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
args: arguments to each replica
envs: env vars to each replica
log_dir: directory used to write log files
start_method: multiprocessing start method (spawn, fork, forkserver)
ignored for binaries
redirects: which std streams to redirect to a log file
tee: which std streams to redirect + print to console
local_ranks_filter: which ranks' logs to print to console
"""
nprocs = len(args)
_validate_full_rank(args, nprocs, "args")
_validate_full_rank(envs, nprocs, "envs")
context: PContext
if isinstance(entrypoint, str):
context = SubprocessContext(
name=name,
entrypoint=entrypoint,
args=args,
envs=envs,
logs_specs=logs_specs,
log_line_prefixes=log_line_prefixes,
)
else:
context = MultiprocessContext(
name=name,
entrypoint=entrypoint,
args=args,
envs=envs,
log_line_prefixes=log_line_prefixes,
start_method=start_method,
logs_specs=logs_specs,
)
try:
context.start()
return context
except Exception:
context.close()
raise

View File

@ -0,0 +1,923 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import logging
import os
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from contextlib import nullcontext
from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
from torch.distributed.elastic.multiprocessing.redirects import (
redirect_stderr,
redirect_stdout,
)
from torch.distributed.elastic.multiprocessing.subprocess_handler import (
get_subprocess_handler,
SubprocessHandler,
)
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
logger = logging.getLogger(__name__)
__all__ = [
"DefaultLogsSpecs",
"SignalException",
"Std",
"to_map",
"RunProcsResult",
"PContext",
"get_std_cm",
"MultiprocessContext",
"SubprocessContext",
"LogsDest",
"LogsSpecs",
]
class SignalException(Exception):
"""
Exception is raised inside the torchelastic agent process by the termination handler
if the death signal got received by the process.
"""
def __init__(self, msg: str, sigval: signal.Signals) -> None:
super().__init__(msg)
self.sigval = sigval
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
"""Termination handler that raises exceptions on the main process.
When the process receives death signal(SIGTERM, SIGINT), this termination handler will
be invoked. It raises the ``SignalException`` exception that should be processed by the
user code. Python does not terminate process after the termination handler is finished,
so the exception should not be silently ignored, otherwise the process will never
be terminated.
"""
sigval = signal.Signals(signum)
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
def _get_kill_signal() -> signal.Signals:
"""Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGKILL
def _get_default_signal() -> signal.Signals:
"""Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys())
expected_keys = set(range(nprocs))
if actual_keys != expected_keys:
raise RuntimeError(
f"{what}, local rank mapping mismatch,"
f" expected: {expected_keys}, actual: {actual_keys}"
)
_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
_VALUE_REGEX = r"^[0123]$"
class Std(IntFlag):
NONE = 0
OUT = 1
ERR = 2
ALL = OUT | ERR
@classmethod
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
"""
Example:
::
from_str("0") -> Std.NONE
from_str("1") -> Std.OUT
from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
Any other input raises an exception
"""
def to_std(v: str) -> Std: # type: ignore[return]
s = Std(int(v))
if s in Std:
return s
# return None -> should NEVER reach here since we regex check input
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
return to_std(vm)
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
d: Dict[int, Std] = {}
for m in vm.split(","):
i, v = m.split(":")
d[int(i)] = to_std(v)
return d
else:
raise ValueError(
f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>"
)
def to_map(
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
) -> Dict[int, Std]:
"""
Certain APIs take redirect settings either as a single value (e.g. apply to all
local ranks) or as an explicit user-provided mapping. This method is a convenience
method that converts a value or mapping into a mapping.
Example:
::
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
"""
if isinstance(val_or_map, Std):
return dict.fromkeys(range(local_world_size), val_or_map)
else:
map = {}
for i in range(local_world_size):
map[i] = val_or_map.get(i, Std.NONE)
return map
@dataclass
class LogsDest:
"""
For each log type, holds mapping of local rank ids to file paths.
"""
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
tee_stdouts: Dict[int, str] = field(default_factory=dict)
tee_stderrs: Dict[int, str] = field(default_factory=dict)
error_files: Dict[int, str] = field(default_factory=dict)
class LogsSpecs(ABC):
"""
Defines logs processing and redirection for each worker process.
Args:
log_dir:
Base directory where logs will be written.
redirects:
Streams to redirect to files. Pass a single ``Std``
enum to redirect for all workers, or a mapping keyed
by local_rank to selectively redirect.
tee:
Streams to duplicate to stdout/stderr.
Pass a single ``Std`` enum to duplicate streams for all workers,
or a mapping keyed by local_rank to selectively duplicate.
"""
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
) -> None:
self._root_log_dir = log_dir
self._redirects = redirects
self._tee = tee
self._local_ranks_filter = local_ranks_filter
@abstractmethod
def reify(
self,
envs: Dict[int, Dict[str, str]],
) -> LogsDest:
"""
Given the environment variables, builds destination of log files for each of the local ranks.
Envs parameter contains env variables dict for each of the local ranks, where entries are defined in:
:func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`.
"""
@property
@abstractmethod
def root_log_dir(self) -> str:
pass
class DefaultLogsSpecs(LogsSpecs):
"""
Default LogsSpecs implementation:
- `log_dir` will be created if it doesn't exist
- Generates nested folders for each attempt and rank.
"""
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
) -> None:
if log_dir != os.devnull:
if not log_dir:
log_dir = tempfile.mkdtemp(prefix="torchelastic_")
elif not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
else:
if os.path.isfile(log_dir):
raise NotADirectoryError(f"log_dir: {log_dir} is a file")
super().__init__(log_dir, redirects, tee, local_ranks_filter)
# initialized only once
self._run_log_dir = None
@property
def root_log_dir(self) -> str:
return str(self._root_log_dir)
def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
os.makedirs(base_log_dir, exist_ok=True)
dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
logger.info("log directory set to: %s", dir)
return dir
def reify(
self,
envs: Dict[int, Dict[str, str]],
) -> LogsDest:
"""
Uses following scheme to build log destination paths:
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json`
"""
nprocs = len(envs)
global_env = {} # use only to query properies that are not dependent on a rank
if nprocs > 0:
global_env = envs[0]
else:
logger.warning(
"Empty envs map provided when defining logging destinations."
)
# Keys are always defined, but values can be missing in unit tests
run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id")
restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0")
attempt_log_dir: str = ""
if self._root_log_dir != os.devnull:
if not self._run_log_dir:
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload]
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)
if self._root_log_dir == os.devnull:
attempt_log_dir = os.devnull
# create subdirs for each local rank in the logs_dir
# logs_dir
# |- 0
# |- error.json
# |- stdout.log
# |- stderr.log
# |- ...
# |- (nprocs-1)
redirs = to_map(self._redirects, nprocs)
ts = to_map(self._tee, nprocs)
# to tee stdout/stderr we first redirect into a file
# then tail -f stdout.log/stderr.log so add tee settings to redirects
for local_rank, tee_std in ts.items():
redirect_std = redirs[local_rank]
redirs[local_rank] = redirect_std | tee_std
SYS_STREAM = "" # special case to indicate to output to console
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
tee_stdouts: Dict[int, str] = {}
tee_stderrs: Dict[int, str] = {}
error_files = {}
for local_rank in range(nprocs):
if attempt_log_dir == os.devnull:
tee_stdouts[local_rank] = os.devnull
tee_stderrs[local_rank] = os.devnull
error_files[local_rank] = os.devnull
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = ""
else:
clogdir = os.path.join(attempt_log_dir, str(local_rank))
os.mkdir(clogdir)
rd = redirs[local_rank]
if (rd & Std.OUT) == Std.OUT:
stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
if (rd & Std.ERR) == Std.ERR:
stderrs[local_rank] = os.path.join(clogdir, "stderr.log")
t = ts[local_rank]
if t & Std.OUT == Std.OUT:
tee_stdouts[local_rank] = stdouts[local_rank]
if t & Std.ERR == Std.ERR:
tee_stderrs[local_rank] = stderrs[local_rank]
if (
self._local_ranks_filter
and local_rank not in self._local_ranks_filter
):
# If stream is tee'd, only write to file, but don't tail
if local_rank in tee_stdouts:
tee_stdouts.pop(local_rank, None)
if local_rank in tee_stderrs:
tee_stderrs.pop(local_rank, None)
# If stream is not redirected, don't print
if stdouts[local_rank] == SYS_STREAM:
stdouts[local_rank] = os.devnull
if stderrs[local_rank] == SYS_STREAM:
stderrs[local_rank] = os.devnull
error_file = os.path.join(clogdir, "error.json")
error_files[local_rank] = error_file
logger.info(
"Setting worker%s reply file to: %s", local_rank, error_file
)
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files)
def __repr__(self) -> str:
return (
f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, "
f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})"
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DefaultLogsSpecs):
return False
return (
self._root_log_dir == other._root_log_dir
and self._redirects == other._redirects
and self._tee == other._tee
and self._local_ranks_filter == other._local_ranks_filter
)
@dataclass
class RunProcsResult:
"""
Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``.
Note the following:
1. All fields are mapped by local rank
2. ``return_values`` - only populated for functions (not the binaries).
3. ``stdouts`` - path to stdout.log (empty string if no redirect)
4. ``stderrs`` - path to stderr.log (empty string if no redirect)
"""
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
def is_failed(self) -> bool:
return len(self.failures) > 0
class PContext(abc.ABC):
"""
The base class that standardizes operations over a set of processes that are launched via different mechanisms.
The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.
.. warning:: stdouts and stderrs should ALWAYS be a superset of
tee_stdouts and tee_stderrs (respectively) this is b/c
tee is implemented as a redirect + tail -f <stdout/stderr.log>
"""
def __init__(
self,
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
self.name = name
# validate that all mappings have the same number of keys and
# all local ranks are accounted for
nprocs = len(args)
# TODO log_line_prefixes can be exanded too
logs_dest = logs_specs.reify(envs)
_validate_full_rank(logs_dest.stdouts, nprocs, "stdouts")
_validate_full_rank(logs_dest.stderrs, nprocs, "stderrs")
self.entrypoint = entrypoint
self.args = args
self.envs = envs
self.stdouts = logs_dest.stdouts
self.stderrs = logs_dest.stderrs
self.error_files = logs_dest.error_files
self.nprocs = nprocs
self._stdout_tail = TailLog(
name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes
)
self._stderr_tail = TailLog(
name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes
)
def start(self) -> None:
"""Start processes using parameters defined in the constructor."""
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGTERM, _terminate_process_handler)
signal.signal(signal.SIGINT, _terminate_process_handler)
if not IS_WINDOWS:
signal.signal(signal.SIGHUP, _terminate_process_handler)
signal.signal(signal.SIGQUIT, _terminate_process_handler)
else:
logger.warning(
"Failed to register signal handlers since torchelastic is running on a child thread. "
"This could lead to orphaned worker processes if the torchrun is terminated."
)
self._start()
self._stdout_tail.start()
self._stderr_tail.start()
@abc.abstractmethod
def _start(self) -> None:
"""Start processes using strategy defined in a particular context."""
raise NotImplementedError
@abc.abstractmethod
def _poll(self) -> Optional[RunProcsResult]:
"""
Poll the run status of the processes running under this context.
This method follows an "all-or-nothing" policy and returns
a ``RunProcessResults`` object if either all processes complete
successfully or any process fails. Returns ``None`` if
all processes are still running.
"""
raise NotImplementedError
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
"""
Wait for the specified ``timeout`` seconds, polling every ``period`` seconds
for the processes to be done. Returns ``None`` if the processes are still running
on timeout expiry. Negative timeout values are interpreted as "wait-forever".
A timeout value of zero simply queries the status of the processes (e.g. equivalent
to a poll).
..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise
``SignalException`` when the signals received. It is up to the consumer of the code
to properly handle the exception. It is important not to swallow the exception otherwise
the process would not terminate. Example of the typical workflow can be:
.. code-block:: python
pc = start_processes(...)
try:
pc.wait(1)
.. do some other work
except SignalException as e:
pc.shutdown(e.sigval, timeout=30)
If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
received signal. If child processes will not terminate in the timeout time, the process will send
the SIGKILL.
"""
if timeout == 0:
return self._poll()
if timeout < 0:
timeout = sys.maxsize
expiry = time.time() + timeout
while time.time() < expiry:
pr = self._poll()
if pr:
return pr
time.sleep(period)
return None
@abc.abstractmethod
def pids(self) -> Dict[int, int]:
"""Return pids of processes mapped by their respective local_ranks."""
raise NotImplementedError
@abc.abstractmethod
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
"""
raise NotImplementedError
def close(
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
Args:
death_sig: Death signal to terminate processes.
timeout: Time to wait for processes to finish, if process is
still alive after this time, it will be terminated via SIGKILL.
"""
if not death_sig:
death_sig = _get_default_signal()
self._close(death_sig=death_sig, timeout=timeout)
if self._stdout_tail:
self._stdout_tail.stop()
if self._stderr_tail:
self._stderr_tail.stop()
def get_std_cm(std_rd: str, redirect_fn):
if IS_WINDOWS or IS_MACOS or not std_rd:
return nullcontext()
else:
return redirect_fn(std_rd)
def _wrap(
local_rank: int,
fn: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank]
env_ = envs[local_rank]
ret_val_ = ret_vals[local_rank]
stdout_rd = stdout_redirects[local_rank]
stderr_rd = stderr_redirects[local_rank]
stdout_cm = get_std_cm(stdout_rd, redirect_stdout)
stderr_cm = get_std_cm(stderr_rd, redirect_stderr)
for k, v in env_.items():
os.environ[k] = v
with stdout_cm, stderr_cm:
ret = record(fn)(*args_)
ret_val_.put(ret)
queue_finished_reading_event.wait()
class MultiprocessContext(PContext):
"""``PContext`` holding worker processes invoked as a function."""
def __init__(
self,
name: str,
entrypoint: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
start_method: str,
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
super().__init__(
name,
entrypoint,
args,
envs,
logs_specs,
log_line_prefixes,
)
self.start_method = start_method
# each ret_val queue will always contain a single element.
self._ret_vals = {
local_rank: mp.get_context(self.start_method).SimpleQueue()
for local_rank in range(self.nprocs)
}
# see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()
def _start(self):
if self._pc:
raise ValueError(
"The process context already initialized."
" Most likely the start method got called twice."
)
self._pc = mp.start_processes(
fn=_wrap,
args=(
self.entrypoint,
self.args,
self.envs,
self.stdouts,
self.stderrs,
self._ret_vals,
self._worker_finished_event,
),
nprocs=self.nprocs,
join=False,
daemon=False,
start_method=self.start_method,
)
def _is_done(self) -> bool:
return len(self._return_values) == self.nprocs
def _poll(self) -> Optional[RunProcsResult]:
assert self._pc is not None # assertion for mypy type checker
try:
# torch.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# timeout < 0 checks worker status and return immediately
# Join will never return success since we use synchronize.Event to wait
# for all processes to finish.
self._pc.join(-1)
# IMPORTANT: we use multiprocessing.Queue to carry worker return values
# back to the parent, the worker process will wait before terminating
# until all the buffered items are fed by the feeder thread to the underlying
# pipe. Hence to prevent deadlocks on large return values,
# we opportunistically try queue.get on each join call
# See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
for local_rank in range(0, self.nprocs):
return_queue = self._ret_vals[local_rank]
if not return_queue.empty():
# save the return values temporarily into a member var
self._return_values[local_rank] = return_queue.get()
if self._is_done():
# we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# At this point workers finished running the user function
# But the child process might still have not exited. Wait for them.
# pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits.
while not self._pc.join():
logger.debug(
"entrypoint fn finished, waiting for all child procs to exit..."
)
_validate_full_rank(
self._return_values, self.nprocs, "return_value queue"
)
self.close()
return RunProcsResult(
return_values=self._return_values,
stdouts=self.stdouts,
stderrs=self.stderrs,
)
else:
return None
except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
failed_local_rank = e.error_index
# entrypoint for MultiprocessContext will always be a Callable
fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr]
failed_proc = self._pc.processes[failed_local_rank]
error_filepath = self.error_files[failed_local_rank]
logger.exception(
"failed (exitcode: %s)"
" local_rank: %s (pid: %s)"
" of fn: %s (start_method: %s)",
failed_proc.exitcode,
failed_local_rank,
e.pid,
fn_name,
self.start_method,
)
self.close()
return RunProcsResult(
failures={
failed_local_rank: ProcessFailure(
local_rank=failed_local_rank,
pid=e.pid,
exitcode=failed_proc.exitcode,
error_file=error_filepath,
)
},
stdouts=self.stdouts,
stderrs=self.stderrs,
)
def pids(self) -> Dict[int, int]:
assert self._pc is not None # assertion for mypy type checking
return dict(enumerate(self._pc.pids()))
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if not self._pc:
return
for proc in self._pc.processes:
if proc.is_alive():
logger.warning(
"Closing process %s via signal %s", proc.pid, death_sig.name
)
try:
os.kill(proc.pid, death_sig)
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be raised, it is safe to ignore it.
pass
end = time.monotonic() + timeout
for proc in self._pc.processes:
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
proc.join(time_to_wait)
for proc in self._pc.processes:
if proc.is_alive():
logger.warning(
"Unable to shutdown process %s via %s, forcefully exiting via %s",
proc.pid,
death_sig,
_get_kill_signal(),
)
try:
os.kill(proc.pid, _get_kill_signal())
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be raised, it is safe to ignore it.
pass
proc.join()
class SubprocessContext(PContext):
"""``PContext`` holding worker processes invoked as a binary."""
def __init__(
self,
name: str,
entrypoint: str,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
super().__init__(
name,
entrypoint,
args,
envs,
logs_specs,
log_line_prefixes,
)
# state vector; _vdone[local_rank] -> is local_rank finished or not
self._running_local_ranks: Set[int] = set(range(self.nprocs))
self._failures: Dict[int, ProcessFailure] = {}
self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
def _start(self):
if self.subprocess_handlers:
raise ValueError(
"The subprocess handlers already initialized. Most likely the start method got called twice."
)
self.subprocess_handlers = {
local_rank: get_subprocess_handler(
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
args=self.args[local_rank],
env=self.envs[local_rank],
stdout=self.stdouts[local_rank],
stderr=self.stderrs[local_rank],
local_rank_id=local_rank,
)
for local_rank in range(self.nprocs)
}
def _poll(self) -> Optional[RunProcsResult]:
done_local_ranks = set()
for local_rank in self._running_local_ranks:
handler = self.subprocess_handlers[local_rank]
exitcode = handler.proc.poll()
if exitcode is not None:
done_local_ranks.add(local_rank)
if exitcode != 0: # failed or signaled
self._failures[local_rank] = ProcessFailure(
local_rank=local_rank,
pid=handler.proc.pid,
exitcode=exitcode,
error_file=self.error_files[local_rank],
)
# else: --> succeeded; nothing to do
self._running_local_ranks.difference_update(done_local_ranks)
# if ALL procs are finished or ANY have failed
if not self._running_local_ranks or self._failures:
self.close() # terminate all running procs
result = RunProcsResult(
failures=self._failures,
stdouts=self.stdouts,
stderrs=self.stderrs,
)
if result.is_failed():
first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
logger.error(
"failed (exitcode: %s)"
" local_rank: %s (pid: %s)"
" of binary: %s",
first_failure.exitcode,
first_failure.local_rank,
first_failure.pid,
self.entrypoint,
)
else:
# Populate return with dummy values. This provides consistency with MultiprocessingHandler
result.return_values = dict.fromkeys(range(self.nprocs))
return result
else: # there are no failures and procs still running
return None
def pids(self) -> Dict[int, int]:
return {
local_rank: sh.proc.pid
for local_rank, sh in self.subprocess_handlers.items()
}
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if not self.subprocess_handlers:
return
for handler in self.subprocess_handlers.values():
if handler.proc.poll() is None:
logger.warning(
"Sending process %s closing signal %s",
handler.proc.pid,
death_sig.name,
)
handler.close(death_sig=death_sig)
end = time.monotonic() + timeout
for handler in self.subprocess_handlers.values():
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
try:
handler.proc.wait(time_to_wait)
except subprocess.TimeoutExpired:
# Ignore the timeout expired exception, since
# the child process will be forcefully terminated via SIGKILL
pass
for handler in self.subprocess_handlers.values():
if handler.proc.poll() is None:
logger.warning(
"Unable to shutdown process %s via %s, forcefully exiting via %s",
handler.proc.pid,
death_sig,
_get_kill_signal(),
)
handler.close(death_sig=_get_kill_signal())
handler.proc.wait()

View File

@ -0,0 +1,383 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Each host in a distributed PyTorch job runs with a single TorchElastic agent,
and multiple workers (as children processes of the TorchElastic agent).
Since the workers are user-provided (your PyTorch script/job), TorchElastic
has a way to propagate errors on the trainers through the agent and up to the
scheduler, which ultimately informs the end-user about the state of the job
and applies any retry policies.
TorchElastic categorizes errors into 3 categories:
+----------------+----------------+--------------------------------------------------------------+
| Category | Sub-Category | Description |
+================+================+==============================================================+
| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) |
| +----------------+--------------------------------------------------------------+
| | Worker Failure | any failures on the worker child process |
+----------------+----------------+--------------------------------------------------------------+
| Platform Error | n/a | failures caused by the agent |
+----------------+----------------+--------------------------------------------------------------+
| Infra Error | n/a | failures outside the domain of the agent and workers |
| | | (e.g. host failures) |
+----------------+----------------+--------------------------------------------------------------+
All errors other than "Worker Failure" are either raised canonically from the
agent process or implicitly or explicitly crash the agent process. So the
standard language (python) provided exception handling strategies apply.
Worker Failures are special because the exception/failure originates on a different
process from the agent so the error needs to be propagated inter-process
(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process).
TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes`
to launch the workers which has a simple file based inter-process error propagation
built-in.
Any function or binary entrypoint decorated with :func:`record`
will write uncaught exceptions (with the trace information) to a file specified by the
environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent)
sets this env var on each child it launches, then aggregates the error files for all
children, and propagates the one with the **smallest** timestamp (e.g. the **first** error).
"""
import json
import os
import signal
import socket
import time
import warnings
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from string import Template
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from torch.distributed.elastic.utils.logging import get_logger
from .error_handler import ErrorHandler # noqa: F401
from .handlers import get_error_handler # noqa: F401
__all__ = [
"ProcessFailure",
"ChildFailedError",
"record",
"ErrorHandler",
"get_error_handler",
]
logger = get_logger(__name__)
JSON = Dict
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
_NOT_AVAILABLE = "<N/A>"
T = TypeVar("T")
@dataclass
class ProcessFailure:
"""
Represent the failed process result. When the worker process fails, it may record failure root cause into the file.
Tries to read the failure timestamp from the provided ``error_file``,
if the ``error_file`` does not exist, the timestamp is the current
timestamp (seconds since epoch).
The ``message`` field is a concise explanation of the failure. If
the error file exists then the message is obtained from the error file.
Otherwise one is generated based on the failure signature.
.. note:: It is assumed that the ``error_file`` is written by
``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``.
Otherwise the behavior is undefined.
"""
local_rank: int
pid: int
exitcode: int
error_file: str
error_file_data: JSON = field(init=False)
message: str = field(init=False)
timestamp: int = field(init=False)
def __post_init__(self):
self.error_file_data = _EMPTY_ERROR_DATA
if os.path.isfile(self.error_file):
try:
with open(self.error_file) as fp:
self.error_file_data = json.load(fp)
logger.debug(
"User process failed with error data: %s",
json.dumps(self.error_file_data, indent=2),
)
self.message, self.timestamp = self._get_error_data(
self.error_file_data
)
except Exception:
logger.exception("Failed to parse reply file: %s", self.error_file)
raise
else:
self._set_no_reply_file()
# make up an informative message if not already present
if not self.message:
# signals typically do not generate an error file message
if self.exitcode < 0:
self.message = (
f"Signal {-self.exitcode} ({self.signal_name()})"
f" received by PID {self.pid}"
)
else:
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
message = error_file_data["message"]
if isinstance(message, str):
timestamp = int(error_file_data.get("timestamp", 0))
else:
timestamp = int(message["extraInfo"]["timestamp"])
return (message, timestamp)
def _set_no_reply_file(self):
self.error_file = _NOT_AVAILABLE
self.error_file_data = _EMPTY_ERROR_DATA
self.message = ""
self.timestamp = int(time.time())
def signal_name(self) -> str:
if self.exitcode < 0:
# We don't want to kill the parent process trying to find the signal name.
# if the signal doesn't map to a known name, use not available.
try:
return signal.Signals(-self.exitcode).name
except Exception:
return _NOT_AVAILABLE
else:
return _NOT_AVAILABLE
def timestamp_isoformat(self):
"""Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS)."""
return datetime.fromtimestamp(self.timestamp).isoformat(sep="_")
GlobalRank = int
_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
time : ${time}
host : ${hostname}
rank : ${rank} (local_rank: ${local_rank})
exitcode : ${exitcode} (pid: ${pid})
error_file: ${error_file}
traceback : ${message}"""
# extra new lines before and after are intentional
_MSG_FORMAT_TEMPLATE = """
${boarder}
${title}
${section}
Failures:
${other_failures}
${section}
Root Cause (first observed failure):
${root_failure}
${boarder}"""
class ChildFailedError(Exception):
"""
Special exception type that can be raised from a function annotated with the
``@record`` decorator to have the child process' (root exception) propagate
up the stack as-is (e.g. without being wrapped in the parent's traceback).
Useful in cases where the parent is a simple nanny process
and the child (worker) processes are actually doing meaningful compute.
In this case, errors typically occur on the child process as the parent
is not doing anything non-trivial, and child errors should be propagated
to the scheduler for accurate root cause diagnostics.
.. note:: The propagation relies on error files rather than exception handling to
support both function and binary launches.
Example:
::
# process tree on a host (container)
0: scheduler-init-process:
|- 1: torchelastic_agent:
|- 2: trainer_0 (ok)
|- 3: trainer_1 (fail) -> error.json
|- ...
|- n+2: trainer_n (ok)
|- n+3: other processes
|- ...
In the example above, trainer 1's failure (written into error.json) is
the root cause and should be reported to the scheduler's init process.
The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})``
upon detecting trainer 1's failure which would propagate the contents
of trainer 1's error file to the scheduler's init process.
"""
def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
self.name = name
self.failures = failures
assert (
self.failures
) # does not make sense to create a ChildFaileError with no failures
super().__init__(self.format_msg())
def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
return rank, self.failures[rank]
def format_msg(self, boarder_delim="=", section_delim="-"):
title = f"{self.name} FAILED"
root_rank, root_failure = self.get_first_failure()
root_failure_fmt: str = ""
other_failures_fmt: List[str] = []
width = len(title)
for idx, (rank, failure) in enumerate(self.failures.items()):
fmt, w = self._format_failure(idx, rank, failure)
width = max(width, w)
if rank == root_rank:
root_failure_fmt = fmt
else:
other_failures_fmt.append(fmt)
# upper boundary on width
width = min(width, 60)
return Template(_MSG_FORMAT_TEMPLATE).substitute(
boarder=boarder_delim * width,
title=title,
section=section_delim * width,
root_failure=root_failure_fmt,
other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]),
)
def _format_failure(
self, idx: int, rank: int, failure: ProcessFailure
) -> Tuple[str, int]:
# failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
# or a dict (json) of the form
# {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
# so the display logic is:
# 1. if failure.message is not a dict (it is a str) just show it as is
# 2. else try to get the traceback (py_callstack)
# 3. if the traceback is not there, use the message
# 4. if the message is not there show <N/A>
msg = failure.message
if isinstance(failure.message, dict):
msg = (
failure.message.get("extraInfo", {})
.get("py_callstack", failure.message.get("message", "<N/A>"))
.replace("\n", "\n ") # to properly indent the traceback
)
fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
idx=idx,
time=failure.timestamp_isoformat(),
hostname=socket.getfqdn(),
rank=rank,
local_rank=failure.local_rank,
exitcode=failure.exitcode,
pid=failure.pid,
error_file=failure.error_file,
message=msg,
)
width = 0
for line in fmt.split("\n"):
width = max(width, len(line))
return fmt, width
def record(
fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
) -> Callable[..., T]:
"""
Syntactic sugar to record errors/exceptions that happened in the decorated
function using the provided ``error_handler``.
Using this decorator is equivalent to:
::
error_handler = get_error_handler()
error_handler.initialize()
try:
foobar()
except ChildFailedError as e:
_, failure = e.get_first_failure()
error_handler.dump_error_file(failure.error_file, failure.exitcode)
raise
except Exception as e:
error_handler.record(e)
raise
.. important:: use this decorator once per process at the top level method,
typically this is the main method.
Example
::
@record
def main():
pass
if __name__=="__main__":
main()
"""
if not error_handler:
error_handler = get_error_handler()
def wrap(f):
@wraps(f)
def wrapper(*args, **kwargs):
assert error_handler is not None # assertion for mypy type checker
error_handler.initialize()
try:
return f(*args, **kwargs)
except SystemExit as se:
# For run_path based entrypoints, SystemExit with code = 0 will never exit.
# Handling it here by returning a value:
if se.code == 0:
return None
else:
raise
except ChildFailedError as e:
rank, failure = e.get_first_failure()
if failure.error_file != _NOT_AVAILABLE:
error_handler.dump_error_file(failure.error_file, failure.exitcode)
else:
logger.info(
(
"local_rank %s FAILED with no error file."
" Decorate your entrypoint fn with @record for traceback info."
" See: https://pytorch.org/docs/stable/elastic/errors.html",
rank,
)
)
raise
except Exception as e:
error_handler.record_exception(e)
raise
return wrapper
return wrap(fn)

View File

@ -0,0 +1,166 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import faulthandler
import json
import logging
import os
import time
import traceback
import warnings
from typing import Any, Dict, Optional
__all__ = ["ErrorHandler"]
logger = logging.getLogger(__name__)
class ErrorHandler:
"""
Write the provided exception object along with some other metadata about
the error in a structured way in JSON format to an error file specified by the
environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment
variable is not set, then simply logs the contents of what would have been
written to the error file.
This handler may be subclassed to customize the handling of the error.
Subclasses should override ``initialize()`` and ``record_exception()``.
"""
def _get_error_file_path(self) -> Optional[str]:
"""
Return the error file path.
May return ``None`` to have the structured error be logged only.
"""
return os.environ.get("TORCHELASTIC_ERROR_FILE", None)
def initialize(self) -> None:
"""
Call prior to running code that we wish to capture errors/exceptions.
Typically registers signal/fault handlers. Users can override this
function to add custom initialization/registrations that aid in
propagation/information of errors/signals/exceptions/faults.
"""
try:
faulthandler.enable(all_threads=True)
except Exception as e:
warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}")
def _write_error_file(self, file_path: str, error_msg: str) -> None:
"""Write error message to the file."""
try:
with open(file_path, "w") as fp:
fp.write(error_msg)
except Exception as e:
warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}")
def record_exception(self, e: BaseException) -> None:
"""
Write a structured information about the exception into an error file in JSON format.
If the error file cannot be determined, then logs the content
that would have been written to the error file.
"""
file = self._get_error_file_path()
if file:
data = {
"message": {
"message": f"{type(e).__name__}: {e}",
"extraInfo": {
"py_callstack": traceback.format_exc(),
"timestamp": str(int(time.time())),
},
}
}
with open(file, "w") as fp:
json.dump(data, fp)
def override_error_code_in_rootcause_data(
self,
rootcause_error_file: str,
rootcause_error: Dict[str, Any],
error_code: int = 0,
):
"""Modify the rootcause_error read from the file, to correctly set the exit code."""
if "message" not in rootcause_error:
logger.warning(
"child error file (%s) does not have field `message`. \n"
"cannot override error code: %s",
rootcause_error_file,
error_code,
)
elif isinstance(rootcause_error["message"], str):
logger.warning(
"child error file (%s) has a new message format. \n"
"skipping error code override",
rootcause_error_file,
)
else:
rootcause_error["message"]["errorCode"] = error_code
def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
"""Dump parent error file from child process's root cause error and error code."""
with open(rootcause_error_file) as fp:
rootcause_error = json.load(fp)
# Override error code since the child process cannot capture the error code if it
# is terminated by signals like SIGSEGV.
if error_code:
self.override_error_code_in_rootcause_data(
rootcause_error_file, rootcause_error, error_code
)
logger.debug(
"child error file (%s) contents:\n" "%s",
rootcause_error_file,
json.dumps(rootcause_error, indent=2),
)
my_error_file = self._get_error_file_path()
if my_error_file:
# Guard against existing error files
# This can happen when the child is created using multiprocessing
# and the same env var (TORCHELASTIC_ERROR_FILE) is used on the
# parent and child to specify the error files (respectively)
# because the env vars on the child is set in the wrapper function
# and by default the child inherits the parent's env vars, if the child
# process receives a signal before the wrapper function kicks in
# and the signal handler writes to the error file, then the child
# will write to the parent's error file. In this case just log the
# original error file contents and overwrite the error file.
self._rm(my_error_file)
self._write_error_file(my_error_file, json.dumps(rootcause_error))
logger.info("dumped error file to parent's %s", my_error_file)
else:
logger.error(
"no error file defined for parent, to copy child error file (%s)",
rootcause_error_file,
)
def _rm(self, my_error_file):
if os.path.isfile(my_error_file):
# Log the contents of the original file.
with open(my_error_file) as fp:
try:
original = json.dumps(json.load(fp), indent=2)
logger.warning(
"%s already exists"
" and will be overwritten."
" Original contents:\n%s",
my_error_file,
original,
)
except json.decoder.JSONDecodeError:
logger.warning(
"%s already exists"
" and will be overwritten."
" Unable to load original contents:\n",
my_error_file,
)
os.remove(my_error_file)

View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Multiprocessing error-reporting module
from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
__all__ = ["get_error_handler"]
def get_error_handler():
return ErrorHandler()

View File

@ -0,0 +1,104 @@
# mypy: allow-untyped-defs
# !/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Taken and modified from original source:
# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
import ctypes
import logging
import os
import sys
from contextlib import contextmanager
from functools import partial
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
logger = logging.getLogger(__name__)
def get_libc():
if IS_WINDOWS or IS_MACOS:
logger.warning(
"NOTE: Redirects are currently not supported in Windows or MacOs."
)
return None
else:
return ctypes.CDLL("libc.so.6")
libc = get_libc()
def _c_std(stream: str):
return ctypes.c_void_p.in_dll(libc, stream)
def _python_std(stream: str):
return {"stdout": sys.stdout, "stderr": sys.stderr}[stream]
_VALID_STD = {"stdout", "stderr"}
@contextmanager
def redirect(std: str, to_file: str):
"""
Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``.
This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``).
See usage for details.
Directory of ``dst_filename`` is assumed to exist and the destination file
is overwritten if it already exists.
.. note:: Due to buffering cross source writes are not guaranteed to
appear in wall-clock order. For instance in the example below
it is possible for the C-outputs to appear before the python
outputs in the log file.
Usage:
::
# syntactic-sugar for redirect("stdout", "tmp/stdout.log")
with redirect_stdout("/tmp/stdout.log"):
print("python stdouts are redirected")
libc = ctypes.CDLL("libc.so.6")
libc.printf(b"c stdouts are also redirected"
os.system("echo system stdouts are also redirected")
print("stdout restored")
"""
if std not in _VALID_STD:
raise ValueError(
f"unknown standard stream <{std}>, must be one of {_VALID_STD}"
)
c_std = _c_std(std)
python_std = _python_std(std)
std_fd = python_std.fileno()
def _redirect(dst):
libc.fflush(c_std)
python_std.flush()
os.dup2(dst.fileno(), std_fd)
with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
_redirect(dst)
try:
yield
finally:
_redirect(orig_std)
redirect_stdout = partial(redirect, "stdout")
redirect_stderr = partial(redirect, "stderr")

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.distributed.elastic.multiprocessing.subprocess_handler.handlers import (
get_subprocess_handler,
)
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
)
__all__ = ["SubprocessHandler", "get_subprocess_handler"]

View File

@ -0,0 +1,34 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Tuple
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
)
__all__ = ["get_subprocess_handler"]
def get_subprocess_handler(
entrypoint: str,
args: Tuple,
env: Dict[str, str],
stdout: str,
stderr: str,
local_rank_id: int,
):
return SubprocessHandler(
entrypoint=entrypoint,
args=args,
env=env,
stdout=stdout,
stderr=stderr,
local_rank_id=local_rank_id,
)

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import signal
import subprocess
import sys
from typing import Any, Dict, Optional, Tuple
__all__ = ["SubprocessHandler"]
IS_WINDOWS = sys.platform == "win32"
def _get_default_signal() -> signal.Signals:
"""Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
class SubprocessHandler:
"""
Convenience wrapper around python's ``subprocess.Popen``. Keeps track of
meta-objects associated to the process (e.g. stdout and stderr redirect fds).
"""
def __init__(
self,
entrypoint: str,
args: Tuple,
env: Dict[str, str],
stdout: Optional[str],
stderr: Optional[str],
local_rank_id: int,
):
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
# inherit parent environment vars
env_vars = os.environ.copy()
env_vars.update(env)
args_str = (entrypoint, *[str(e) for e in args])
self.local_rank_id = local_rank_id
self.proc: subprocess.Popen = self._popen(args_str, env_vars)
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
kwargs: Dict[str, Any] = {}
if not IS_WINDOWS:
kwargs["start_new_session"] = True
return subprocess.Popen(
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
# `Tuple[str, *Tuple[Any, ...]]`.
args=args,
env=env,
stdout=self._stdout,
stderr=self._stderr,
**kwargs,
)
def close(self, death_sig: Optional[signal.Signals] = None) -> None:
if not death_sig:
death_sig = _get_default_signal()
if IS_WINDOWS:
self.proc.send_signal(death_sig)
else:
os.killpg(self.proc.pid, death_sig)
if self._stdout:
self._stdout.close()
if self._stderr:
self._stderr.close()

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import time
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING:
from concurrent.futures._base import Future
__all__ = ["tail_logfile", "TailLog"]
logger = logging.getLogger(__name__)
def tail_logfile(
header: str, file: str, dst: TextIO, finished: Event, interval_sec: float
):
while not os.path.exists(file):
if finished.is_set():
return
time.sleep(interval_sec)
with open(file, errors="replace") as fp:
while True:
line = fp.readline()
if line:
dst.write(f"{header}{line}")
else: # reached EOF
if finished.is_set():
# log line producer is finished
break
else:
# log line producer is still going
# wait for a bit before looping again
time.sleep(interval_sec)
class TailLog:
"""
Tail the given log files.
The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until
the log files are created by the producer and will tail the contents of the
log files until the ``stop()`` method is called.
.. warning:: ``TailLog`` will wait indefinitely for the log file to be created!
Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``,
where the ``name`` is user-provided and ``idx`` is the index of the log file
in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the
header for each log file.
Usage:
::
log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"}
tailer = TailLog("trainer", log_files, sys.stdout).start()
# actually run the trainers to produce 0_stdout.log and 1_stdout.log
run_trainers()
tailer.stop()
# once run_trainers() start writing the ##_stdout.log files
# the tailer will print to sys.stdout:
# >>> [trainer0]:log_line1
# >>> [trainer1]:log_line1
# >>> [trainer0]:log_line2
# >>> [trainer0]:log_line3
# >>> [trainer1]:log_line2
.. note:: Due to buffering log lines between files may not necessarily
be printed out in order. You should configure your application's
logger to suffix each log line with a proper timestamp.
"""
def __init__(
self,
name: str,
log_files: Dict[int, str],
dst: TextIO,
log_line_prefixes: Optional[Dict[int, str]] = None,
interval_sec: float = 0.1,
):
n = len(log_files)
self._threadpool = None
if n > 0:
self._threadpool = ThreadPoolExecutor(
max_workers=n,
thread_name_prefix=f"{self.__class__.__qualname__}_{name}",
)
self._name = name
self._dst = dst
self._log_files = log_files
self._log_line_prefixes = log_line_prefixes
self._finished_events: Dict[int, Event] = {
local_rank: Event() for local_rank in log_files.keys()
}
self._futs: List[Future] = []
self._interval_sec = interval_sec
self._stopped = False
def start(self) -> "TailLog":
if not self._threadpool:
return self
for local_rank, file in self._log_files.items():
header = f"[{self._name}{local_rank}]:"
if self._log_line_prefixes and local_rank in self._log_line_prefixes:
header = self._log_line_prefixes[local_rank]
self._futs.append(
self._threadpool.submit(
tail_logfile,
header=header,
file=file,
dst=self._dst,
finished=self._finished_events[local_rank],
interval_sec=self._interval_sec,
)
)
return self
def stop(self) -> None:
for finished in self._finished_events.values():
finished.set()
for local_rank, f in enumerate(self._futs):
try:
f.result()
except Exception as e:
logger.error(
"error in log tailor for %s%s. %s: %s",
self._name,
local_rank,
e.__class__.__qualname__,
e,
)
if self._threadpool:
self._threadpool.shutdown(wait=True)
self._stopped = True
def stopped(self) -> bool:
return self._stopped

View File

@ -0,0 +1,166 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
In the context of Torch Distributed Elastic we use the term *rendezvous* to
refer to a particular functionality that combines a **distributed
synchronization** primitive with **peer discovery**.
It is used by Torch Distributed Elastic to gather participants of a training
job (i.e. nodes) such that they all agree on the same list of participants and
everyone's roles, as well as make a consistent collective decision on when
training can begin/resume.
Torch Distributed Elastic rendezvous provides the following critical
functionalities:
**Barrier**:
Nodes performing rendezvous will all block until the rendezvous is considered
complete - this happens when at least ``min`` total number of nodes have joined
the rendezvous barrier (for the same job). This also implies the barrier is not
necessarily of fixed size.
There's an additional small waiting time after reaching ``min`` number of
nodes - this is used to ensure the rendezvous is not completed "too quickly"
(which could potentially exclude additional nodes attempting to join at
approximately the same time).
If ``max`` number of nodes is gathered at the barrier, the rendezvous is
completed immediately.
There's also an overall timeout which causes the rendezvous to fail if ``min``
number of nodes is never reached - this is meant to be a simple fail-safe to
help release partially allocated job resources, in case there's a problem with
the resource manager, and is meant to be interpreted as non-retryable.
**Exclusivity**:
A simple distributed barrier would not be sufficient, as we also need to ensure
that only one group of nodes exists at any given time (for a given job). In
other words, new nodes (i.e. joining late) should not be able to form a parallel
independent group of workers for the same job.
Torch Distributed Elastic rendezvous ensures that if a group of nodes has
already completed a rendezvous (and hence might already be training), then
additional "late" nodes attempting to rendezvous will only announce themselves
as waiting, and will have to wait until the (previously completed) existing
rendezvous is destroyed first.
**Consistency**:
When a rendezvous is completed, all its members will agree on the job membership
and everyone's role in it. This role is represented using an integer, called
rank, that is between between 0 and world size.
Note that ranks are *not stable*, in the sense that the same node can be
assigned a different rank in the next (re-)rendezvous.
**Fault-tolerance**:
Torch Distributed Elastic rendezvous is designed to tolerate node failures
during the rendezvous process. Should a process crash (or lose network
connectivity, etc), between joining the rendezvous and it being completed, then
a re-rendezvous with remaining healthy nodes will happen automatically.
A node can also fail *after* it has completed (or *has been observered* by other
nodes to have completed) the rendezvous - this scenario will be handled by the
Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a
re-rendezvous).
**Shared key-value store**:
When the rendezvous is completed, a shared key-value store is created and
returned. This store implements a ``torch.distributed.Store`` API (see
`distributed communication docs
<https://pytorch.org/docs/stable/distributed.html>`__).
This store is only shared by the members of the completed rendezvous. It
is intended to be used by Torch Distributed Elastic to exchange information
necessary to initialize job control and data-planes.
**Waiting workers and rendezvous closing**:
Torch Distributed Elastic rendezvous handler object provides additional
functionalities, which are technically not part of the rendezvous process:
1. Querying how many workers arrived late at the barrier, who can participate in
*next* rendezvous.
2. Setting the rendezvous *closed* to signal all nodes not to participate in
next rendezvous.
**DynamicRendezvousHandler**:
Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler`
class that implements the rendezvous mechanism described above. It is a backend-
agnostic type that expects a particular :py:class:`.RendezvousBackend` instance
to be specified during construction.
Torch distributed users can either implement their own backend type or use one
of the following implementations that come with PyTorch:
- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default
``TCPStore``) as the rendezvous backend. The main advantage of using a C10d
store is that it requires no 3rd-party dependency (such as etcd) to establish
a rendezvous.
- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy
:py:class:`.EtcdRendezvousHandler` class. Passing an
:py:class:`.EtcdRendezvousBackend` instance to
:py:class:`.DynamicRendezvousHandler` is functionally equivalent to
instantiating an :py:class:`.EtcdRendezvousHandler`.
::
store = TCPStore("localhost")
backend = C10dRendezvousBackend(store, "my_run_id")
rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id",
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
)
"""
from .api import (
rendezvous_handler_registry,
RendezvousClosedError,
RendezvousConnectionError,
RendezvousError,
RendezvousGracefulExitError,
RendezvousHandler,
RendezvousHandlerCreator,
RendezvousHandlerRegistry,
RendezvousInfo,
RendezvousParameters,
RendezvousStateError,
RendezvousStoreInfo,
RendezvousTimeoutError,
)
from .registry import _register_default_handlers
_register_default_handlers()
__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousGracefulExitError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousInfo",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousStoreInfo",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]

View File

@ -0,0 +1,379 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import socket
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, Optional
from torch.distributed import Store
from torch.distributed.elastic.utils.distributed import get_free_port as _get_free_port
__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousGracefulExitError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousInfo",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousStoreInfo",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]
class RendezvousError(Exception):
"""Represents the base type for rendezvous errors."""
class RendezvousClosedError(RendezvousError):
"""Raised when a rendezvous is closed."""
class RendezvousTimeoutError(RendezvousError):
"""Raised when a rendezvous did not complete on time."""
class RendezvousConnectionError(RendezvousError):
"""Raised when the connection to a rendezvous backend has failed."""
class RendezvousStateError(RendezvousError):
"""Raised when the state of a rendezvous is corrupt."""
class RendezvousGracefulExitError(RendezvousError):
"""Raised when node wasn't not included in rendezvous and gracefully exits.
Exception is a mechanism to exit the stack, however does not mean a failure.
"""
@dataclass
class RendezvousStoreInfo:
"""Store address and port that can be used to bootstrap trainer distributed comms"""
MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR"
MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT"
master_addr: str
master_port: int
@staticmethod
def build(
rank: int, store: Store, local_addr: Optional[str]
) -> "RendezvousStoreInfo":
"""Factory method, finds unused new port on rank0 host and addr/port info with all ranks.
If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor.
Args:
rank: rank of the current node
store: store to use for rendezvous
local_addr: address of the current node, if not provided will be resolved from hostname
"""
# TODO swap to collectives comms API
if rank == 0:
addr = local_addr or socket.getfqdn()
port = _get_free_port()
store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type]
store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type]
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
port = int(
store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")
)
return RendezvousStoreInfo(master_addr=addr, master_port=port)
class RendezvousInfo:
"""Holds the information about the rendezvous."""
def __init__(
self,
store: Store,
rank: int,
world_size: int,
bootstrap_store_info: RendezvousStoreInfo,
):
self._store = store
self._rank = rank
self._world_size = world_size
self._bootstrap_store_info = bootstrap_store_info
@property
def store(self) -> Store:
"""Store used by torchelastic control plane"""
return self._store
@property
def rank(self) -> int:
"""Rank within a group"""
return self._rank
@property
def world_size(self) -> int:
"""Global group size"""
return self._world_size
@property
def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]:
"""Store information that can used by trainer code to bootstrap distributed comms."""
return self._bootstrap_store_info
class RendezvousHandler(ABC):
"""Main rendezvous interface.
Note:
Distributed Torch users normally **do not** need to implement their own
``RendezvousHandler``. An implementation based on C10d Store is already
provided, and is recommended for most users.
"""
@abstractmethod
def get_backend(self) -> str:
"""Return the name of the rendezvous backend."""
@property
def use_agent_store(self) -> bool:
"""Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user
applications and will be available during application lifecyle.
Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`.
Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store.
"""
return False
@abstractmethod
def next_rendezvous(self) -> RendezvousInfo:
"""Main entry-point into the rendezvous barrier.
Blocks until the rendezvous is complete and the current process is
included in the formed worker group, or a timeout occurs, or the
rendezvous was marked closed.
Returns:
Instance of :py:class:`RendezvousInfo`.
Raises:
RendezvousClosedError:
The rendezvous is closed.
RendezvousConnectionError:
The connection to the rendezvous backend has failed.
RendezvousStateError:
The rendezvous state is corrupt.
RendezvousTimeoutError:
The rendezvous did not complete on time.
"""
@abstractmethod
def is_closed(self) -> bool:
"""Check whether the rendezvous has been closed.
A closed rendezvous means all future attempts to re-rendezvous within
same job will fail.
``is_closed()`` and :py:meth:`set_closed` have semantics of eventual
propagation and should not be used for synchronization. The intention is
that if at least one node decides the job is finished, it will close the
rendezvous, and other nodes will soon observe this and stop running as
well.
"""
@abstractmethod
def set_closed(self):
"""Mark the rendezvous as closed."""
@abstractmethod
def num_nodes_waiting(self) -> int:
"""Return the number of nodes who arrived late at the rendezvous
barrier, hence were not included in the current worker group.
Callers should periodically call this method to check whether new
nodes are waiting to join the job and if so admit them by calling
:py:meth:`next_rendezvous()` (re-rendezvous).
"""
@abstractmethod
def get_run_id(self) -> str:
"""Return the run id of the rendezvous.
The run id is a user-defined id that uniquely identifies an instance of
a distributed application. It typically maps to a job id and is used to
allow nodes to join the correct distributed application.
"""
@abstractmethod
def shutdown(self) -> bool:
"""Close all resources that were open for the rendezvous.
Example::
rdzv_handler = ...
try:
store, rank, world_size = rdzv_handler.next_rendezvous()
finally:
rdzv_handler.shutdown()
"""
class RendezvousParameters:
"""Hold the parameters to construct a :py:class:`RendezvousHandler`.
Args:
backend:
The name of the backend to use to handle the rendezvous.
endpoint:
The endpoint of the rendezvous, usually in form <hostname>[:<port>].
run_id:
The id of the rendezvous.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
local_addr:
The address of the local node.
**kwargs:
Additional parameters for the specified backend.
"""
def __init__(
self,
backend: str,
endpoint: str,
run_id: str,
min_nodes: int,
max_nodes: int,
local_addr: Optional[str] = None,
**kwargs,
):
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
if min_nodes < 1:
raise ValueError(
f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
)
if max_nodes < min_nodes:
raise ValueError(
f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
f"equal to the minimum number of rendezvous nodes ({min_nodes})."
)
self.backend = backend
self.endpoint = endpoint
self.run_id = run_id
self.min_nodes = min_nodes
self.max_nodes = max_nodes
self.config = kwargs
self.local_addr = local_addr
def get(self, key: str, default: Any = None) -> Any:
"""Return the value for ``key`` if ``key`` exists, else ``default``."""
return self.config.get(key, default)
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
"""Return the value for ``key`` as a ``bool``."""
value = self.get(key, default)
if value is None or isinstance(value, bool):
return value
if isinstance(value, int):
if value == 1:
return True
if value == 0:
return False
elif isinstance(value, str):
if value.lower() in ["1", "true", "t", "yes", "y"]:
return True
if value.lower() in ["0", "false", "f", "no", "n"]:
return False
raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
)
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
"""Return the value for ``key`` as an ``int``."""
value = self.get(key, default)
if value is None:
return value
try:
return int(value)
except ValueError as e:
raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid integer "
"value."
) from e
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
class RendezvousHandlerRegistry:
"""Represent a registry of :py:class:`RendezvousHandler` backends."""
_registry: Dict[str, RendezvousHandlerCreator]
def __init__(self) -> None:
self._registry = {}
def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
"""Register a new rendezvous backend.
Args:
backend:
The name of the backend.
creator:
The callback to invoke to construct the
:py:class:`RendezvousHandler`.
"""
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
current_creator: Optional[RendezvousHandlerCreator]
try:
current_creator = self._registry[backend]
except KeyError:
current_creator = None
if current_creator is not None and current_creator != creator:
raise ValueError(
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
f"is already registered with '{current_creator}'."
)
self._registry[backend] = creator
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""Create a new :py:class:`RendezvousHandler`."""
try:
creator = self._registry[params.backend]
except KeyError as e:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call `{self.register.__name__}`?"
) from e
handler = creator(params)
# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
f"backend '{params.backend}'."
)
return handler
# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()

View File

@ -0,0 +1,273 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import binascii
import logging
import os
import tempfile
from base64 import b64decode, b64encode
from datetime import timedelta
from typing import Any, cast, Optional, Tuple
from torch.distributed import FileStore, Store, TCPStore
from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
from .api import (
RendezvousConnectionError,
RendezvousError,
RendezvousParameters,
RendezvousStateError,
)
from .dynamic_rendezvous import RendezvousBackend, Token
from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
logger = logging.getLogger(__name__)
# default port for the TCP store
DEFAULT_PORT = 29400
class C10dRendezvousBackend(RendezvousBackend):
"""Represents a C10d-backed rendezvous backend.
Args:
store:
The :py:class:`torch.distributed.Store` instance to use to
communicate with the C10d store.
run_id:
The run id of the rendezvous.
"""
# See the explanation in the __init__ method.
_NULL_SENTINEL = "Y2FuaW1hZGFt"
_store: Store
_key: str
def __init__(self, store: Store, run_id: str) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
self._store = store
self._key = "torch.rendezvous." + run_id
# The read operation of a store blocks the caller until the specified
# key becomes available. This behavior makes it tricky to use a store
# as a regular key-value dictionary.
#
# As a workaround we initially set a sentinel value as the rendezvous
# state. Whenever this value gets returned we treat it as a None.
self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
@property
def name(self) -> str:
"""See base class."""
return "c10d"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
base64_state: bytes = self._call_store("get", self._key)
return self._decode_state(base64_state)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state_str: str = b64encode(state).decode()
if token:
# Shortcut if we know for sure that the token is not valid.
if not isinstance(token, bytes):
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
token = token.decode()
else:
token = self._NULL_SENTINEL
base64_state: bytes = self._call_store(
"compare_set", self._key, token, base64_state_str
)
state_token_pair = self._decode_state(base64_state)
if state_token_pair is None:
return None
new_state, new_token = state_token_pair
# C10d Store's compare_set method does not offer an easy way to find out
# whether our write attempt was successful. As a brute-force solution we
# perform a bitwise comparison of our local state and the remote state.
return new_state, new_token, new_state == state
def _call_store(self, store_op: str, *args, **kwargs) -> Any:
try:
return getattr(self._store, store_op)(*args, **kwargs)
except (ValueError, RuntimeError, TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
if base64_state == self._NULL_SENTINEL.encode():
return None
try:
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
) from exc
return state, base64_state
def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT)
cfg_is_host = params.get_as_bool("is_host")
# If the user has explicitly specified whether our process should host the
# the store, respect it.
if cfg_is_host is not None:
is_host = cfg_is_host
# Otherwise try to determine whether we are the host based on our hostname
# and IP address.
else:
is_host = _matches_machine_hostname(host)
# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")
# In specific cases we attempt to instantiate the store twice. For details
# see the explanation in the except clause below.
for is_server in [is_host, False]:
try:
store = TCPStore(
host,
port,
is_master=is_server,
multi_tenant=True,
timeout=timedelta(seconds=read_timeout),
)
if is_server:
msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
construct_and_record_rdzv_event(
run_id=params.run_id, message=msg, node_state=NodeState.INIT
)
logger.info(msg)
break
except (ValueError, RuntimeError, TimeoutError) as exc:
# If we heuristically inferred the value of is_host as True and our
# first attempt to instantiate the TCP store has failed, try it one
# more time with is_host set to False. As an edge case there can be
# more than one process that is part of the same rendezvous on this
# machine and only one of them will eventually host the store.
if not is_server or cfg_is_host is not None:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store # type: ignore[possibly-undefined]
def _create_file_store(params: RendezvousParameters) -> FileStore:
# If a user specifies an endpoint, we treat it as a path to a file.
if params.endpoint:
path = params.endpoint
else:
try:
# The temporary file is readable and writable only by the user of
# this process.
_, path = tempfile.mkstemp()
except OSError as exc:
raise RendezvousError(
"The file creation for C10d store has failed. See inner exception for details."
) from exc
try:
store = FileStore(path)
except (ValueError, RuntimeError) as exc:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store
def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
"""Create a new :py:class:`C10dRendezvousBackend` from the specified parameters.
+--------------+-----------------------------------------------------------+
| Parameter | Description |
+==============+===========================================================+
| store_type | The type of the C10d store. The currently supported types |
| | are "tcp" and "file" which correspond to |
| | :py:class:`torch.distributed.TCPStore` and |
| | :py:class:`torch.distributed.FileStore`, respectively. |
| | Defaults to "tcp". |
+--------------+-----------------------------------------------------------+
| read_timeout | The read timeout, in seconds, for store operations. |
| | Defaults to 60 seconds. |
| | |
| | Note this only applies to |
| | :py:class:`torch.distributed.TCPStore`. It is not relevant|
| | to :py:class:`torch.distributed.FileStore` which does not |
| | take in timeout as a parameter. |
+--------------+-----------------------------------------------------------+
| is_host | A boolean value indicating whether this backend instance |
| | will host the C10d store. If not specified it will be |
| | inferred heuristically by matching the hostname or the IP |
| | address of this machine against the specified rendezvous |
| | endpoint. Defaults to ``None``. |
| | |
| | Note that this configuration option only applies to |
| | :py:class:`torch.distributed.TCPStore`. In normal |
| | circumstances you can safely skip it; the only time when |
| | it is needed is if its value cannot be correctly |
| | determined (e.g. the rendezvous endpoint has a CNAME as |
| | the hostname or does not match the FQDN of the machine). |
+--------------+-----------------------------------------------------------+
"""
# As of today we only support TCPStore and FileStore. Other store types do
# not have the required functionality (e.g. compare_set) yet.
store_type = params.get("store_type", "tcp").strip().lower()
store: Store
try:
if store_type == "file":
store = _create_file_store(params)
elif store_type == "tcp":
store = _create_tcp_store(params)
else:
raise ValueError(
"Invalid store type given. Currently only supports file and tcp."
)
backend = C10dRendezvousBackend(store, params.run_id)
except Exception as e:
construct_and_record_rdzv_event(
message=f"{type(e).__name__}: {str(e)}",
run_id=params.run_id,
node_state=NodeState.FAILED,
)
raise
return backend, store

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,217 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import binascii
from base64 import b64decode, b64encode
from typing import cast, Optional, Tuple
import urllib3.exceptions # type: ignore[import]
from etcd import ( # type: ignore[import]
Client as EtcdClient,
EtcdAlreadyExist,
EtcdCompareFailed,
EtcdException,
EtcdKeyNotFound,
EtcdResult,
)
from torch.distributed import Store
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
class EtcdRendezvousBackend(RendezvousBackend):
"""Represents an etcd-based rendezvous backend.
Args:
client:
The ``etcd.Client`` instance to use to communicate with etcd.
run_id:
The run id of the rendezvous.
key_prefix:
The path under which to store the rendezvous state in etcd.
ttl:
The TTL of the rendezvous state. If not specified, defaults to two hours.
"""
_DEFAULT_TTL = 7200 # 2 hours
_client: EtcdClient
_key: str
_ttl: int
def __init__(
self,
client: EtcdClient,
run_id: str,
key_prefix: Optional[str] = None,
ttl: Optional[int] = None,
) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
self._client = client
if key_prefix:
self._key = key_prefix + "/" + run_id
else:
self._key = run_id
if ttl and ttl > 0:
self._ttl = ttl
else:
self._ttl = self._DEFAULT_TTL
@property
def name(self) -> str:
"""See base class."""
return "etcd-v2"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
try:
result = self._client.read(self._key)
except EtcdKeyNotFound:
return None
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
return self._decode_state(result)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state = b64encode(state).decode()
kwargs = {}
def get_state():
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
if token:
try:
token = int(token)
except ValueError:
return get_state()
if token:
kwargs["prevIndex"] = token
else:
kwargs["prevExist"] = False
try:
result = self._client.write(self._key, base64_state, self._ttl, **kwargs)
except (EtcdAlreadyExist, EtcdCompareFailed):
result = None
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
if result is None:
return get_state()
tmp = *self._decode_state(result), True
return tmp
def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]:
base64_state = result.value.encode()
try:
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
) from exc
return state, result.modifiedIndex
def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379)
# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")
# The communication protocol
protocol = params.get("protocol", "http").strip().lower()
if protocol != "http" and protocol != "https":
raise ValueError("The protocol must be HTTP or HTTPS.")
# The SSL client certificate
ssl_cert = params.get("ssl_cert")
if ssl_cert:
ssl_cert_key = params.get("ssl_cert_key")
if ssl_cert_key:
# The etcd client expects the certificate key as the second element
# of the `cert` tuple.
ssl_cert = (ssl_cert, ssl_cert_key)
# The root certificate
ca_cert = params.get("ca_cert")
try:
return EtcdClient(
host,
port,
read_timeout=read_timeout,
protocol=protocol,
cert=ssl_cert,
ca_cert=ca_cert,
allow_reconnect=True,
)
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
"""Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters.
+--------------+-----------------------------------------------------------+
| Parameter | Description |
+==============+===========================================================+
| read_timeout | The read timeout, in seconds, for etcd operations. |
| | Defaults to 60 seconds. |
+--------------+-----------------------------------------------------------+
| protocol | The protocol to use to communicate with etcd. Valid |
| | values are "http" and "https". Defaults to "http". |
+--------------+-----------------------------------------------------------+
| ssl_cert | The path to the SSL client certificate to use along with |
| | HTTPS. Defaults to ``None``. |
+--------------+-----------------------------------------------------------+
| ssl_cert_key | The path to the private key of the SSL client certificate |
| | to use along with HTTPS. Defaults to ``None``. |
+--------------+-----------------------------------------------------------+
| ca_cert | The path to the rool SSL authority certificate. Defaults |
| | to ``None``. |
+--------------+-----------------------------------------------------------+
"""
client = _create_etcd_client(params)
backend = EtcdRendezvousBackend(
client, params.run_id, key_prefix="/torch/elastic/rendezvous"
)
store = EtcdStore(client, "/torch/elastic/store")
return backend, store

View File

@ -0,0 +1,248 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import atexit
import logging
import os
import shlex
import shutil
import socket
import subprocess
import tempfile
import time
from typing import Optional, TextIO, Union
try:
import etcd # type: ignore[import]
except ModuleNotFoundError:
pass
logger = logging.getLogger(__name__)
def find_free_port():
"""
Find a free port and binds a temporary socket to it so that the port can be "reserved" until used.
.. note:: the returned socket must be closed before using the port,
otherwise a ``address already in use`` error will happen.
The socket should be held and closed as close to the
consumer of the port as possible since otherwise, there
is a greater chance of race-condition where a different
process may see the port as being free and take it.
Returns: a socket binded to the reserved free port
Usage::
sock = find_free_port()
port = sock.getsockname()[1]
sock.close()
use_port(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
try:
s = socket.socket(family, type, proto)
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close() # type: ignore[possibly-undefined]
print(f"Socket creation attempt failed: {e}")
raise RuntimeError("Failed to create a socket")
def stop_etcd(subprocess, data_dir: Optional[str] = None):
if subprocess and subprocess.poll() is None:
logger.info("stopping etcd server")
subprocess.terminate()
subprocess.wait()
if data_dir:
logger.info("deleting etcd data dir: %s", data_dir)
shutil.rmtree(data_dir, ignore_errors=True)
class EtcdServer:
"""
.. note:: tested on etcd server v3.4.3.
Starts and stops a local standalone etcd server on a random free
port. Useful for single node, multi-worker launches or testing,
where a sidecar etcd server is more convenient than having to
separately setup an etcd server.
This class registers a termination handler to shutdown the etcd
subprocess on exit. This termination handler is NOT a substitute for
calling the ``stop()`` method.
The following fallback mechanism is used to find the etcd binary:
1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
2. Uses ``<this file root>/bin/etcd`` if one exists
3. Uses ``etcd`` from ``PATH``
Usage
::
server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
server.start()
client = server.get_client()
# use client
server.stop()
Args:
etcd_binary_path: path of etcd server binary (see above for fallback path)
"""
def __init__(self, data_dir: Optional[str] = None):
self._port = -1
self._host = "localhost"
root = os.path.dirname(__file__)
default_etcd_bin = os.path.join(root, "bin/etcd")
self._etcd_binary_path = os.environ.get(
"TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
)
if not os.path.isfile(self._etcd_binary_path):
self._etcd_binary_path = "etcd"
self._base_data_dir = (
data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
)
self._etcd_cmd = None
self._etcd_proc: Optional[subprocess.Popen] = None
def _get_etcd_server_process(self) -> subprocess.Popen:
if not self._etcd_proc:
raise RuntimeError(
"No etcd server process started. Call etcd_server.start() first"
)
else:
return self._etcd_proc
def get_port(self) -> int:
"""Return the port the server is running on."""
return self._port
def get_host(self) -> str:
"""Return the host the server is running on."""
return self._host
def get_endpoint(self) -> str:
"""Return the etcd server endpoint (host:port)."""
return f"{self._host}:{self._port}"
def start(
self,
timeout: int = 60,
num_retries: int = 3,
stderr: Union[int, TextIO, None] = None,
) -> None:
"""
Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
Args:
timeout: time (in seconds) to wait for the server to be ready
before giving up.
num_retries: number of retries to start the server. Each retry
will wait for max ``timeout`` before considering it as failed.
stderr: the standard error file handle. Valid values are
`subprocess.PIPE`, `subprocess.DEVNULL`, an existing file
descriptor (a positive integer), an existing file object, and
`None`.
Raises:
TimeoutError: if the server is not ready within the specified timeout
"""
curr_retries = 0
while True:
try:
data_dir = os.path.join(self._base_data_dir, str(curr_retries))
os.makedirs(data_dir, exist_ok=True)
return self._start(data_dir, timeout, stderr)
except Exception as e:
curr_retries += 1
stop_etcd(self._etcd_proc)
logger.warning(
"Failed to start etcd server, got error: %s, retrying", str(e)
)
if curr_retries >= num_retries:
shutil.rmtree(self._base_data_dir, ignore_errors=True)
raise
atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
def _start(
self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
) -> None:
sock = find_free_port()
sock_peer = find_free_port()
self._port = sock.getsockname()[1]
peer_port = sock_peer.getsockname()[1]
etcd_cmd = shlex.split(
" ".join(
[
self._etcd_binary_path,
"--enable-v2",
"--data-dir",
data_dir,
"--listen-client-urls",
f"http://{self._host}:{self._port}",
"--advertise-client-urls",
f"http://{self._host}:{self._port}",
"--listen-peer-urls",
f"http://{self._host}:{peer_port}",
]
)
)
logger.info("Starting etcd server: [%s]", etcd_cmd)
sock.close()
sock_peer.close()
self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr)
self._wait_for_ready(timeout)
def get_client(self):
"""Return an etcd client object that can be used to make requests to this server."""
return etcd.Client(
host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
)
def _wait_for_ready(self, timeout: int = 60) -> None:
client = etcd.Client(
host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
)
max_time = time.time() + timeout
while time.time() < max_time:
if self._get_etcd_server_process().poll() is not None:
# etcd server process finished
exitcode = self._get_etcd_server_process().returncode
raise RuntimeError(
f"Etcd server process exited with the code: {exitcode}"
)
try:
logger.info("etcd server ready. version: %s", client.version)
return
except Exception:
time.sleep(1)
raise TimeoutError("Timed out waiting for etcd server to be ready!")
def stop(self) -> None:
"""Stop the server and cleans up auto generated resources (e.g. data dir)."""
logger.info("EtcdServer stop method called")
stop_etcd(self._etcd_proc, self._base_data_dir)

View File

@ -0,0 +1,207 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import random
import time
from base64 import b64decode, b64encode
from typing import Optional
import etcd # type: ignore[import]
# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
from torch.distributed import Store
# Delay (sleep) for a small random amount to reduce CAS failures.
# This does not affect correctness, but will reduce requests to etcd server.
def cas_delay():
time.sleep(random.uniform(0, 0.1))
# pyre-fixme[11]: Annotation `Store` is not defined as a type.
class EtcdStore(Store):
"""
Implement a c10 Store interface by piggybacking on the rendezvous etcd instance.
This is the store object returned by ``EtcdRendezvous``.
"""
def __init__(
self,
etcd_client,
etcd_store_prefix,
# Default timeout same as in c10d/Store.hpp
timeout: Optional[datetime.timedelta] = None,
):
super().__init__() # required for pybind trampoline.
self.client = etcd_client
self.prefix = etcd_store_prefix
if timeout is not None:
self.set_timeout(timeout)
if not self.prefix.endswith("/"):
self.prefix += "/"
def set(self, key, value):
"""
Write a key/value pair into ``EtcdStore``.
Both key and value may be either Python ``str`` or ``bytes``.
"""
self.client.set(key=self.prefix + self._encode(key), value=self._encode(value))
def get(self, key) -> bytes:
"""
Get a value by key, possibly doing a blocking wait.
If key is not immediately present, will do a blocking wait
for at most ``timeout`` duration or until the key is published.
Returns:
value ``(bytes)``
Raises:
LookupError - If key still not published after timeout
"""
b64_key = self.prefix + self._encode(key)
kvs = self._try_wait_get([b64_key])
if kvs is None:
raise LookupError(f"Key {key} not found in EtcdStore")
return self._decode(kvs[b64_key])
def add(self, key, num: int) -> int:
"""
Atomically increment a value by an integer amount.
The integer is represented as a string using base 10. If key is not present,
a default value of ``0`` will be assumed.
Returns:
the new (incremented) value
"""
b64_key = self._encode(key)
# c10d Store assumes value is an integer represented as a decimal string
try:
# Assume default value "0", if this key didn't yet:
node = self.client.write(
key=self.prefix + b64_key,
value=self._encode(str(num)), # i.e. 0 + num
prevExist=False,
)
return int(self._decode(node.value))
except etcd.EtcdAlreadyExist:
pass
while True:
# Note: c10d Store does not have a method to delete keys, so we
# can be sure it's still there.
node = self.client.get(key=self.prefix + b64_key)
new_value = self._encode(str(int(self._decode(node.value)) + num))
try:
node = self.client.test_and_set(
key=node.key, value=new_value, prev_value=node.value
)
return int(self._decode(node.value))
except etcd.EtcdCompareFailed:
cas_delay()
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
"""
Wait until all of the keys are published, or until timeout.
Raises:
LookupError - if timeout occurs
"""
b64_keys = [self.prefix + self._encode(key) for key in keys]
kvs = self._try_wait_get(b64_keys, override_timeout)
if kvs is None:
raise LookupError("Timeout while waiting for keys in EtcdStore")
# No return value on success
def check(self, keys) -> bool:
"""Check if all of the keys are immediately present (without waiting)."""
b64_keys = [self.prefix + self._encode(key) for key in keys]
kvs = self._try_wait_get(
b64_keys,
override_timeout=datetime.timedelta(microseconds=1), # as if no wait
)
return kvs is not None
#
# Encode key/value data in base64, so we can store arbitrary binary data
# in EtcdStore. Input can be `str` or `bytes`.
# In case of `str`, utf-8 encoding is assumed.
#
def _encode(self, value) -> str:
if type(value) == bytes:
return b64encode(value).decode()
elif type(value) == str:
return b64encode(value.encode()).decode()
raise ValueError("Value must be of type str or bytes")
#
# Decode a base64 string (of type `str` or `bytes`).
# Return type is `bytes`, which is more convenient with the Store interface.
#
def _decode(self, value) -> bytes:
if type(value) == bytes:
return b64decode(value)
elif type(value) == str:
return b64decode(value.encode())
raise ValueError("Value must be of type str or bytes")
#
# Get all of the (base64-encoded) etcd keys at once, or wait until all the keys
# are published or timeout occurs.
# This is a helper method for the public interface methods.
#
# On success, a dictionary of {etcd key -> etcd value} is returned.
# On timeout, None is returned.
#
def _try_wait_get(self, b64_keys, override_timeout=None):
timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined]
deadline = time.time() + timeout.total_seconds()
while True:
# Read whole directory (of keys), filter only the ones waited for
all_nodes = self.client.get(key=self.prefix)
req_nodes = {
node.key: node.value
for node in all_nodes.children
if node.key in b64_keys
}
if len(req_nodes) == len(b64_keys):
# All keys are available
return req_nodes
watch_timeout = deadline - time.time()
if watch_timeout <= 0:
return None
try:
self.client.watch(
key=self.prefix,
recursive=True,
timeout=watch_timeout,
index=all_nodes.etcd_index + 1,
)
except etcd.EtcdWatchTimedOut:
if time.time() >= deadline:
return None
else:
continue
except etcd.EtcdEventIndexCleared:
continue

View File

@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .api import (
rendezvous_handler_registry as handler_registry,
RendezvousHandler,
RendezvousParameters,
)
from .dynamic_rendezvous import create_handler
__all__ = ["get_rendezvous_handler"]
def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import static_tcp_rendezvous
return static_tcp_rendezvous.create_rdzv_handler(params)
def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import etcd_rendezvous
return etcd_rendezvous.create_rdzv_handler(params)
def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
from .etcd_rendezvous_backend import create_backend
backend, store = create_backend(params)
return create_handler(store, backend, params)
def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
from .c10d_rendezvous_backend import create_backend
backend, store = create_backend(params)
return create_handler(store, backend, params)
def _register_default_handlers() -> None:
handler_registry.register("etcd", _create_etcd_handler)
handler_registry.register("etcd-v2", _create_etcd_v2_handler)
handler_registry.register("c10d", _create_c10d_handler)
handler_registry.register("static", _create_static_handler)
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
"""
Obtain a reference to a :py:class`RendezvousHandler`.
Custom rendezvous handlers can be registered by
::
from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
def create_my_rdzv(params: RendezvousParameters):
return MyCustomRdzv(params)
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
"""
return handler_registry.create_handler(params)

View File

@ -0,0 +1,128 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import logging
from typing import cast, Optional
from torch.distributed import PrefixStore, Store, TCPStore
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousInfo,
RendezvousParameters,
RendezvousStoreInfo,
)
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"]
logger = logging.getLogger(__name__)
_default_timeout_seconds = 600
class StaticTCPRendezvous(RendezvousHandler):
"""
Static rendezvous that is a wrapper around the TCPStore.
Creates TCPStore based on the input parameters with the
listener on the agent with group_rank=0
"""
def __init__(
self,
master_addr: str,
master_port: int,
rank: int,
world_size: int,
run_id: str,
timeout: int,
):
self.master_addr = master_addr
self.master_port = master_port
self.rank = rank
self.world_size = world_size
self.run_id = run_id
self.timeout = datetime.timedelta(seconds=timeout)
self._store: Optional[Store] = None
def get_backend(self) -> str:
return "static"
@property
def use_agent_store(self) -> bool:
return True
def next_rendezvous(self) -> RendezvousInfo:
logger.info("Creating TCPStore as the c10d::Store implementation")
is_master = self.rank == 0
if not self._store:
self._store = TCPStore( # type: ignore[call-arg]
self.master_addr,
self.master_port,
self.world_size,
is_master,
self.timeout,
multi_tenant=True,
)
store = PrefixStore(self.run_id, self._store)
# TCPStore server instance is used by trainer code
bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port)
return RendezvousInfo(
store,
self.rank,
self.world_size,
bootstrap_store_info,
)
def is_closed(self):
return False
def set_closed(self):
pass
def num_nodes_waiting(self):
return 0
def get_run_id(self) -> str:
return self.run_id
def shutdown(self) -> bool:
return True
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
if "rank" not in params.config:
raise ValueError(
"rank is absent in RendezvousParameters."
"Try add --node-rank to the cmd request"
)
endpoint = params.endpoint.strip()
if not endpoint:
raise ValueError(
"endpoint is absent in RendezvousParameters"
"Try add --master-port and --master-addr to the cmd request"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
if master_port == -1:
raise ValueError(
f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
)
world_size = params.max_nodes
rank = cast(int, params.config.get("rank"))
run_id = params.run_id
if "timeout" in params.config:
timeout = int(params.config["timeout"])
else:
timeout = _default_timeout_seconds
return StaticTCPRendezvous(
master_addr, master_port, rank, world_size, run_id, timeout
)

View File

@ -0,0 +1,284 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import ipaddress
import random
import re
import socket
import time
import weakref
from datetime import timedelta
from threading import Event, Thread
from typing import Any, Callable, Dict, Optional, Tuple, Union
__all__ = ["parse_rendezvous_endpoint"]
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
"""Extract key-value pairs from a rendezvous configuration string.
Args:
config_str:
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
"""
config: Dict[str, str] = {}
config_str = config_str.strip()
if not config_str:
return config
key_values = config_str.split(",")
for kv in key_values:
key, *values = kv.split("=", 1)
key = key.strip()
if not key:
raise ValueError(
"The rendezvous configuration string must be in format "
"<key1>=<value1>,...,<keyN>=<valueN>."
)
value: Optional[str]
if values:
value = values[0].strip()
else:
value = None
if not value:
raise ValueError(
f"The rendezvous configuration option '{key}' must have a value specified."
)
config[key] = value
return config
def _try_parse_port(port_str: str) -> Optional[int]:
"""Try to extract the port number from ``port_str``."""
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
return int(port_str)
return None
def parse_rendezvous_endpoint(
endpoint: Optional[str], default_port: int
) -> Tuple[str, int]:
"""Extract the hostname and the port number from a rendezvous endpoint.
Args:
endpoint:
A string in format <hostname>[:<port>].
default_port:
The port number to use if the endpoint does not include one.
Returns:
A tuple of hostname and port number.
"""
if endpoint is not None:
endpoint = endpoint.strip()
if not endpoint:
return ("localhost", default_port)
# An endpoint that starts and ends with brackets represents an IPv6 address.
if endpoint[0] == "[" and endpoint[-1] == "]":
host, *rest = endpoint, *[]
else:
host, *rest = endpoint.rsplit(":", 1)
# Sanitize the IPv6 address.
if len(host) > 1 and host[0] == "[" and host[-1] == "]":
host = host[1:-1]
if len(rest) == 1:
port = _try_parse_port(rest[0])
if port is None or port >= 2**16:
raise ValueError(
f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
"between 0 and 65536."
)
else:
port = default_port
if not re.match(r"^[\w\.:-]+$", host):
raise ValueError(
f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
"labels, an IPv4 address, or an IPv6 address."
)
return host, port
def _matches_machine_hostname(host: str) -> bool:
"""Indicate whether ``host`` matches the hostname of this machine.
This function compares ``host`` to the hostname as well as to the IP
addresses of this machine. Note that it may return a false negative if this
machine has CNAME records beyond its FQDN or IP addresses assigned to
secondary NICs.
"""
if host == "localhost":
return True
try:
addr = ipaddress.ip_address(host)
except ValueError:
addr = None
if addr and addr.is_loopback:
return True
try:
host_addr_list = socket.getaddrinfo(
host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
except (ValueError, socket.gaierror) as _:
host_addr_list = []
host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list]
this_host = socket.gethostname()
if host == this_host:
return True
addr_list = socket.getaddrinfo(
this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
for addr_info in addr_list:
# If we have an FQDN in the addr_info, compare it to `host`.
if addr_info[3] and addr_info[3] == host:
return True
# Otherwise if `host` represents an IP address, compare it to our IP
# address.
if addr and addr_info[4][0] == str(addr):
return True
# If the IP address matches one of the provided host's IP addresses
if addr_info[4][0] in host_ip_list:
return True
return False
def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
"""Suspend the current thread for ``seconds``.
Args:
seconds:
Either the delay, in seconds, or a tuple of a lower and an upper
bound within which a random delay will be picked.
"""
if isinstance(seconds, tuple):
seconds = random.uniform(*seconds)
# Ignore delay requests that are less than 10 milliseconds.
if seconds >= 0.01:
time.sleep(seconds)
class _PeriodicTimer:
"""Represent a timer that periodically runs a specified function.
Args:
interval:
The interval, in seconds, between each run.
function:
The function to run.
"""
# The state of the timer is hold in a separate context object to avoid a
# reference cycle between the timer and the background thread.
class _Context:
interval: float
function: Callable[..., None]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
stop_event: Event
_name: Optional[str]
_thread: Optional[Thread]
_finalizer: Optional[weakref.finalize]
# The context that is shared between the timer and the background thread.
_ctx: _Context
def __init__(
self,
interval: timedelta,
function: Callable[..., None],
*args: Any,
**kwargs: Any,
) -> None:
self._name = None
self._ctx = self._Context()
self._ctx.interval = interval.total_seconds()
self._ctx.function = function # type: ignore[assignment]
self._ctx.args = args or ()
self._ctx.kwargs = kwargs or {}
self._ctx.stop_event = Event()
self._thread = None
self._finalizer = None
@property
def name(self) -> Optional[str]:
"""Get the name of the timer."""
return self._name
def set_name(self, name: str) -> None:
"""Set the name of the timer.
The specified name will be assigned to the background thread and serves
for debugging and troubleshooting purposes.
"""
if self._thread:
raise RuntimeError("The timer has already started.")
self._name = name
def start(self) -> None:
"""Start the timer."""
if self._thread:
raise RuntimeError("The timer has already started.")
self._thread = Thread(
target=self._run,
name=self._name or "PeriodicTimer",
args=(self._ctx,),
daemon=True,
)
# We avoid using a regular finalizer (a.k.a. __del__) for stopping the
# timer as joining a daemon thread during the interpreter shutdown can
# cause deadlocks. The weakref.finalize is a superior alternative that
# provides a consistent behavior regardless of the GC implementation.
self._finalizer = weakref.finalize(
self, self._stop_thread, self._thread, self._ctx.stop_event
)
# We do not attempt to stop our background thread during the interpreter
# shutdown. At that point we do not even know whether it still exists.
self._finalizer.atexit = False
self._thread.start()
def cancel(self) -> None:
"""Stop the timer at the next opportunity."""
if self._finalizer:
self._finalizer()
@staticmethod
def _run(ctx) -> None:
while not ctx.stop_event.wait(ctx.interval):
ctx.function(*ctx.args, **ctx.kwargs)
@staticmethod
def _stop_thread(thread, stop_event):
stop_event.set()
thread.join()

View File

@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Expiration timers are set up on the same process as the agent and
used from your script to deal with stuck workers. When you go into
a code-block that has the potential to get stuck you can acquire
an expiration timer, which instructs the timer server to kill the
process if it does not release the timer by the self-imposed expiration
deadline.
Usage::
import torchelastic.timer as timer
import torchelastic.agent.server as agent
def main():
start_method = "spawn"
message_queue = mp.get_context(start_method).Queue()
server = timer.LocalTimerServer(message, max_interval=0.01)
server.start() # non-blocking
spec = WorkerSpec(
fn=trainer_func,
args=(message_queue,),
...<OTHER_PARAMS...>)
agent = agent.LocalElasticAgent(spec, start_method)
agent.run()
def trainer_func(message_queue):
timer.configure(timer.LocalTimerClient(message_queue))
with timer.expires(after=60): # 60 second expiry
# do some work
In the example above if ``trainer_func`` takes more than 60 seconds to
complete, then the worker process is killed and the agent retries the worker group.
"""
from .api import ( # noqa: F401
configure,
expires,
TimerClient,
TimerRequest,
TimerServer,
)
from .file_based_local_timer import ( # noqa: F401
FileTimerClient,
FileTimerRequest,
FileTimerServer,
)
from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401

View File

@ -0,0 +1,283 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import logging
import threading
import time
from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set
__all__ = [
"TimerRequest",
"TimerClient",
"RequestQueue",
"TimerServer",
"configure",
"expires",
]
logger = logging.getLogger(__name__)
class TimerRequest:
"""
Data object representing a countdown timer acquisition and release
that is used between the ``TimerClient`` and ``TimerServer``.
A negative ``expiration_time`` should be interpreted as a "release"
request.
.. note:: the type of ``worker_id`` is implementation specific.
It is whatever the TimerServer and TimerClient implementations
have on to uniquely identify a worker.
"""
__slots__ = ["worker_id", "scope_id", "expiration_time"]
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
self.worker_id = worker_id
self.scope_id = scope_id
self.expiration_time = expiration_time
def __eq__(self, other):
if isinstance(other, TimerRequest):
return (
self.worker_id == other.worker_id
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
)
return False
class TimerClient(abc.ABC):
"""
Client library to acquire and release countdown timers by communicating
with the TimerServer.
"""
@abc.abstractmethod
def acquire(self, scope_id: str, expiration_time: float) -> None:
"""
Acquires a timer for the worker that holds this client object
given the scope_id and expiration_time. Typically registers
the timer with the TimerServer.
"""
@abc.abstractmethod
def release(self, scope_id: str):
"""
Releases the timer for the ``scope_id`` on the worker this
client represents. After this method is
called, the countdown timer on the scope is no longer in effect.
"""
class RequestQueue(abc.ABC):
"""
Consumer queue holding timer acquisition/release requests
"""
@abc.abstractmethod
def size(self) -> int:
"""
Returns the size of the queue at the time this method is called.
Note that by the time ``get`` is called the size of the queue
may have increased. The size of the queue should not decrease
until the ``get`` method is called. That is, the following assertion
should hold:
size = q.size()
res = q.get(size, timeout=0)
assert size == len(res)
-- or --
size = q.size()
res = q.get(size * 2, timeout=1)
assert size <= len(res) <= size * 2
"""
@abc.abstractmethod
def get(self, size: int, timeout: float) -> List[TimerRequest]:
"""
Gets up to ``size`` number of timer requests in a blocking fashion
(no more than ``timeout`` seconds).
"""
class TimerServer(abc.ABC):
"""
Entity that monitors active timers and expires them
in a timely fashion. This server is responsible for
reaping workers that have expired timers.
"""
def __init__(
self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
):
"""
:param request_queue: Consumer ``RequestQueue``
:param max_interval: max time (in seconds) to wait
for an item in the request_queue
:param daemon: whether to run the watchdog thread as a daemon
"""
super().__init__()
self._request_queue = request_queue
self._max_interval = max_interval
self._daemon = daemon
self._watchdog_thread: Optional[threading.Thread] = None
self._stop_signaled = False
@abc.abstractmethod
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
"""
Processes the incoming timer requests and registers them with the server.
The timer request can either be a acquire-timer or release-timer request.
Timer requests with a negative expiration_time should be interpreted
as a release-timer request.
"""
@abc.abstractmethod
def clear_timers(self, worker_ids: Set[Any]) -> None:
"""
Clears all timers for the given ``worker_ids``.
"""
@abc.abstractmethod
def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
"""
Returns all expired timers for each worker_id. An expired timer
is a timer for which the expiration_time is less than or equal to
the provided deadline.
"""
@abc.abstractmethod
def _reap_worker(self, worker_id: Any) -> bool:
"""
Reaps the given worker. Returns True if the worker has been
successfully reaped, False otherwise. If any uncaught exception
is thrown from this method, the worker is considered reaped
and all associated timers will be removed.
"""
def _reap_worker_no_throw(self, worker_id: Any) -> bool:
"""
Wraps ``_reap_worker(worker_id)``, if an uncaught exception is
thrown, then it considers the worker as reaped.
"""
try:
return self._reap_worker(worker_id)
except Exception:
logger.exception(
"Uncaught exception thrown from _reap_worker(), "
"check that the implementation correctly catches exceptions",
)
return True
def _watchdog_loop(self):
while not self._stop_signaled:
try:
self._run_watchdog()
except Exception:
logger.exception("Error running watchdog")
def _run_watchdog(self):
batch_size = max(1, self._request_queue.size())
timer_requests = self._request_queue.get(batch_size, self._max_interval)
self.register_timers(timer_requests)
now = time.time()
reaped_worker_ids = set()
for worker_id, expired_timers in self.get_expired_timers(now).items():
logger.info(
"Reaping worker_id=[%s]." " Expired timers: %s",
worker_id,
self._get_scopes(expired_timers),
)
if self._reap_worker_no_throw(worker_id):
logger.info("Successfully reaped worker=[%s]", worker_id)
reaped_worker_ids.add(worker_id)
else:
logger.error(
"Error reaping worker=[%s]. Will retry on next watchdog.", worker_id
)
self.clear_timers(reaped_worker_ids)
def _get_scopes(self, timer_requests):
return [r.scope_id for r in timer_requests]
def start(self) -> None:
logger.info(
"Starting %s..." " max_interval=%s," " daemon=%s",
type(self).__name__,
self._max_interval,
self._daemon,
)
self._watchdog_thread = threading.Thread(
target=self._watchdog_loop, daemon=self._daemon
)
logger.info("Starting watchdog thread...")
self._watchdog_thread.start()
def stop(self) -> None:
logger.info("Stopping %s", type(self).__name__)
self._stop_signaled = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join(self._max_interval)
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
_timer_client: Optional[TimerClient] = None
def configure(timer_client: TimerClient):
"""
Configures a timer client. Must be called before using ``expires``.
"""
global _timer_client
_timer_client = timer_client
logger.info("Timer client configured to: %s", type(_timer_client).__name__)
@contextmanager
def expires(
after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
):
"""
Acquires a countdown timer that expires in ``after`` seconds from now,
unless the code-block that it wraps is finished within the timeframe.
When the timer expires, this worker is eligible to be reaped. The
exact meaning of "reaped" depends on the client implementation. In
most cases, reaping means to terminate the worker process.
Note that the worker is NOT guaranteed to be reaped at exactly
``time.now() + after``, but rather the worker is "eligible" for being
reaped and the ``TimerServer`` that the client talks to will ultimately
make the decision when and how to reap the workers with expired timers.
Usage::
torch.distributed.elastic.timer.configure(LocalTimerClient())
with expires(after=10):
torch.distributed.all_reduce(...)
"""
if client is None:
if _timer_client is None:
raise RuntimeError("Configure timer client before using countdown timers.")
client = _timer_client
if scope is None:
# grab the caller file + lineno
caller = getframeinfo(stack()[1][0])
scope = f"{caller.filename}#{caller.lineno}"
expiration = time.time() + after
client.acquire(scope, expiration)
try:
yield
finally:
client.release(scope)

View File

@ -0,0 +1,25 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
from torch.distributed.elastic.utils.logging import get_logger
logger = get_logger(__name__)
__all__ = ["log_debug_info_for_expired_timers"]
def log_debug_info_for_expired_timers(
run_id: str,
expired_timers: Dict[int, List[str]],
):
if expired_timers:
logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)

View File

@ -0,0 +1,393 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import io
import json
import os
import select
import signal
import sys
import threading
import time
from typing import Callable, Dict, List, Optional, Set, Tuple
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
from torch.distributed.elastic.timer.debug_info_logging import (
log_debug_info_for_expired_timers,
)
from torch.distributed.elastic.utils.logging import get_logger
__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
logger = get_logger(__name__)
class FileTimerRequest(TimerRequest):
"""
Data object representing a countdown timer acquisition and release
that is used between the ``FileTimerClient`` and ``FileTimerServer``.
A negative ``expiration_time`` should be interpreted as a "release"
request.
``signal`` is the signal to reap the worker process from the server
process.
"""
__slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"]
def __init__(
self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0
) -> None:
self.version = 1
self.worker_pid = worker_pid
self.scope_id = scope_id
self.expiration_time = expiration_time
self.signal = signal
def __eq__(self, other) -> bool:
if isinstance(other, FileTimerRequest):
return (
self.version == other.version
and self.worker_pid == other.worker_pid
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
and self.signal == other.signal
)
return False
def to_json(self) -> str:
return json.dumps(
{
"version": self.version,
"pid": self.worker_pid,
"scope_id": self.scope_id,
"expiration_time": self.expiration_time,
"signal": self.signal,
},
)
class FileTimerClient(TimerClient):
"""
Client side of ``FileTimerServer``. This client is meant to be used
on the same host that the ``FileTimerServer`` is running on and uses
pid to uniquely identify a worker.
This client uses a named_pipe to send timer requests to the
``FileTimerServer``. This client is a producer while the
``FileTimerServer`` is a consumer. Multiple clients can work with
the same ``FileTimerServer``.
Args:
file_path: str, the path of a FIFO special file. ``FileTimerServer``
must have created it by calling os.mkfifo().
signal: signal, the signal to use to kill the process. Using a
negative or zero signal will not kill the process.
"""
def __init__(
self,
file_path: str,
signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined]
) -> None:
super().__init__()
self._file_path = file_path
self.signal = signal
def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
try:
fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK)
return os.fdopen(fd, "wt")
except Exception:
return None
def _send_request(self, request: FileTimerRequest) -> None:
# The server may have crashed or may haven't started yet.
# In such case, calling open() in blocking model blocks the client.
# To avoid such issue, open it in non-blocking mode, and an OSError will
# be raised if the server is not there.
file = self._open_non_blocking()
if file is None:
raise BrokenPipeError(
"Could not send the FileTimerRequest because FileTimerServer is not available."
)
with file:
json_request = request.to_json()
# Write request with no greater than select.PIPE_BUF is guarantee to be atomic.
if len(json_request) > select.PIPE_BUF:
raise RuntimeError(
f"FileTimerRequest larger than {select.PIPE_BUF} bytes "
f"is not supported: {json_request}"
)
file.write(json_request + "\n")
def acquire(self, scope_id: str, expiration_time: float) -> None:
self._send_request(
request=FileTimerRequest(
worker_pid=os.getpid(),
scope_id=scope_id,
expiration_time=expiration_time,
signal=self.signal,
),
)
def release(self, scope_id: str) -> None:
self._send_request(
request=FileTimerRequest(
worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0
),
)
class FileTimerServer:
"""
Server that works with ``FileTimerClient``. Clients are expected to be
running on the same host as the process that is running this server.
Each host in the job is expected to start its own timer server locally
and each server instance manages timers for local workers (running on
processes on the same host).
Args:
file_path: str, the path of a FIFO special file to be created.
max_interval: float, max interval in seconds for each watchdog loop.
daemon: bool, running the watchdog thread in daemon mode or not.
A daemon thread will not block a process to stop.
log_event: Callable[[Dict[str, str]], None], an optional callback for
logging the events in JSON format.
"""
def __init__(
self,
file_path: str,
run_id: str,
max_interval: float = 10,
daemon: bool = True,
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None,
) -> None:
self._file_path = file_path
self._run_id = run_id
self._max_interval = max_interval
self._daemon = daemon
self._timers: Dict[Tuple[int, str], FileTimerRequest] = {}
self._stop_signaled = False
self._watchdog_thread: Optional[threading.Thread] = None
if os.path.exists(self._file_path):
os.remove(self._file_path)
os.mkfifo(self._file_path)
# For test only. Count the number of requests received.
self._request_count = 0
# For test only. Process all requests and stop the server.
self._run_once = False
self._log_event = (
log_event if log_event is not None else lambda name, request: None
)
self._last_progress_time = int(time.time())
def start(self) -> None:
logger.info(
"Starting %s... max_interval=%s, daemon=%s, file_path=%s",
type(self).__name__,
self._max_interval,
self._daemon,
self._file_path,
)
self._watchdog_thread = threading.Thread(
target=self._watchdog_loop, daemon=self._daemon
)
logger.info("Starting watchdog thread...")
self._watchdog_thread.start()
self._log_event("watchdog started", None)
def stop(self) -> None:
logger.info("Stopping %s", type(self).__name__)
self._stop_signaled = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join(self._max_interval)
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
if os.path.exists(self._file_path):
os.remove(self._file_path)
self._log_event("watchdog stopped", None)
def run_once(self) -> None:
self._run_once = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join()
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
if os.path.exists(self._file_path):
os.remove(self._file_path)
@staticmethod
def is_process_running(pid: int):
"""
function to check process is running or not
"""
try:
# Check if the process exists and we can send signals to it
os.kill(pid, 0)
return True
except OSError:
return False
def _watchdog_loop(self) -> None:
# Open the pipe in blocking mode blocks the server thread.
# This is fine for the following reasons:
# 1. No client case usually does not happen.
# 2. We are running the watchdog loop in a separate daemon
# thread, which will not block the process to stop.
with open(self._file_path) as fd:
while not self._stop_signaled:
try:
run_once = self._run_once
self._run_watchdog(fd)
if run_once:
break
self._last_progress_time = int(time.time())
except Exception:
logger.exception("Error running watchdog")
def _run_watchdog(self, fd: io.TextIOWrapper) -> None:
timer_requests = self._get_requests(fd, self._max_interval)
self.register_timers(timer_requests)
now = time.time()
reaped_worker_pids = set()
all_expired_timers = self.get_expired_timers(now)
log_debug_info_for_expired_timers(
self._run_id,
{
pid: self._get_scopes(expired_timers)
for pid, expired_timers in all_expired_timers.items()
},
)
for worker_pid, expired_timers in all_expired_timers.items():
logger.info(
"Reaping worker_pid=[%s]. Expired timers: %s",
worker_pid,
self._get_scopes(expired_timers),
)
reaped_worker_pids.add(worker_pid)
# In case we have multiple expired timers, we find the first timer
# with a valid signal (>0) in the expiration time order.
expired_timers.sort(key=lambda timer: timer.expiration_time)
signal = 0
expired_timer = None
for timer in expired_timers:
self._log_event("timer expired", timer)
if timer.signal > 0:
signal = timer.signal
expired_timer = timer
break
if signal <= 0:
logger.info(
"No signal specified with worker=[%s]. Do not reap it.", worker_pid
)
continue
if self._reap_worker(worker_pid, signal):
logger.info(
"Successfully reaped worker=[%s] with signal=%s", worker_pid, signal
)
self._log_event("kill worker process", expired_timer)
else:
logger.error(
"Error reaping worker=[%s]. Will retry on next watchdog.",
worker_pid,
)
self.clear_timers(reaped_worker_pids)
def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]:
return [r.scope_id for r in timer_requests]
def _get_requests(
self, fd: io.TextIOWrapper, max_interval: float
) -> List[FileTimerRequest]:
start = time.time()
requests = []
while not self._stop_signaled or self._run_once:
# For named pipe, readline() is blocking when at least one writer opens.
# It returns only when flush() is called at the writer side.
# Note that flush() is automatically called inside close().
# After the last writer closes, readline() is not blocking.
# It will return an empty string when it's at end-of-file.
# Since the client side always opens the pipe, writes a message and closes
# the pipe immediately, the readline() call below is not blocking for long.
json_request = fd.readline()
if len(json_request) == 0:
if self._run_once:
break
time.sleep(min(max_interval, 1))
else:
request = json.loads(json_request)
pid = request["pid"]
scope_id = request["scope_id"]
expiration_time = request["expiration_time"]
signal = request["signal"]
requests.append(
FileTimerRequest(
worker_pid=pid,
scope_id=scope_id,
expiration_time=expiration_time,
signal=signal,
)
)
now = time.time()
if now - start > max_interval:
break
return requests
def register_timers(self, timer_requests: List[FileTimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_pid
scope_id = request.scope_id
expiration_time = request.expiration_time
self._request_count += 1
key = (pid, scope_id)
# negative expiration is a proxy for a release call
if expiration_time < 0:
if key in self._timers:
del self._timers[key]
else:
self._timers[key] = request
def clear_timers(self, worker_pids: Set[int]) -> None:
for pid, scope_id in list(self._timers.keys()):
if pid in worker_pids or not FileTimerServer.is_process_running(pid):
del self._timers[(pid, scope_id)]
def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]:
# pid -> [timer_requests...]
expired_timers: Dict[int, List[FileTimerRequest]] = {}
for request in self._timers.values():
if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_pid, [])
expired_scopes.append(request)
return expired_timers
def _reap_worker(self, worker_pid: int, signal: int) -> bool:
try:
os.kill(worker_pid, signal)
return True
except ProcessLookupError:
logger.info("Process with pid=%s does not exist. Skipping", worker_pid)
return True
except Exception:
logger.exception("Error terminating pid=%s", worker_pid)
return False
def get_last_progress_time(self) -> int:
return self._last_progress_time

View File

@ -0,0 +1,128 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import multiprocessing as mp
import os
import signal
import time
from queue import Empty
from typing import Any, Dict, List, Set, Tuple
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"]
logger = logging.getLogger(__name__)
class LocalTimerClient(TimerClient):
"""
Client side of ``LocalTimerServer``. This client is meant to be used
on the same host that the ``LocalTimerServer`` is running on and uses
pid to uniquely identify a worker. This is particularly useful in situations
where one spawns a subprocess (trainer) per GPU on a host with multiple
GPU devices.
"""
def __init__(self, mp_queue):
super().__init__()
self._mp_queue = mp_queue
def acquire(self, scope_id, expiration_time):
pid = os.getpid()
acquire_request = TimerRequest(pid, scope_id, expiration_time)
self._mp_queue.put(acquire_request)
def release(self, scope_id):
pid = os.getpid()
release_request = TimerRequest(pid, scope_id, -1)
self._mp_queue.put(release_request)
class MultiprocessingRequestQueue(RequestQueue):
"""
A ``RequestQueue`` backed by python ``multiprocessing.Queue``
"""
def __init__(self, mp_queue: mp.Queue):
super().__init__()
self._mp_queue = mp_queue
def size(self) -> int:
return self._mp_queue.qsize()
def get(self, size, timeout: float) -> List[TimerRequest]:
requests = []
wait = timeout
for _ in range(0, size):
start = time.time()
try:
r = self._mp_queue.get(block=True, timeout=wait)
except Empty:
break
requests.append(r)
wait = wait - (time.time() - start)
if wait <= 0:
break
return requests
class LocalTimerServer(TimerServer):
"""
Server that works with ``LocalTimerClient``. Clients are expected to be
subprocesses to the parent process that is running this server. Each host
in the job is expected to start its own timer server locally and each
server instance manages timers for local workers (running on processes
on the same host).
"""
def __init__(
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
):
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_id
scope_id = request.scope_id
expiration_time = request.expiration_time
# negative expiration is a proxy for a release call
if expiration_time < 0:
self._timers.pop((pid, scope_id), None)
else:
self._timers[(pid, scope_id)] = request
def clear_timers(self, worker_ids: Set[int]) -> None:
for pid, scope_id in list(self._timers.keys()):
if pid in worker_ids:
self._timers.pop((pid, scope_id))
def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
# pid -> [timer_requests...]
expired_timers: Dict[Any, List[TimerRequest]] = {}
for request in self._timers.values():
if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_id, [])
expired_scopes.append(request)
return expired_timers
def _reap_worker(self, worker_id: int) -> bool:
try:
os.kill(worker_id, signal.SIGKILL)
return True
except ProcessLookupError:
logger.info("Process with pid=%s does not exist. Skipping", worker_id)
return True
except Exception:
logger.exception("Error terminating pid=%s", worker_id)
return False

View File

@ -0,0 +1,9 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import socket
from string import Template
from typing import Any, List
def get_env_variable_or_raise(env_name: str) -> str:
r"""
Tries to retrieve environment variable. Raises ``ValueError``
if no environment variable found.
Args:
env_name (str): Name of the env variable
"""
value = os.environ.get(env_name, None)
if value is None:
msg = f"Environment variable {env_name} expected, but not set"
raise ValueError(msg)
return value
def get_socket_with_port() -> socket.socket:
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
raise RuntimeError("Failed to create a socket")
class macros:
"""
Defines simple macros for caffe2.distributed.launch cmd args substitution
"""
local_rank = "${local_rank}"
@staticmethod
def substitute(args: List[Any], local_rank: str) -> List[str]:
args_sub = []
for arg in args:
if isinstance(arg, str):
sub = Template(arg).safe_substitute(local_rank=local_rank)
args_sub.append(sub)
else:
args_sub.append(arg)
return args_sub

View File

@ -0,0 +1,10 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .cycling_iterator import CyclingIterator # noqa: F401
from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401

View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
class CyclingIterator:
"""
An iterator decorator that cycles through the
underlying iterator "n" times. Useful to "unroll"
the dataset across multiple training epochs.
The generator function is called as ``generator_fn(epoch)``
to obtain the underlying iterator, where ``epoch`` is a
number less than or equal to ``n`` representing the ``k``th cycle
For example if ``generator_fn`` always returns ``[1,2,3]``
then ``CyclingIterator(n=2, generator_fn)`` will iterate through
``[1,2,3,1,2,3]``
"""
def __init__(self, n: int, generator_fn, start_epoch=0):
self._n = n
self._epoch = start_epoch
self._generator_fn = generator_fn
self._iter = generator_fn(self._epoch)
def __iter__(self):
return self
def __next__(self):
try:
return next(self._iter)
except StopIteration as eod: # eod == end of data
if self._epoch < self._n - 1:
self._epoch += 1
self._iter = self._generator_fn(self._epoch)
return self.__next__()
else:
raise eod

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.utils.data.distributed import DistributedSampler
class ElasticDistributedSampler(DistributedSampler):
"""
Sampler that restricts data loading to a subset of
the dataset for elastic training.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
start_index (optional): Which index of the dataset to start sampling from
"""
def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
if start_index >= len(dataset):
raise ValueError(
f"Start index {start_index} should be less than dataset size {len(dataset)}"
)
self.start_index = start_index
self.num_samples = int(
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = (
torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
.add(self.start_index)
.tolist()
)
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import os
import socket
from contextlib import closing
from typing import Optional
import torch.distributed as dist
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.elastic.utils.store import barrier
__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"]
logger = get_logger(__name__)
_ADDRESS_IN_USE = "Address already in use"
_SOCKET_TIMEOUT = "Socket Timeout"
_TCP_STORE_INIT = "_tcp_store/num_members"
def create_c10d_store(
is_server: bool,
server_addr: str,
server_port: int = -1,
world_size: int = 1,
timeout: float = (60 * 10), # 10 min
wait_for_workers: bool = True,
retries=3,
use_libuv: Optional[bool] = None,
):
if use_libuv is not None:
logger.warning(
"argument use_libuv is deprecated and ignored. Set USE_LIBUV environment "
'variable to "0" to disable libuv, or "1" to enable it. If the env var '
"is not set, libuv will be used by default."
)
# check os.environ for use_libuv
use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option
if server_port == -1 and world_size > 1:
raise ValueError(
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
)
if server_port != -1:
logger.info("sever_port: %s, specified, ignoring retries", server_port)
# only retry when server_port is NOT static
attempt = retries if server_port == -1 else 1
while True:
if server_port != -1:
port = server_port
else:
port = get_free_port()
logger.info(
"Creating c10d store on %s:%s\n"
" world_size : %s\n"
" is_server : %s\n"
" timeout(sec): %s\n"
" use_libuv : %s\n",
server_addr,
port,
world_size,
is_server,
timeout,
use_libuv,
)
try:
store = dist.TCPStore(
host_name=server_addr,
port=port,
world_size=world_size,
is_master=is_server,
timeout=datetime.timedelta(seconds=timeout),
wait_for_workers=wait_for_workers,
use_libuv=use_libuv,
)
# skips full rank check when we don't have to wait for all workers
if wait_for_workers:
_check_full_rank(store, world_size, timeout=timeout)
logger.info("Successfully created c10d store")
return store
except RuntimeError as e:
# this is brittle, but the underlying exception type is not properly pybinded
# so we parse the error msg for now, interestingly this is how torch itself
# detects timeouts and port conflicts in their own unittests
# see - caffe2/torch/testing/_internal/common_utils.py
# TODO properly map the exceptions in pybind (c10d/init.cpp)
if str(e) == _ADDRESS_IN_USE: # this will only happen on the server
if attempt < retries:
logger.warning(
"port: %s already in use, attempt: [%s/%s]",
port,
attempt,
retries,
)
attempt += 1
else:
raise RuntimeError(
f"on {server_addr}, port: {port} already in use"
) from e
else:
raise
def _check_full_rank(store, world_size, timeout):
try:
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
except RuntimeError as e:
if str(e) == _SOCKET_TIMEOUT:
raise TimeoutError(
f"timed out waiting for all {world_size} members to join"
) from e
else:
raise
def get_free_port():
"""
Returns an unused port on localhost.
This function finds an unused port on localhost by opening to socket to bind
to a port and then closing it.
Returns:
int: an unused port on localhost
Example:
>>> # xdoctest: +SKIP("Nondeterministic")
>>> get_free_port()
63976
..note:
The port returned by :func:`get_free_port` is not reserved and may be
taken by another process after this function returns.
"""
sock = get_socket_with_port()
with closing(sock):
return sock.getsockname()[1]
def get_socket_with_port() -> socket.socket:
"""
Returns a free port on localhost that is "reserved" by binding a temporary
socket on it. Close the socket before passing the port to the entity
that requires it. Usage example
::
sock = _get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
sock.close()
# there is still a race-condition that some other process
# may grab this port before func() runs
func(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
logger.warning("Socket creation attempt failed.", exc_info=e)
raise RuntimeError("Failed to create a socket")

View File

@ -0,0 +1,14 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
def get_log_level() -> str:
"""
Return default log level for pytorch.
"""
return "WARNING"

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
import warnings
from typing import Optional
from torch.distributed.elastic.utils.log_level import get_log_level
def get_logger(name: Optional[str] = None):
"""
Util function to set up a simple logger that writes
into stderr. The loglevel is fetched from the LOGLEVEL
env. variable or WARNING as default. The function will use the
module name of the caller if no name is provided.
Args:
name: Name of the logger. If no name provided, the name will
be derived from the call stack.
"""
# Derive the name of the caller, if none provided
# Use depth=2 since this function takes up one level in the call stack
return _setup_logger(name or _derive_module_name(depth=2))
def _setup_logger(name: Optional[str] = None):
logger = logging.getLogger(name)
logger.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
return logger
def _derive_module_name(depth: int = 1) -> Optional[str]:
"""
Derives the name of the caller module from the stack frames.
Args:
depth: The position of the frame in the stack.
"""
try:
stack = inspect.stack()
assert depth < len(stack)
# FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
frame_info = stack[depth]
module = inspect.getmodule(frame_info[0])
if module:
module_name = module.__name__
else:
# inspect.getmodule(frame_info[0]) does NOT work (returns None) in
# binaries built with @mode/opt
# return the filename (minus the .py extension) as modulename
filename = frame_info[1]
module_name = os.path.splitext(os.path.basename(filename))[0]
return module_name
except Exception as e:
warnings.warn(
f"Error deriving logger module name, using <None>. Exception: {e}",
RuntimeWarning,
)
return None

View File

@ -0,0 +1,225 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from datetime import timedelta
from typing import Callable, Iterable, List, Optional
import torch
DistStoreError = torch._C._DistStoreError
_NUM_MEMBERS = "/num_members"
_LAST_MEMBER_CHECKIN = "/last_member"
_TRACE = "/TRACE"
_TRACING_GATE = "/TRACING_GATE"
_MAX_TRACE_MISSING_RANKS = 16
__all__ = ["store_timeout", "get_all", "synchronize", "barrier"]
@contextmanager
def store_timeout(store, timeout: float):
"""
This sets the timeout and then restores the old timeout when the context
manager exits.
Args:
store: the store to set the timeout on
timeout: the timeout to set
"""
old_timeout = store.timeout
store.set_timeout(timedelta(seconds=timeout))
yield
store.set_timeout(old_timeout)
def get_all(store, rank: int, prefix: str, world_size: int):
r"""
Given a store and a prefix, the method goes through the array of keys
of the following format: ``{prefix}{idx}``, where idx is in a range
from 0 to size, and tries to retrieve the data.
The Rank0 process waits at the end to make sure all other processes
finished the procedure before exiting.
Usage
::
values = get_all(store, 'torchelastic/data', 3)
value1 = values[0] # retrieves the data for key torchelastic/data0
value2 = values[1] # retrieves the data for key torchelastic/data1
value3 = values[2] # retrieves the data for key torchelastic/data2
"""
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])
barrier_key = _barrier_nonblocking(
store=store,
world_size=world_size,
key_prefix=f"{prefix}/finished",
)
if rank == 0:
# Rank0 runs the TCPStore daemon, as a result it needs to exit last.
# Otherwise, the barrier may timeout if rank0 process finished the work
# before other processes finished `get_all` method
store.wait([barrier_key])
return data_arr
def synchronize(
store,
data: bytes,
rank: int,
world_size: int,
key_prefix: str,
timeout: float = 300,
) -> List[bytes]:
"""
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
The ``data`` will be available on each of the agents.
Note: The data on the path is not deleted, as a result there can be stale data if
you use the same key_prefix twice.
Time complexity: O(N) per worker, O(N^2) globally.
"""
with store_timeout(store, timeout):
store.set(f"{key_prefix}{rank}", data)
agent_data = get_all(store, rank, key_prefix, world_size)
return agent_data
def _try_detecting_missing_ranks(
store,
world_size: int,
key_prefix: str,
rank: int,
rank_decoder: Callable[[int], str],
trace_timeout: float,
) -> Optional[Iterable[str]]:
store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>")
def _find_missing_ranks():
missing_rank_info = set()
ranks_missing = 0
for i in range(1, world_size):
# reduce noise, assuming in general 8 ranks per node
# It is valuable to know that 1 or >1 nodes have timed-out.
if ranks_missing >= _MAX_TRACE_MISSING_RANKS:
break
try:
if ranks_missing == 0:
store.wait(
[f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout)
)
else:
# use a shortest timeout, some ranks have failed to check-in
store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1))
except DistStoreError:
ranks_missing += 1
missing_rank_info.add(rank_decoder(i))
return missing_rank_info
def _checkin():
try:
store.wait([f"{key_prefix}{_TRACING_GATE}"])
return [f"[<check rank 0 ({rank_decoder(0)}) for missing rank info>]"]
except DistStoreError:
# in case rank0 is the source of the timeout, original exception will be raised
return None
if rank == 0:
missing_rank_info = _find_missing_ranks()
store.set(f"{key_prefix}{_TRACING_GATE}", "<val_ignored>")
return missing_rank_info
else:
return _checkin()
def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
"""
Does all the non-blocking operations for a barrier and returns the final key
that can be waited on.
"""
num_members_key = key_prefix + _NUM_MEMBERS
last_member_key = key_prefix + _LAST_MEMBER_CHECKIN
idx = store.add(num_members_key, 1)
if idx == world_size:
store.set(last_member_key, "<val_ignored>")
return last_member_key
def barrier(
store,
world_size: int,
key_prefix: str,
barrier_timeout: float = 300,
rank: Optional[int] = None,
rank_tracing_decoder: Optional[Callable[[int], str]] = None,
trace_timeout: float = 10,
) -> None:
"""
A global lock between agents. This will pause all workers until at least
``world_size`` workers respond.
This uses a fast incrementing index to assign waiting ranks and a success
flag set by the last worker.
Time complexity: O(1) per worker, O(N) globally.
Optionally, passing rank will enable tracing of missing ranks on timeouts.
`rank_tracing_decoder` lambda arg can be used to convert rank data
into a more meaninful information at an app level (e.g. hostname).
Note: Since the data is not removed from the store, the barrier can be used
once per unique ``key_prefix``.
"""
if rank is None:
assert rank_tracing_decoder is None, "Tracing requires rank information"
with store_timeout(store, barrier_timeout):
last_member_key = _barrier_nonblocking(
store=store, world_size=world_size, key_prefix=key_prefix
)
try:
store.wait([last_member_key])
except DistStoreError as e:
if rank is None:
raise e
else:
missing_ranks = _try_detecting_missing_ranks(
store,
world_size,
key_prefix,
rank,
rank_tracing_decoder or (lambda x: str(x)),
trace_timeout,
)
if missing_ranks is not None:
raise DistStoreError(
"Timed out waiting on barrier on "
"rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format(
rank,
key_prefix,
world_size,
f"[{', '.join(missing_ranks)}]",
barrier_timeout,
)
) from None
else:
raise e