I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,233 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary.
For functions, it uses ``torch.multiprocessing`` (and therefore python
``multiprocessing``) to spawn/fork worker processes. For binaries it uses python
``subprocessing.Popen`` to create worker processes.
Usage 1: Launching two trainers as a function
::
from torch.distributed.elastic.multiprocessing import Std, start_processes
def trainer(a, b, c):
pass # train
# runs two trainers
# LOCAL_RANK=0 trainer(1,2,3)
# LOCAL_RANK=1 trainer(4,5,6)
ctx = start_processes(
name="trainer",
entrypoint=trainer,
args={0: (1,2,3), 1: (4,5,6)},
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
log_dir="/tmp/foobar",
redirects=Std.ALL, # write all worker stdout/stderr to a log file
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
)
# waits for all copies of trainer to finish
ctx.wait()
Usage 2: Launching 2 echo workers as a binary
::
# same as invoking
# echo hello
# echo world > stdout.log
ctx = start_processes(
name="echo"
entrypoint="echo",
log_dir="/tmp/foobar",
args={0: "hello", 1: "world"},
redirects={1: Std.OUT},
)
Just like ``torch.multiprocessing``, the return value of the function
:func:`start_processes` is a process context (:class:`api.PContext`). If a function
was launched, a :class:`api.MultiprocessContext` is returned and if a binary
was launched a :class:`api.SubprocessContext` is returned. Both are specific
implementations of the parent :class:`api.PContext` class.
"""
from typing import Callable, Dict, Optional, Tuple, Union
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
_validate_full_rank,
DefaultLogsSpecs,
LogsDest,
LogsSpecs,
MultiprocessContext,
PContext,
ProcessFailure,
RunProcsResult,
SignalException,
Std,
SubprocessContext,
to_map,
)
from torch.distributed.elastic.utils.logging import get_logger
__all__ = [
"start_processes",
"MultiprocessContext",
"PContext",
"ProcessFailure",
"RunProcsResult",
"SignalException",
"Std",
"LogsDest",
"LogsSpecs",
"DefaultLogsSpecs",
"SubprocessContext",
"to_map",
]
def start_processes(
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
start_method: str = "spawn",
) -> PContext:
"""
Start ``n`` copies of ``entrypoint`` processes with the provided options.
``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
The number of copies is determined by the number of entries for ``args`` and
``envs`` arguments, which need to have the same key set.
``args`` and ``env`` parameters are the arguments and environment variables
to pass down to the entrypoint mapped by the replica index (local rank).
All local ranks must be accounted for.
That is, the keyset should be ``{0,1,...,(nprocs-1)}``.
.. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings.
If any other type is given, then it is casted to a string representation
(e.g. ``str(arg1)``). Furthermore, a binary failure will only write
an ``error.json`` error file if the main function is annotated with
``torch.distributed.elastic.multiprocessing.errors.record``. For function launches,
this is done by default and there is no need to manually annotate
with the ``@record`` annotation.
``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
the local rank to specify the redirect behavior for.
Any missing local ranks will default to ``Std.NONE``.
``tee`` acts like the unix "tee" command in that it redirects + prints to console.
To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
For each process, the ``log_dir`` will contain:
#. ``{local_rank}/error.json``: if the process failed, a file with the error info
#. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT``
#. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR``
.. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
Example:
::
log_dir = "/tmp/test"
# ok; two copies of foo: foo("bar0"), foo("bar1")
start_processes(
name="trainer",
entrypoint=foo,
args:{0:("bar0",), 1:("bar1",),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
# invalid; envs missing for local rank 1
start_processes(
name="trainer",
entrypoint=foo,
args:{0:("bar0",), 1:("bar1",),
envs:{0:{}},
log_dir=log_dir
)
# ok; two copies of /usr/bin/touch: touch file1, touch file2
start_processes(
name="trainer",
entrypoint="/usr/bin/touch",
args:{0:("file1",), 1:("file2",),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
# caution; arguments casted to string, runs:
# echo "1" "2" "3" and echo "[1, 2, 3]"
start_processes(
name="trainer",
entrypoint="/usr/bin/echo",
args:{0:(1,2,3), 1:([1,2,3],),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
Args:
name: a human readable short name that describes what the processes are
(used as header when tee'ing stdout/stderr outputs)
entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
args: arguments to each replica
envs: env vars to each replica
log_dir: directory used to write log files
start_method: multiprocessing start method (spawn, fork, forkserver)
ignored for binaries
redirects: which std streams to redirect to a log file
tee: which std streams to redirect + print to console
local_ranks_filter: which ranks' logs to print to console
"""
nprocs = len(args)
_validate_full_rank(args, nprocs, "args")
_validate_full_rank(envs, nprocs, "envs")
context: PContext
if isinstance(entrypoint, str):
context = SubprocessContext(
name=name,
entrypoint=entrypoint,
args=args,
envs=envs,
logs_specs=logs_specs,
log_line_prefixes=log_line_prefixes,
)
else:
context = MultiprocessContext(
name=name,
entrypoint=entrypoint,
args=args,
envs=envs,
log_line_prefixes=log_line_prefixes,
start_method=start_method,
logs_specs=logs_specs,
)
try:
context.start()
return context
except Exception:
context.close()
raise

View File

@ -0,0 +1,923 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import logging
import os
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from contextlib import nullcontext
from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
from torch.distributed.elastic.multiprocessing.redirects import (
redirect_stderr,
redirect_stdout,
)
from torch.distributed.elastic.multiprocessing.subprocess_handler import (
get_subprocess_handler,
SubprocessHandler,
)
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
logger = logging.getLogger(__name__)
__all__ = [
"DefaultLogsSpecs",
"SignalException",
"Std",
"to_map",
"RunProcsResult",
"PContext",
"get_std_cm",
"MultiprocessContext",
"SubprocessContext",
"LogsDest",
"LogsSpecs",
]
class SignalException(Exception):
"""
Exception is raised inside the torchelastic agent process by the termination handler
if the death signal got received by the process.
"""
def __init__(self, msg: str, sigval: signal.Signals) -> None:
super().__init__(msg)
self.sigval = sigval
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
"""Termination handler that raises exceptions on the main process.
When the process receives death signal(SIGTERM, SIGINT), this termination handler will
be invoked. It raises the ``SignalException`` exception that should be processed by the
user code. Python does not terminate process after the termination handler is finished,
so the exception should not be silently ignored, otherwise the process will never
be terminated.
"""
sigval = signal.Signals(signum)
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
def _get_kill_signal() -> signal.Signals:
"""Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGKILL
def _get_default_signal() -> signal.Signals:
"""Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys())
expected_keys = set(range(nprocs))
if actual_keys != expected_keys:
raise RuntimeError(
f"{what}, local rank mapping mismatch,"
f" expected: {expected_keys}, actual: {actual_keys}"
)
_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
_VALUE_REGEX = r"^[0123]$"
class Std(IntFlag):
NONE = 0
OUT = 1
ERR = 2
ALL = OUT | ERR
@classmethod
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
"""
Example:
::
from_str("0") -> Std.NONE
from_str("1") -> Std.OUT
from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
Any other input raises an exception
"""
def to_std(v: str) -> Std: # type: ignore[return]
s = Std(int(v))
if s in Std:
return s
# return None -> should NEVER reach here since we regex check input
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
return to_std(vm)
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
d: Dict[int, Std] = {}
for m in vm.split(","):
i, v = m.split(":")
d[int(i)] = to_std(v)
return d
else:
raise ValueError(
f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>"
)
def to_map(
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
) -> Dict[int, Std]:
"""
Certain APIs take redirect settings either as a single value (e.g. apply to all
local ranks) or as an explicit user-provided mapping. This method is a convenience
method that converts a value or mapping into a mapping.
Example:
::
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
"""
if isinstance(val_or_map, Std):
return dict.fromkeys(range(local_world_size), val_or_map)
else:
map = {}
for i in range(local_world_size):
map[i] = val_or_map.get(i, Std.NONE)
return map
@dataclass
class LogsDest:
"""
For each log type, holds mapping of local rank ids to file paths.
"""
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
tee_stdouts: Dict[int, str] = field(default_factory=dict)
tee_stderrs: Dict[int, str] = field(default_factory=dict)
error_files: Dict[int, str] = field(default_factory=dict)
class LogsSpecs(ABC):
"""
Defines logs processing and redirection for each worker process.
Args:
log_dir:
Base directory where logs will be written.
redirects:
Streams to redirect to files. Pass a single ``Std``
enum to redirect for all workers, or a mapping keyed
by local_rank to selectively redirect.
tee:
Streams to duplicate to stdout/stderr.
Pass a single ``Std`` enum to duplicate streams for all workers,
or a mapping keyed by local_rank to selectively duplicate.
"""
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
) -> None:
self._root_log_dir = log_dir
self._redirects = redirects
self._tee = tee
self._local_ranks_filter = local_ranks_filter
@abstractmethod
def reify(
self,
envs: Dict[int, Dict[str, str]],
) -> LogsDest:
"""
Given the environment variables, builds destination of log files for each of the local ranks.
Envs parameter contains env variables dict for each of the local ranks, where entries are defined in:
:func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`.
"""
@property
@abstractmethod
def root_log_dir(self) -> str:
pass
class DefaultLogsSpecs(LogsSpecs):
"""
Default LogsSpecs implementation:
- `log_dir` will be created if it doesn't exist
- Generates nested folders for each attempt and rank.
"""
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
) -> None:
if log_dir != os.devnull:
if not log_dir:
log_dir = tempfile.mkdtemp(prefix="torchelastic_")
elif not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
else:
if os.path.isfile(log_dir):
raise NotADirectoryError(f"log_dir: {log_dir} is a file")
super().__init__(log_dir, redirects, tee, local_ranks_filter)
# initialized only once
self._run_log_dir = None
@property
def root_log_dir(self) -> str:
return str(self._root_log_dir)
def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
os.makedirs(base_log_dir, exist_ok=True)
dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
logger.info("log directory set to: %s", dir)
return dir
def reify(
self,
envs: Dict[int, Dict[str, str]],
) -> LogsDest:
"""
Uses following scheme to build log destination paths:
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json`
"""
nprocs = len(envs)
global_env = {} # use only to query properies that are not dependent on a rank
if nprocs > 0:
global_env = envs[0]
else:
logger.warning(
"Empty envs map provided when defining logging destinations."
)
# Keys are always defined, but values can be missing in unit tests
run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id")
restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0")
attempt_log_dir: str = ""
if self._root_log_dir != os.devnull:
if not self._run_log_dir:
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload]
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)
if self._root_log_dir == os.devnull:
attempt_log_dir = os.devnull
# create subdirs for each local rank in the logs_dir
# logs_dir
# |- 0
# |- error.json
# |- stdout.log
# |- stderr.log
# |- ...
# |- (nprocs-1)
redirs = to_map(self._redirects, nprocs)
ts = to_map(self._tee, nprocs)
# to tee stdout/stderr we first redirect into a file
# then tail -f stdout.log/stderr.log so add tee settings to redirects
for local_rank, tee_std in ts.items():
redirect_std = redirs[local_rank]
redirs[local_rank] = redirect_std | tee_std
SYS_STREAM = "" # special case to indicate to output to console
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
tee_stdouts: Dict[int, str] = {}
tee_stderrs: Dict[int, str] = {}
error_files = {}
for local_rank in range(nprocs):
if attempt_log_dir == os.devnull:
tee_stdouts[local_rank] = os.devnull
tee_stderrs[local_rank] = os.devnull
error_files[local_rank] = os.devnull
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = ""
else:
clogdir = os.path.join(attempt_log_dir, str(local_rank))
os.mkdir(clogdir)
rd = redirs[local_rank]
if (rd & Std.OUT) == Std.OUT:
stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
if (rd & Std.ERR) == Std.ERR:
stderrs[local_rank] = os.path.join(clogdir, "stderr.log")
t = ts[local_rank]
if t & Std.OUT == Std.OUT:
tee_stdouts[local_rank] = stdouts[local_rank]
if t & Std.ERR == Std.ERR:
tee_stderrs[local_rank] = stderrs[local_rank]
if (
self._local_ranks_filter
and local_rank not in self._local_ranks_filter
):
# If stream is tee'd, only write to file, but don't tail
if local_rank in tee_stdouts:
tee_stdouts.pop(local_rank, None)
if local_rank in tee_stderrs:
tee_stderrs.pop(local_rank, None)
# If stream is not redirected, don't print
if stdouts[local_rank] == SYS_STREAM:
stdouts[local_rank] = os.devnull
if stderrs[local_rank] == SYS_STREAM:
stderrs[local_rank] = os.devnull
error_file = os.path.join(clogdir, "error.json")
error_files[local_rank] = error_file
logger.info(
"Setting worker%s reply file to: %s", local_rank, error_file
)
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files)
def __repr__(self) -> str:
return (
f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, "
f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})"
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DefaultLogsSpecs):
return False
return (
self._root_log_dir == other._root_log_dir
and self._redirects == other._redirects
and self._tee == other._tee
and self._local_ranks_filter == other._local_ranks_filter
)
@dataclass
class RunProcsResult:
"""
Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``.
Note the following:
1. All fields are mapped by local rank
2. ``return_values`` - only populated for functions (not the binaries).
3. ``stdouts`` - path to stdout.log (empty string if no redirect)
4. ``stderrs`` - path to stderr.log (empty string if no redirect)
"""
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
def is_failed(self) -> bool:
return len(self.failures) > 0
class PContext(abc.ABC):
"""
The base class that standardizes operations over a set of processes that are launched via different mechanisms.
The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.
.. warning:: stdouts and stderrs should ALWAYS be a superset of
tee_stdouts and tee_stderrs (respectively) this is b/c
tee is implemented as a redirect + tail -f <stdout/stderr.log>
"""
def __init__(
self,
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
self.name = name
# validate that all mappings have the same number of keys and
# all local ranks are accounted for
nprocs = len(args)
# TODO log_line_prefixes can be exanded too
logs_dest = logs_specs.reify(envs)
_validate_full_rank(logs_dest.stdouts, nprocs, "stdouts")
_validate_full_rank(logs_dest.stderrs, nprocs, "stderrs")
self.entrypoint = entrypoint
self.args = args
self.envs = envs
self.stdouts = logs_dest.stdouts
self.stderrs = logs_dest.stderrs
self.error_files = logs_dest.error_files
self.nprocs = nprocs
self._stdout_tail = TailLog(
name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes
)
self._stderr_tail = TailLog(
name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes
)
def start(self) -> None:
"""Start processes using parameters defined in the constructor."""
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGTERM, _terminate_process_handler)
signal.signal(signal.SIGINT, _terminate_process_handler)
if not IS_WINDOWS:
signal.signal(signal.SIGHUP, _terminate_process_handler)
signal.signal(signal.SIGQUIT, _terminate_process_handler)
else:
logger.warning(
"Failed to register signal handlers since torchelastic is running on a child thread. "
"This could lead to orphaned worker processes if the torchrun is terminated."
)
self._start()
self._stdout_tail.start()
self._stderr_tail.start()
@abc.abstractmethod
def _start(self) -> None:
"""Start processes using strategy defined in a particular context."""
raise NotImplementedError
@abc.abstractmethod
def _poll(self) -> Optional[RunProcsResult]:
"""
Poll the run status of the processes running under this context.
This method follows an "all-or-nothing" policy and returns
a ``RunProcessResults`` object if either all processes complete
successfully or any process fails. Returns ``None`` if
all processes are still running.
"""
raise NotImplementedError
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
"""
Wait for the specified ``timeout`` seconds, polling every ``period`` seconds
for the processes to be done. Returns ``None`` if the processes are still running
on timeout expiry. Negative timeout values are interpreted as "wait-forever".
A timeout value of zero simply queries the status of the processes (e.g. equivalent
to a poll).
..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise
``SignalException`` when the signals received. It is up to the consumer of the code
to properly handle the exception. It is important not to swallow the exception otherwise
the process would not terminate. Example of the typical workflow can be:
.. code-block:: python
pc = start_processes(...)
try:
pc.wait(1)
.. do some other work
except SignalException as e:
pc.shutdown(e.sigval, timeout=30)
If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
received signal. If child processes will not terminate in the timeout time, the process will send
the SIGKILL.
"""
if timeout == 0:
return self._poll()
if timeout < 0:
timeout = sys.maxsize
expiry = time.time() + timeout
while time.time() < expiry:
pr = self._poll()
if pr:
return pr
time.sleep(period)
return None
@abc.abstractmethod
def pids(self) -> Dict[int, int]:
"""Return pids of processes mapped by their respective local_ranks."""
raise NotImplementedError
@abc.abstractmethod
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
"""
raise NotImplementedError
def close(
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
Args:
death_sig: Death signal to terminate processes.
timeout: Time to wait for processes to finish, if process is
still alive after this time, it will be terminated via SIGKILL.
"""
if not death_sig:
death_sig = _get_default_signal()
self._close(death_sig=death_sig, timeout=timeout)
if self._stdout_tail:
self._stdout_tail.stop()
if self._stderr_tail:
self._stderr_tail.stop()
def get_std_cm(std_rd: str, redirect_fn):
if IS_WINDOWS or IS_MACOS or not std_rd:
return nullcontext()
else:
return redirect_fn(std_rd)
def _wrap(
local_rank: int,
fn: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank]
env_ = envs[local_rank]
ret_val_ = ret_vals[local_rank]
stdout_rd = stdout_redirects[local_rank]
stderr_rd = stderr_redirects[local_rank]
stdout_cm = get_std_cm(stdout_rd, redirect_stdout)
stderr_cm = get_std_cm(stderr_rd, redirect_stderr)
for k, v in env_.items():
os.environ[k] = v
with stdout_cm, stderr_cm:
ret = record(fn)(*args_)
ret_val_.put(ret)
queue_finished_reading_event.wait()
class MultiprocessContext(PContext):
"""``PContext`` holding worker processes invoked as a function."""
def __init__(
self,
name: str,
entrypoint: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
start_method: str,
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
super().__init__(
name,
entrypoint,
args,
envs,
logs_specs,
log_line_prefixes,
)
self.start_method = start_method
# each ret_val queue will always contain a single element.
self._ret_vals = {
local_rank: mp.get_context(self.start_method).SimpleQueue()
for local_rank in range(self.nprocs)
}
# see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()
def _start(self):
if self._pc:
raise ValueError(
"The process context already initialized."
" Most likely the start method got called twice."
)
self._pc = mp.start_processes(
fn=_wrap,
args=(
self.entrypoint,
self.args,
self.envs,
self.stdouts,
self.stderrs,
self._ret_vals,
self._worker_finished_event,
),
nprocs=self.nprocs,
join=False,
daemon=False,
start_method=self.start_method,
)
def _is_done(self) -> bool:
return len(self._return_values) == self.nprocs
def _poll(self) -> Optional[RunProcsResult]:
assert self._pc is not None # assertion for mypy type checker
try:
# torch.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# timeout < 0 checks worker status and return immediately
# Join will never return success since we use synchronize.Event to wait
# for all processes to finish.
self._pc.join(-1)
# IMPORTANT: we use multiprocessing.Queue to carry worker return values
# back to the parent, the worker process will wait before terminating
# until all the buffered items are fed by the feeder thread to the underlying
# pipe. Hence to prevent deadlocks on large return values,
# we opportunistically try queue.get on each join call
# See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
for local_rank in range(0, self.nprocs):
return_queue = self._ret_vals[local_rank]
if not return_queue.empty():
# save the return values temporarily into a member var
self._return_values[local_rank] = return_queue.get()
if self._is_done():
# we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# At this point workers finished running the user function
# But the child process might still have not exited. Wait for them.
# pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits.
while not self._pc.join():
logger.debug(
"entrypoint fn finished, waiting for all child procs to exit..."
)
_validate_full_rank(
self._return_values, self.nprocs, "return_value queue"
)
self.close()
return RunProcsResult(
return_values=self._return_values,
stdouts=self.stdouts,
stderrs=self.stderrs,
)
else:
return None
except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
failed_local_rank = e.error_index
# entrypoint for MultiprocessContext will always be a Callable
fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr]
failed_proc = self._pc.processes[failed_local_rank]
error_filepath = self.error_files[failed_local_rank]
logger.exception(
"failed (exitcode: %s)"
" local_rank: %s (pid: %s)"
" of fn: %s (start_method: %s)",
failed_proc.exitcode,
failed_local_rank,
e.pid,
fn_name,
self.start_method,
)
self.close()
return RunProcsResult(
failures={
failed_local_rank: ProcessFailure(
local_rank=failed_local_rank,
pid=e.pid,
exitcode=failed_proc.exitcode,
error_file=error_filepath,
)
},
stdouts=self.stdouts,
stderrs=self.stderrs,
)
def pids(self) -> Dict[int, int]:
assert self._pc is not None # assertion for mypy type checking
return dict(enumerate(self._pc.pids()))
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if not self._pc:
return
for proc in self._pc.processes:
if proc.is_alive():
logger.warning(
"Closing process %s via signal %s", proc.pid, death_sig.name
)
try:
os.kill(proc.pid, death_sig)
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be raised, it is safe to ignore it.
pass
end = time.monotonic() + timeout
for proc in self._pc.processes:
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
proc.join(time_to_wait)
for proc in self._pc.processes:
if proc.is_alive():
logger.warning(
"Unable to shutdown process %s via %s, forcefully exiting via %s",
proc.pid,
death_sig,
_get_kill_signal(),
)
try:
os.kill(proc.pid, _get_kill_signal())
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be raised, it is safe to ignore it.
pass
proc.join()
class SubprocessContext(PContext):
"""``PContext`` holding worker processes invoked as a binary."""
def __init__(
self,
name: str,
entrypoint: str,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
super().__init__(
name,
entrypoint,
args,
envs,
logs_specs,
log_line_prefixes,
)
# state vector; _vdone[local_rank] -> is local_rank finished or not
self._running_local_ranks: Set[int] = set(range(self.nprocs))
self._failures: Dict[int, ProcessFailure] = {}
self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
def _start(self):
if self.subprocess_handlers:
raise ValueError(
"The subprocess handlers already initialized. Most likely the start method got called twice."
)
self.subprocess_handlers = {
local_rank: get_subprocess_handler(
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
args=self.args[local_rank],
env=self.envs[local_rank],
stdout=self.stdouts[local_rank],
stderr=self.stderrs[local_rank],
local_rank_id=local_rank,
)
for local_rank in range(self.nprocs)
}
def _poll(self) -> Optional[RunProcsResult]:
done_local_ranks = set()
for local_rank in self._running_local_ranks:
handler = self.subprocess_handlers[local_rank]
exitcode = handler.proc.poll()
if exitcode is not None:
done_local_ranks.add(local_rank)
if exitcode != 0: # failed or signaled
self._failures[local_rank] = ProcessFailure(
local_rank=local_rank,
pid=handler.proc.pid,
exitcode=exitcode,
error_file=self.error_files[local_rank],
)
# else: --> succeeded; nothing to do
self._running_local_ranks.difference_update(done_local_ranks)
# if ALL procs are finished or ANY have failed
if not self._running_local_ranks or self._failures:
self.close() # terminate all running procs
result = RunProcsResult(
failures=self._failures,
stdouts=self.stdouts,
stderrs=self.stderrs,
)
if result.is_failed():
first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
logger.error(
"failed (exitcode: %s)"
" local_rank: %s (pid: %s)"
" of binary: %s",
first_failure.exitcode,
first_failure.local_rank,
first_failure.pid,
self.entrypoint,
)
else:
# Populate return with dummy values. This provides consistency with MultiprocessingHandler
result.return_values = dict.fromkeys(range(self.nprocs))
return result
else: # there are no failures and procs still running
return None
def pids(self) -> Dict[int, int]:
return {
local_rank: sh.proc.pid
for local_rank, sh in self.subprocess_handlers.items()
}
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if not self.subprocess_handlers:
return
for handler in self.subprocess_handlers.values():
if handler.proc.poll() is None:
logger.warning(
"Sending process %s closing signal %s",
handler.proc.pid,
death_sig.name,
)
handler.close(death_sig=death_sig)
end = time.monotonic() + timeout
for handler in self.subprocess_handlers.values():
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
try:
handler.proc.wait(time_to_wait)
except subprocess.TimeoutExpired:
# Ignore the timeout expired exception, since
# the child process will be forcefully terminated via SIGKILL
pass
for handler in self.subprocess_handlers.values():
if handler.proc.poll() is None:
logger.warning(
"Unable to shutdown process %s via %s, forcefully exiting via %s",
handler.proc.pid,
death_sig,
_get_kill_signal(),
)
handler.close(death_sig=_get_kill_signal())
handler.proc.wait()

View File

@ -0,0 +1,383 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Each host in a distributed PyTorch job runs with a single TorchElastic agent,
and multiple workers (as children processes of the TorchElastic agent).
Since the workers are user-provided (your PyTorch script/job), TorchElastic
has a way to propagate errors on the trainers through the agent and up to the
scheduler, which ultimately informs the end-user about the state of the job
and applies any retry policies.
TorchElastic categorizes errors into 3 categories:
+----------------+----------------+--------------------------------------------------------------+
| Category | Sub-Category | Description |
+================+================+==============================================================+
| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) |
| +----------------+--------------------------------------------------------------+
| | Worker Failure | any failures on the worker child process |
+----------------+----------------+--------------------------------------------------------------+
| Platform Error | n/a | failures caused by the agent |
+----------------+----------------+--------------------------------------------------------------+
| Infra Error | n/a | failures outside the domain of the agent and workers |
| | | (e.g. host failures) |
+----------------+----------------+--------------------------------------------------------------+
All errors other than "Worker Failure" are either raised canonically from the
agent process or implicitly or explicitly crash the agent process. So the
standard language (python) provided exception handling strategies apply.
Worker Failures are special because the exception/failure originates on a different
process from the agent so the error needs to be propagated inter-process
(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process).
TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes`
to launch the workers which has a simple file based inter-process error propagation
built-in.
Any function or binary entrypoint decorated with :func:`record`
will write uncaught exceptions (with the trace information) to a file specified by the
environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent)
sets this env var on each child it launches, then aggregates the error files for all
children, and propagates the one with the **smallest** timestamp (e.g. the **first** error).
"""
import json
import os
import signal
import socket
import time
import warnings
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from string import Template
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from torch.distributed.elastic.utils.logging import get_logger
from .error_handler import ErrorHandler # noqa: F401
from .handlers import get_error_handler # noqa: F401
__all__ = [
"ProcessFailure",
"ChildFailedError",
"record",
"ErrorHandler",
"get_error_handler",
]
logger = get_logger(__name__)
JSON = Dict
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
_NOT_AVAILABLE = "<N/A>"
T = TypeVar("T")
@dataclass
class ProcessFailure:
"""
Represent the failed process result. When the worker process fails, it may record failure root cause into the file.
Tries to read the failure timestamp from the provided ``error_file``,
if the ``error_file`` does not exist, the timestamp is the current
timestamp (seconds since epoch).
The ``message`` field is a concise explanation of the failure. If
the error file exists then the message is obtained from the error file.
Otherwise one is generated based on the failure signature.
.. note:: It is assumed that the ``error_file`` is written by
``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``.
Otherwise the behavior is undefined.
"""
local_rank: int
pid: int
exitcode: int
error_file: str
error_file_data: JSON = field(init=False)
message: str = field(init=False)
timestamp: int = field(init=False)
def __post_init__(self):
self.error_file_data = _EMPTY_ERROR_DATA
if os.path.isfile(self.error_file):
try:
with open(self.error_file) as fp:
self.error_file_data = json.load(fp)
logger.debug(
"User process failed with error data: %s",
json.dumps(self.error_file_data, indent=2),
)
self.message, self.timestamp = self._get_error_data(
self.error_file_data
)
except Exception:
logger.exception("Failed to parse reply file: %s", self.error_file)
raise
else:
self._set_no_reply_file()
# make up an informative message if not already present
if not self.message:
# signals typically do not generate an error file message
if self.exitcode < 0:
self.message = (
f"Signal {-self.exitcode} ({self.signal_name()})"
f" received by PID {self.pid}"
)
else:
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
message = error_file_data["message"]
if isinstance(message, str):
timestamp = int(error_file_data.get("timestamp", 0))
else:
timestamp = int(message["extraInfo"]["timestamp"])
return (message, timestamp)
def _set_no_reply_file(self):
self.error_file = _NOT_AVAILABLE
self.error_file_data = _EMPTY_ERROR_DATA
self.message = ""
self.timestamp = int(time.time())
def signal_name(self) -> str:
if self.exitcode < 0:
# We don't want to kill the parent process trying to find the signal name.
# if the signal doesn't map to a known name, use not available.
try:
return signal.Signals(-self.exitcode).name
except Exception:
return _NOT_AVAILABLE
else:
return _NOT_AVAILABLE
def timestamp_isoformat(self):
"""Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS)."""
return datetime.fromtimestamp(self.timestamp).isoformat(sep="_")
GlobalRank = int
_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
time : ${time}
host : ${hostname}
rank : ${rank} (local_rank: ${local_rank})
exitcode : ${exitcode} (pid: ${pid})
error_file: ${error_file}
traceback : ${message}"""
# extra new lines before and after are intentional
_MSG_FORMAT_TEMPLATE = """
${boarder}
${title}
${section}
Failures:
${other_failures}
${section}
Root Cause (first observed failure):
${root_failure}
${boarder}"""
class ChildFailedError(Exception):
"""
Special exception type that can be raised from a function annotated with the
``@record`` decorator to have the child process' (root exception) propagate
up the stack as-is (e.g. without being wrapped in the parent's traceback).
Useful in cases where the parent is a simple nanny process
and the child (worker) processes are actually doing meaningful compute.
In this case, errors typically occur on the child process as the parent
is not doing anything non-trivial, and child errors should be propagated
to the scheduler for accurate root cause diagnostics.
.. note:: The propagation relies on error files rather than exception handling to
support both function and binary launches.
Example:
::
# process tree on a host (container)
0: scheduler-init-process:
|- 1: torchelastic_agent:
|- 2: trainer_0 (ok)
|- 3: trainer_1 (fail) -> error.json
|- ...
|- n+2: trainer_n (ok)
|- n+3: other processes
|- ...
In the example above, trainer 1's failure (written into error.json) is
the root cause and should be reported to the scheduler's init process.
The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})``
upon detecting trainer 1's failure which would propagate the contents
of trainer 1's error file to the scheduler's init process.
"""
def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
self.name = name
self.failures = failures
assert (
self.failures
) # does not make sense to create a ChildFaileError with no failures
super().__init__(self.format_msg())
def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
return rank, self.failures[rank]
def format_msg(self, boarder_delim="=", section_delim="-"):
title = f"{self.name} FAILED"
root_rank, root_failure = self.get_first_failure()
root_failure_fmt: str = ""
other_failures_fmt: List[str] = []
width = len(title)
for idx, (rank, failure) in enumerate(self.failures.items()):
fmt, w = self._format_failure(idx, rank, failure)
width = max(width, w)
if rank == root_rank:
root_failure_fmt = fmt
else:
other_failures_fmt.append(fmt)
# upper boundary on width
width = min(width, 60)
return Template(_MSG_FORMAT_TEMPLATE).substitute(
boarder=boarder_delim * width,
title=title,
section=section_delim * width,
root_failure=root_failure_fmt,
other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]),
)
def _format_failure(
self, idx: int, rank: int, failure: ProcessFailure
) -> Tuple[str, int]:
# failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
# or a dict (json) of the form
# {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
# so the display logic is:
# 1. if failure.message is not a dict (it is a str) just show it as is
# 2. else try to get the traceback (py_callstack)
# 3. if the traceback is not there, use the message
# 4. if the message is not there show <N/A>
msg = failure.message
if isinstance(failure.message, dict):
msg = (
failure.message.get("extraInfo", {})
.get("py_callstack", failure.message.get("message", "<N/A>"))
.replace("\n", "\n ") # to properly indent the traceback
)
fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
idx=idx,
time=failure.timestamp_isoformat(),
hostname=socket.getfqdn(),
rank=rank,
local_rank=failure.local_rank,
exitcode=failure.exitcode,
pid=failure.pid,
error_file=failure.error_file,
message=msg,
)
width = 0
for line in fmt.split("\n"):
width = max(width, len(line))
return fmt, width
def record(
fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
) -> Callable[..., T]:
"""
Syntactic sugar to record errors/exceptions that happened in the decorated
function using the provided ``error_handler``.
Using this decorator is equivalent to:
::
error_handler = get_error_handler()
error_handler.initialize()
try:
foobar()
except ChildFailedError as e:
_, failure = e.get_first_failure()
error_handler.dump_error_file(failure.error_file, failure.exitcode)
raise
except Exception as e:
error_handler.record(e)
raise
.. important:: use this decorator once per process at the top level method,
typically this is the main method.
Example
::
@record
def main():
pass
if __name__=="__main__":
main()
"""
if not error_handler:
error_handler = get_error_handler()
def wrap(f):
@wraps(f)
def wrapper(*args, **kwargs):
assert error_handler is not None # assertion for mypy type checker
error_handler.initialize()
try:
return f(*args, **kwargs)
except SystemExit as se:
# For run_path based entrypoints, SystemExit with code = 0 will never exit.
# Handling it here by returning a value:
if se.code == 0:
return None
else:
raise
except ChildFailedError as e:
rank, failure = e.get_first_failure()
if failure.error_file != _NOT_AVAILABLE:
error_handler.dump_error_file(failure.error_file, failure.exitcode)
else:
logger.info(
(
"local_rank %s FAILED with no error file."
" Decorate your entrypoint fn with @record for traceback info."
" See: https://pytorch.org/docs/stable/elastic/errors.html",
rank,
)
)
raise
except Exception as e:
error_handler.record_exception(e)
raise
return wrapper
return wrap(fn)

View File

@ -0,0 +1,166 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import faulthandler
import json
import logging
import os
import time
import traceback
import warnings
from typing import Any, Dict, Optional
__all__ = ["ErrorHandler"]
logger = logging.getLogger(__name__)
class ErrorHandler:
"""
Write the provided exception object along with some other metadata about
the error in a structured way in JSON format to an error file specified by the
environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment
variable is not set, then simply logs the contents of what would have been
written to the error file.
This handler may be subclassed to customize the handling of the error.
Subclasses should override ``initialize()`` and ``record_exception()``.
"""
def _get_error_file_path(self) -> Optional[str]:
"""
Return the error file path.
May return ``None`` to have the structured error be logged only.
"""
return os.environ.get("TORCHELASTIC_ERROR_FILE", None)
def initialize(self) -> None:
"""
Call prior to running code that we wish to capture errors/exceptions.
Typically registers signal/fault handlers. Users can override this
function to add custom initialization/registrations that aid in
propagation/information of errors/signals/exceptions/faults.
"""
try:
faulthandler.enable(all_threads=True)
except Exception as e:
warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}")
def _write_error_file(self, file_path: str, error_msg: str) -> None:
"""Write error message to the file."""
try:
with open(file_path, "w") as fp:
fp.write(error_msg)
except Exception as e:
warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}")
def record_exception(self, e: BaseException) -> None:
"""
Write a structured information about the exception into an error file in JSON format.
If the error file cannot be determined, then logs the content
that would have been written to the error file.
"""
file = self._get_error_file_path()
if file:
data = {
"message": {
"message": f"{type(e).__name__}: {e}",
"extraInfo": {
"py_callstack": traceback.format_exc(),
"timestamp": str(int(time.time())),
},
}
}
with open(file, "w") as fp:
json.dump(data, fp)
def override_error_code_in_rootcause_data(
self,
rootcause_error_file: str,
rootcause_error: Dict[str, Any],
error_code: int = 0,
):
"""Modify the rootcause_error read from the file, to correctly set the exit code."""
if "message" not in rootcause_error:
logger.warning(
"child error file (%s) does not have field `message`. \n"
"cannot override error code: %s",
rootcause_error_file,
error_code,
)
elif isinstance(rootcause_error["message"], str):
logger.warning(
"child error file (%s) has a new message format. \n"
"skipping error code override",
rootcause_error_file,
)
else:
rootcause_error["message"]["errorCode"] = error_code
def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
"""Dump parent error file from child process's root cause error and error code."""
with open(rootcause_error_file) as fp:
rootcause_error = json.load(fp)
# Override error code since the child process cannot capture the error code if it
# is terminated by signals like SIGSEGV.
if error_code:
self.override_error_code_in_rootcause_data(
rootcause_error_file, rootcause_error, error_code
)
logger.debug(
"child error file (%s) contents:\n" "%s",
rootcause_error_file,
json.dumps(rootcause_error, indent=2),
)
my_error_file = self._get_error_file_path()
if my_error_file:
# Guard against existing error files
# This can happen when the child is created using multiprocessing
# and the same env var (TORCHELASTIC_ERROR_FILE) is used on the
# parent and child to specify the error files (respectively)
# because the env vars on the child is set in the wrapper function
# and by default the child inherits the parent's env vars, if the child
# process receives a signal before the wrapper function kicks in
# and the signal handler writes to the error file, then the child
# will write to the parent's error file. In this case just log the
# original error file contents and overwrite the error file.
self._rm(my_error_file)
self._write_error_file(my_error_file, json.dumps(rootcause_error))
logger.info("dumped error file to parent's %s", my_error_file)
else:
logger.error(
"no error file defined for parent, to copy child error file (%s)",
rootcause_error_file,
)
def _rm(self, my_error_file):
if os.path.isfile(my_error_file):
# Log the contents of the original file.
with open(my_error_file) as fp:
try:
original = json.dumps(json.load(fp), indent=2)
logger.warning(
"%s already exists"
" and will be overwritten."
" Original contents:\n%s",
my_error_file,
original,
)
except json.decoder.JSONDecodeError:
logger.warning(
"%s already exists"
" and will be overwritten."
" Unable to load original contents:\n",
my_error_file,
)
os.remove(my_error_file)

View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Multiprocessing error-reporting module
from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
__all__ = ["get_error_handler"]
def get_error_handler():
return ErrorHandler()

View File

@ -0,0 +1,104 @@
# mypy: allow-untyped-defs
# !/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Taken and modified from original source:
# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
import ctypes
import logging
import os
import sys
from contextlib import contextmanager
from functools import partial
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
logger = logging.getLogger(__name__)
def get_libc():
if IS_WINDOWS or IS_MACOS:
logger.warning(
"NOTE: Redirects are currently not supported in Windows or MacOs."
)
return None
else:
return ctypes.CDLL("libc.so.6")
libc = get_libc()
def _c_std(stream: str):
return ctypes.c_void_p.in_dll(libc, stream)
def _python_std(stream: str):
return {"stdout": sys.stdout, "stderr": sys.stderr}[stream]
_VALID_STD = {"stdout", "stderr"}
@contextmanager
def redirect(std: str, to_file: str):
"""
Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``.
This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``).
See usage for details.
Directory of ``dst_filename`` is assumed to exist and the destination file
is overwritten if it already exists.
.. note:: Due to buffering cross source writes are not guaranteed to
appear in wall-clock order. For instance in the example below
it is possible for the C-outputs to appear before the python
outputs in the log file.
Usage:
::
# syntactic-sugar for redirect("stdout", "tmp/stdout.log")
with redirect_stdout("/tmp/stdout.log"):
print("python stdouts are redirected")
libc = ctypes.CDLL("libc.so.6")
libc.printf(b"c stdouts are also redirected"
os.system("echo system stdouts are also redirected")
print("stdout restored")
"""
if std not in _VALID_STD:
raise ValueError(
f"unknown standard stream <{std}>, must be one of {_VALID_STD}"
)
c_std = _c_std(std)
python_std = _python_std(std)
std_fd = python_std.fileno()
def _redirect(dst):
libc.fflush(c_std)
python_std.flush()
os.dup2(dst.fileno(), std_fd)
with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
_redirect(dst)
try:
yield
finally:
_redirect(orig_std)
redirect_stdout = partial(redirect, "stdout")
redirect_stderr = partial(redirect, "stderr")

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.distributed.elastic.multiprocessing.subprocess_handler.handlers import (
get_subprocess_handler,
)
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
)
__all__ = ["SubprocessHandler", "get_subprocess_handler"]

View File

@ -0,0 +1,34 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Tuple
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
)
__all__ = ["get_subprocess_handler"]
def get_subprocess_handler(
entrypoint: str,
args: Tuple,
env: Dict[str, str],
stdout: str,
stderr: str,
local_rank_id: int,
):
return SubprocessHandler(
entrypoint=entrypoint,
args=args,
env=env,
stdout=stdout,
stderr=stderr,
local_rank_id=local_rank_id,
)

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import signal
import subprocess
import sys
from typing import Any, Dict, Optional, Tuple
__all__ = ["SubprocessHandler"]
IS_WINDOWS = sys.platform == "win32"
def _get_default_signal() -> signal.Signals:
"""Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
class SubprocessHandler:
"""
Convenience wrapper around python's ``subprocess.Popen``. Keeps track of
meta-objects associated to the process (e.g. stdout and stderr redirect fds).
"""
def __init__(
self,
entrypoint: str,
args: Tuple,
env: Dict[str, str],
stdout: Optional[str],
stderr: Optional[str],
local_rank_id: int,
):
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
# inherit parent environment vars
env_vars = os.environ.copy()
env_vars.update(env)
args_str = (entrypoint, *[str(e) for e in args])
self.local_rank_id = local_rank_id
self.proc: subprocess.Popen = self._popen(args_str, env_vars)
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
kwargs: Dict[str, Any] = {}
if not IS_WINDOWS:
kwargs["start_new_session"] = True
return subprocess.Popen(
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
# `Tuple[str, *Tuple[Any, ...]]`.
args=args,
env=env,
stdout=self._stdout,
stderr=self._stderr,
**kwargs,
)
def close(self, death_sig: Optional[signal.Signals] = None) -> None:
if not death_sig:
death_sig = _get_default_signal()
if IS_WINDOWS:
self.proc.send_signal(death_sig)
else:
os.killpg(self.proc.pid, death_sig)
if self._stdout:
self._stdout.close()
if self._stderr:
self._stderr.close()

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import time
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING:
from concurrent.futures._base import Future
__all__ = ["tail_logfile", "TailLog"]
logger = logging.getLogger(__name__)
def tail_logfile(
header: str, file: str, dst: TextIO, finished: Event, interval_sec: float
):
while not os.path.exists(file):
if finished.is_set():
return
time.sleep(interval_sec)
with open(file, errors="replace") as fp:
while True:
line = fp.readline()
if line:
dst.write(f"{header}{line}")
else: # reached EOF
if finished.is_set():
# log line producer is finished
break
else:
# log line producer is still going
# wait for a bit before looping again
time.sleep(interval_sec)
class TailLog:
"""
Tail the given log files.
The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until
the log files are created by the producer and will tail the contents of the
log files until the ``stop()`` method is called.
.. warning:: ``TailLog`` will wait indefinitely for the log file to be created!
Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``,
where the ``name`` is user-provided and ``idx`` is the index of the log file
in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the
header for each log file.
Usage:
::
log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"}
tailer = TailLog("trainer", log_files, sys.stdout).start()
# actually run the trainers to produce 0_stdout.log and 1_stdout.log
run_trainers()
tailer.stop()
# once run_trainers() start writing the ##_stdout.log files
# the tailer will print to sys.stdout:
# >>> [trainer0]:log_line1
# >>> [trainer1]:log_line1
# >>> [trainer0]:log_line2
# >>> [trainer0]:log_line3
# >>> [trainer1]:log_line2
.. note:: Due to buffering log lines between files may not necessarily
be printed out in order. You should configure your application's
logger to suffix each log line with a proper timestamp.
"""
def __init__(
self,
name: str,
log_files: Dict[int, str],
dst: TextIO,
log_line_prefixes: Optional[Dict[int, str]] = None,
interval_sec: float = 0.1,
):
n = len(log_files)
self._threadpool = None
if n > 0:
self._threadpool = ThreadPoolExecutor(
max_workers=n,
thread_name_prefix=f"{self.__class__.__qualname__}_{name}",
)
self._name = name
self._dst = dst
self._log_files = log_files
self._log_line_prefixes = log_line_prefixes
self._finished_events: Dict[int, Event] = {
local_rank: Event() for local_rank in log_files.keys()
}
self._futs: List[Future] = []
self._interval_sec = interval_sec
self._stopped = False
def start(self) -> "TailLog":
if not self._threadpool:
return self
for local_rank, file in self._log_files.items():
header = f"[{self._name}{local_rank}]:"
if self._log_line_prefixes and local_rank in self._log_line_prefixes:
header = self._log_line_prefixes[local_rank]
self._futs.append(
self._threadpool.submit(
tail_logfile,
header=header,
file=file,
dst=self._dst,
finished=self._finished_events[local_rank],
interval_sec=self._interval_sec,
)
)
return self
def stop(self) -> None:
for finished in self._finished_events.values():
finished.set()
for local_rank, f in enumerate(self._futs):
try:
f.result()
except Exception as e:
logger.error(
"error in log tailor for %s%s. %s: %s",
self._name,
local_rank,
e.__class__.__qualname__,
e,
)
if self._threadpool:
self._threadpool.shutdown(wait=True)
self._stopped = True
def stopped(self) -> bool:
return self._stopped