I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,166 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
In the context of Torch Distributed Elastic we use the term *rendezvous* to
refer to a particular functionality that combines a **distributed
synchronization** primitive with **peer discovery**.
It is used by Torch Distributed Elastic to gather participants of a training
job (i.e. nodes) such that they all agree on the same list of participants and
everyone's roles, as well as make a consistent collective decision on when
training can begin/resume.
Torch Distributed Elastic rendezvous provides the following critical
functionalities:
**Barrier**:
Nodes performing rendezvous will all block until the rendezvous is considered
complete - this happens when at least ``min`` total number of nodes have joined
the rendezvous barrier (for the same job). This also implies the barrier is not
necessarily of fixed size.
There's an additional small waiting time after reaching ``min`` number of
nodes - this is used to ensure the rendezvous is not completed "too quickly"
(which could potentially exclude additional nodes attempting to join at
approximately the same time).
If ``max`` number of nodes is gathered at the barrier, the rendezvous is
completed immediately.
There's also an overall timeout which causes the rendezvous to fail if ``min``
number of nodes is never reached - this is meant to be a simple fail-safe to
help release partially allocated job resources, in case there's a problem with
the resource manager, and is meant to be interpreted as non-retryable.
**Exclusivity**:
A simple distributed barrier would not be sufficient, as we also need to ensure
that only one group of nodes exists at any given time (for a given job). In
other words, new nodes (i.e. joining late) should not be able to form a parallel
independent group of workers for the same job.
Torch Distributed Elastic rendezvous ensures that if a group of nodes has
already completed a rendezvous (and hence might already be training), then
additional "late" nodes attempting to rendezvous will only announce themselves
as waiting, and will have to wait until the (previously completed) existing
rendezvous is destroyed first.
**Consistency**:
When a rendezvous is completed, all its members will agree on the job membership
and everyone's role in it. This role is represented using an integer, called
rank, that is between between 0 and world size.
Note that ranks are *not stable*, in the sense that the same node can be
assigned a different rank in the next (re-)rendezvous.
**Fault-tolerance**:
Torch Distributed Elastic rendezvous is designed to tolerate node failures
during the rendezvous process. Should a process crash (or lose network
connectivity, etc), between joining the rendezvous and it being completed, then
a re-rendezvous with remaining healthy nodes will happen automatically.
A node can also fail *after* it has completed (or *has been observered* by other
nodes to have completed) the rendezvous - this scenario will be handled by the
Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a
re-rendezvous).
**Shared key-value store**:
When the rendezvous is completed, a shared key-value store is created and
returned. This store implements a ``torch.distributed.Store`` API (see
`distributed communication docs
<https://pytorch.org/docs/stable/distributed.html>`__).
This store is only shared by the members of the completed rendezvous. It
is intended to be used by Torch Distributed Elastic to exchange information
necessary to initialize job control and data-planes.
**Waiting workers and rendezvous closing**:
Torch Distributed Elastic rendezvous handler object provides additional
functionalities, which are technically not part of the rendezvous process:
1. Querying how many workers arrived late at the barrier, who can participate in
*next* rendezvous.
2. Setting the rendezvous *closed* to signal all nodes not to participate in
next rendezvous.
**DynamicRendezvousHandler**:
Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler`
class that implements the rendezvous mechanism described above. It is a backend-
agnostic type that expects a particular :py:class:`.RendezvousBackend` instance
to be specified during construction.
Torch distributed users can either implement their own backend type or use one
of the following implementations that come with PyTorch:
- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default
``TCPStore``) as the rendezvous backend. The main advantage of using a C10d
store is that it requires no 3rd-party dependency (such as etcd) to establish
a rendezvous.
- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy
:py:class:`.EtcdRendezvousHandler` class. Passing an
:py:class:`.EtcdRendezvousBackend` instance to
:py:class:`.DynamicRendezvousHandler` is functionally equivalent to
instantiating an :py:class:`.EtcdRendezvousHandler`.
::
store = TCPStore("localhost")
backend = C10dRendezvousBackend(store, "my_run_id")
rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id",
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
)
"""
from .api import (
rendezvous_handler_registry,
RendezvousClosedError,
RendezvousConnectionError,
RendezvousError,
RendezvousGracefulExitError,
RendezvousHandler,
RendezvousHandlerCreator,
RendezvousHandlerRegistry,
RendezvousInfo,
RendezvousParameters,
RendezvousStateError,
RendezvousStoreInfo,
RendezvousTimeoutError,
)
from .registry import _register_default_handlers
_register_default_handlers()
__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousGracefulExitError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousInfo",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousStoreInfo",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]

View File

@ -0,0 +1,379 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import socket
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, Optional
from torch.distributed import Store
from torch.distributed.elastic.utils.distributed import get_free_port as _get_free_port
__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousGracefulExitError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousInfo",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousStoreInfo",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]
class RendezvousError(Exception):
"""Represents the base type for rendezvous errors."""
class RendezvousClosedError(RendezvousError):
"""Raised when a rendezvous is closed."""
class RendezvousTimeoutError(RendezvousError):
"""Raised when a rendezvous did not complete on time."""
class RendezvousConnectionError(RendezvousError):
"""Raised when the connection to a rendezvous backend has failed."""
class RendezvousStateError(RendezvousError):
"""Raised when the state of a rendezvous is corrupt."""
class RendezvousGracefulExitError(RendezvousError):
"""Raised when node wasn't not included in rendezvous and gracefully exits.
Exception is a mechanism to exit the stack, however does not mean a failure.
"""
@dataclass
class RendezvousStoreInfo:
"""Store address and port that can be used to bootstrap trainer distributed comms"""
MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR"
MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT"
master_addr: str
master_port: int
@staticmethod
def build(
rank: int, store: Store, local_addr: Optional[str]
) -> "RendezvousStoreInfo":
"""Factory method, finds unused new port on rank0 host and addr/port info with all ranks.
If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor.
Args:
rank: rank of the current node
store: store to use for rendezvous
local_addr: address of the current node, if not provided will be resolved from hostname
"""
# TODO swap to collectives comms API
if rank == 0:
addr = local_addr or socket.getfqdn()
port = _get_free_port()
store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type]
store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type]
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
port = int(
store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")
)
return RendezvousStoreInfo(master_addr=addr, master_port=port)
class RendezvousInfo:
"""Holds the information about the rendezvous."""
def __init__(
self,
store: Store,
rank: int,
world_size: int,
bootstrap_store_info: RendezvousStoreInfo,
):
self._store = store
self._rank = rank
self._world_size = world_size
self._bootstrap_store_info = bootstrap_store_info
@property
def store(self) -> Store:
"""Store used by torchelastic control plane"""
return self._store
@property
def rank(self) -> int:
"""Rank within a group"""
return self._rank
@property
def world_size(self) -> int:
"""Global group size"""
return self._world_size
@property
def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]:
"""Store information that can used by trainer code to bootstrap distributed comms."""
return self._bootstrap_store_info
class RendezvousHandler(ABC):
"""Main rendezvous interface.
Note:
Distributed Torch users normally **do not** need to implement their own
``RendezvousHandler``. An implementation based on C10d Store is already
provided, and is recommended for most users.
"""
@abstractmethod
def get_backend(self) -> str:
"""Return the name of the rendezvous backend."""
@property
def use_agent_store(self) -> bool:
"""Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user
applications and will be available during application lifecyle.
Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`.
Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store.
"""
return False
@abstractmethod
def next_rendezvous(self) -> RendezvousInfo:
"""Main entry-point into the rendezvous barrier.
Blocks until the rendezvous is complete and the current process is
included in the formed worker group, or a timeout occurs, or the
rendezvous was marked closed.
Returns:
Instance of :py:class:`RendezvousInfo`.
Raises:
RendezvousClosedError:
The rendezvous is closed.
RendezvousConnectionError:
The connection to the rendezvous backend has failed.
RendezvousStateError:
The rendezvous state is corrupt.
RendezvousTimeoutError:
The rendezvous did not complete on time.
"""
@abstractmethod
def is_closed(self) -> bool:
"""Check whether the rendezvous has been closed.
A closed rendezvous means all future attempts to re-rendezvous within
same job will fail.
``is_closed()`` and :py:meth:`set_closed` have semantics of eventual
propagation and should not be used for synchronization. The intention is
that if at least one node decides the job is finished, it will close the
rendezvous, and other nodes will soon observe this and stop running as
well.
"""
@abstractmethod
def set_closed(self):
"""Mark the rendezvous as closed."""
@abstractmethod
def num_nodes_waiting(self) -> int:
"""Return the number of nodes who arrived late at the rendezvous
barrier, hence were not included in the current worker group.
Callers should periodically call this method to check whether new
nodes are waiting to join the job and if so admit them by calling
:py:meth:`next_rendezvous()` (re-rendezvous).
"""
@abstractmethod
def get_run_id(self) -> str:
"""Return the run id of the rendezvous.
The run id is a user-defined id that uniquely identifies an instance of
a distributed application. It typically maps to a job id and is used to
allow nodes to join the correct distributed application.
"""
@abstractmethod
def shutdown(self) -> bool:
"""Close all resources that were open for the rendezvous.
Example::
rdzv_handler = ...
try:
store, rank, world_size = rdzv_handler.next_rendezvous()
finally:
rdzv_handler.shutdown()
"""
class RendezvousParameters:
"""Hold the parameters to construct a :py:class:`RendezvousHandler`.
Args:
backend:
The name of the backend to use to handle the rendezvous.
endpoint:
The endpoint of the rendezvous, usually in form <hostname>[:<port>].
run_id:
The id of the rendezvous.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
local_addr:
The address of the local node.
**kwargs:
Additional parameters for the specified backend.
"""
def __init__(
self,
backend: str,
endpoint: str,
run_id: str,
min_nodes: int,
max_nodes: int,
local_addr: Optional[str] = None,
**kwargs,
):
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
if min_nodes < 1:
raise ValueError(
f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
)
if max_nodes < min_nodes:
raise ValueError(
f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
f"equal to the minimum number of rendezvous nodes ({min_nodes})."
)
self.backend = backend
self.endpoint = endpoint
self.run_id = run_id
self.min_nodes = min_nodes
self.max_nodes = max_nodes
self.config = kwargs
self.local_addr = local_addr
def get(self, key: str, default: Any = None) -> Any:
"""Return the value for ``key`` if ``key`` exists, else ``default``."""
return self.config.get(key, default)
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
"""Return the value for ``key`` as a ``bool``."""
value = self.get(key, default)
if value is None or isinstance(value, bool):
return value
if isinstance(value, int):
if value == 1:
return True
if value == 0:
return False
elif isinstance(value, str):
if value.lower() in ["1", "true", "t", "yes", "y"]:
return True
if value.lower() in ["0", "false", "f", "no", "n"]:
return False
raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
)
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
"""Return the value for ``key`` as an ``int``."""
value = self.get(key, default)
if value is None:
return value
try:
return int(value)
except ValueError as e:
raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid integer "
"value."
) from e
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
class RendezvousHandlerRegistry:
"""Represent a registry of :py:class:`RendezvousHandler` backends."""
_registry: Dict[str, RendezvousHandlerCreator]
def __init__(self) -> None:
self._registry = {}
def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
"""Register a new rendezvous backend.
Args:
backend:
The name of the backend.
creator:
The callback to invoke to construct the
:py:class:`RendezvousHandler`.
"""
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
current_creator: Optional[RendezvousHandlerCreator]
try:
current_creator = self._registry[backend]
except KeyError:
current_creator = None
if current_creator is not None and current_creator != creator:
raise ValueError(
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
f"is already registered with '{current_creator}'."
)
self._registry[backend] = creator
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""Create a new :py:class:`RendezvousHandler`."""
try:
creator = self._registry[params.backend]
except KeyError as e:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call `{self.register.__name__}`?"
) from e
handler = creator(params)
# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
f"backend '{params.backend}'."
)
return handler
# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()

View File

@ -0,0 +1,273 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import binascii
import logging
import os
import tempfile
from base64 import b64decode, b64encode
from datetime import timedelta
from typing import Any, cast, Optional, Tuple
from torch.distributed import FileStore, Store, TCPStore
from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
from .api import (
RendezvousConnectionError,
RendezvousError,
RendezvousParameters,
RendezvousStateError,
)
from .dynamic_rendezvous import RendezvousBackend, Token
from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
logger = logging.getLogger(__name__)
# default port for the TCP store
DEFAULT_PORT = 29400
class C10dRendezvousBackend(RendezvousBackend):
"""Represents a C10d-backed rendezvous backend.
Args:
store:
The :py:class:`torch.distributed.Store` instance to use to
communicate with the C10d store.
run_id:
The run id of the rendezvous.
"""
# See the explanation in the __init__ method.
_NULL_SENTINEL = "Y2FuaW1hZGFt"
_store: Store
_key: str
def __init__(self, store: Store, run_id: str) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
self._store = store
self._key = "torch.rendezvous." + run_id
# The read operation of a store blocks the caller until the specified
# key becomes available. This behavior makes it tricky to use a store
# as a regular key-value dictionary.
#
# As a workaround we initially set a sentinel value as the rendezvous
# state. Whenever this value gets returned we treat it as a None.
self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
@property
def name(self) -> str:
"""See base class."""
return "c10d"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
base64_state: bytes = self._call_store("get", self._key)
return self._decode_state(base64_state)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state_str: str = b64encode(state).decode()
if token:
# Shortcut if we know for sure that the token is not valid.
if not isinstance(token, bytes):
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
token = token.decode()
else:
token = self._NULL_SENTINEL
base64_state: bytes = self._call_store(
"compare_set", self._key, token, base64_state_str
)
state_token_pair = self._decode_state(base64_state)
if state_token_pair is None:
return None
new_state, new_token = state_token_pair
# C10d Store's compare_set method does not offer an easy way to find out
# whether our write attempt was successful. As a brute-force solution we
# perform a bitwise comparison of our local state and the remote state.
return new_state, new_token, new_state == state
def _call_store(self, store_op: str, *args, **kwargs) -> Any:
try:
return getattr(self._store, store_op)(*args, **kwargs)
except (ValueError, RuntimeError, TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
if base64_state == self._NULL_SENTINEL.encode():
return None
try:
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
) from exc
return state, base64_state
def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT)
cfg_is_host = params.get_as_bool("is_host")
# If the user has explicitly specified whether our process should host the
# the store, respect it.
if cfg_is_host is not None:
is_host = cfg_is_host
# Otherwise try to determine whether we are the host based on our hostname
# and IP address.
else:
is_host = _matches_machine_hostname(host)
# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")
# In specific cases we attempt to instantiate the store twice. For details
# see the explanation in the except clause below.
for is_server in [is_host, False]:
try:
store = TCPStore(
host,
port,
is_master=is_server,
multi_tenant=True,
timeout=timedelta(seconds=read_timeout),
)
if is_server:
msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
construct_and_record_rdzv_event(
run_id=params.run_id, message=msg, node_state=NodeState.INIT
)
logger.info(msg)
break
except (ValueError, RuntimeError, TimeoutError) as exc:
# If we heuristically inferred the value of is_host as True and our
# first attempt to instantiate the TCP store has failed, try it one
# more time with is_host set to False. As an edge case there can be
# more than one process that is part of the same rendezvous on this
# machine and only one of them will eventually host the store.
if not is_server or cfg_is_host is not None:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store # type: ignore[possibly-undefined]
def _create_file_store(params: RendezvousParameters) -> FileStore:
# If a user specifies an endpoint, we treat it as a path to a file.
if params.endpoint:
path = params.endpoint
else:
try:
# The temporary file is readable and writable only by the user of
# this process.
_, path = tempfile.mkstemp()
except OSError as exc:
raise RendezvousError(
"The file creation for C10d store has failed. See inner exception for details."
) from exc
try:
store = FileStore(path)
except (ValueError, RuntimeError) as exc:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store
def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
"""Create a new :py:class:`C10dRendezvousBackend` from the specified parameters.
+--------------+-----------------------------------------------------------+
| Parameter | Description |
+==============+===========================================================+
| store_type | The type of the C10d store. The currently supported types |
| | are "tcp" and "file" which correspond to |
| | :py:class:`torch.distributed.TCPStore` and |
| | :py:class:`torch.distributed.FileStore`, respectively. |
| | Defaults to "tcp". |
+--------------+-----------------------------------------------------------+
| read_timeout | The read timeout, in seconds, for store operations. |
| | Defaults to 60 seconds. |
| | |
| | Note this only applies to |
| | :py:class:`torch.distributed.TCPStore`. It is not relevant|
| | to :py:class:`torch.distributed.FileStore` which does not |
| | take in timeout as a parameter. |
+--------------+-----------------------------------------------------------+
| is_host | A boolean value indicating whether this backend instance |
| | will host the C10d store. If not specified it will be |
| | inferred heuristically by matching the hostname or the IP |
| | address of this machine against the specified rendezvous |
| | endpoint. Defaults to ``None``. |
| | |
| | Note that this configuration option only applies to |
| | :py:class:`torch.distributed.TCPStore`. In normal |
| | circumstances you can safely skip it; the only time when |
| | it is needed is if its value cannot be correctly |
| | determined (e.g. the rendezvous endpoint has a CNAME as |
| | the hostname or does not match the FQDN of the machine). |
+--------------+-----------------------------------------------------------+
"""
# As of today we only support TCPStore and FileStore. Other store types do
# not have the required functionality (e.g. compare_set) yet.
store_type = params.get("store_type", "tcp").strip().lower()
store: Store
try:
if store_type == "file":
store = _create_file_store(params)
elif store_type == "tcp":
store = _create_tcp_store(params)
else:
raise ValueError(
"Invalid store type given. Currently only supports file and tcp."
)
backend = C10dRendezvousBackend(store, params.run_id)
except Exception as e:
construct_and_record_rdzv_event(
message=f"{type(e).__name__}: {str(e)}",
run_id=params.run_id,
node_state=NodeState.FAILED,
)
raise
return backend, store

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,217 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import binascii
from base64 import b64decode, b64encode
from typing import cast, Optional, Tuple
import urllib3.exceptions # type: ignore[import]
from etcd import ( # type: ignore[import]
Client as EtcdClient,
EtcdAlreadyExist,
EtcdCompareFailed,
EtcdException,
EtcdKeyNotFound,
EtcdResult,
)
from torch.distributed import Store
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
class EtcdRendezvousBackend(RendezvousBackend):
"""Represents an etcd-based rendezvous backend.
Args:
client:
The ``etcd.Client`` instance to use to communicate with etcd.
run_id:
The run id of the rendezvous.
key_prefix:
The path under which to store the rendezvous state in etcd.
ttl:
The TTL of the rendezvous state. If not specified, defaults to two hours.
"""
_DEFAULT_TTL = 7200 # 2 hours
_client: EtcdClient
_key: str
_ttl: int
def __init__(
self,
client: EtcdClient,
run_id: str,
key_prefix: Optional[str] = None,
ttl: Optional[int] = None,
) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
self._client = client
if key_prefix:
self._key = key_prefix + "/" + run_id
else:
self._key = run_id
if ttl and ttl > 0:
self._ttl = ttl
else:
self._ttl = self._DEFAULT_TTL
@property
def name(self) -> str:
"""See base class."""
return "etcd-v2"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
try:
result = self._client.read(self._key)
except EtcdKeyNotFound:
return None
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
return self._decode_state(result)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state = b64encode(state).decode()
kwargs = {}
def get_state():
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
if token:
try:
token = int(token)
except ValueError:
return get_state()
if token:
kwargs["prevIndex"] = token
else:
kwargs["prevExist"] = False
try:
result = self._client.write(self._key, base64_state, self._ttl, **kwargs)
except (EtcdAlreadyExist, EtcdCompareFailed):
result = None
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
if result is None:
return get_state()
tmp = *self._decode_state(result), True
return tmp
def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]:
base64_state = result.value.encode()
try:
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
) from exc
return state, result.modifiedIndex
def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379)
# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")
# The communication protocol
protocol = params.get("protocol", "http").strip().lower()
if protocol != "http" and protocol != "https":
raise ValueError("The protocol must be HTTP or HTTPS.")
# The SSL client certificate
ssl_cert = params.get("ssl_cert")
if ssl_cert:
ssl_cert_key = params.get("ssl_cert_key")
if ssl_cert_key:
# The etcd client expects the certificate key as the second element
# of the `cert` tuple.
ssl_cert = (ssl_cert, ssl_cert_key)
# The root certificate
ca_cert = params.get("ca_cert")
try:
return EtcdClient(
host,
port,
read_timeout=read_timeout,
protocol=protocol,
cert=ssl_cert,
ca_cert=ca_cert,
allow_reconnect=True,
)
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
"""Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters.
+--------------+-----------------------------------------------------------+
| Parameter | Description |
+==============+===========================================================+
| read_timeout | The read timeout, in seconds, for etcd operations. |
| | Defaults to 60 seconds. |
+--------------+-----------------------------------------------------------+
| protocol | The protocol to use to communicate with etcd. Valid |
| | values are "http" and "https". Defaults to "http". |
+--------------+-----------------------------------------------------------+
| ssl_cert | The path to the SSL client certificate to use along with |
| | HTTPS. Defaults to ``None``. |
+--------------+-----------------------------------------------------------+
| ssl_cert_key | The path to the private key of the SSL client certificate |
| | to use along with HTTPS. Defaults to ``None``. |
+--------------+-----------------------------------------------------------+
| ca_cert | The path to the rool SSL authority certificate. Defaults |
| | to ``None``. |
+--------------+-----------------------------------------------------------+
"""
client = _create_etcd_client(params)
backend = EtcdRendezvousBackend(
client, params.run_id, key_prefix="/torch/elastic/rendezvous"
)
store = EtcdStore(client, "/torch/elastic/store")
return backend, store

View File

@ -0,0 +1,248 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import atexit
import logging
import os
import shlex
import shutil
import socket
import subprocess
import tempfile
import time
from typing import Optional, TextIO, Union
try:
import etcd # type: ignore[import]
except ModuleNotFoundError:
pass
logger = logging.getLogger(__name__)
def find_free_port():
"""
Find a free port and binds a temporary socket to it so that the port can be "reserved" until used.
.. note:: the returned socket must be closed before using the port,
otherwise a ``address already in use`` error will happen.
The socket should be held and closed as close to the
consumer of the port as possible since otherwise, there
is a greater chance of race-condition where a different
process may see the port as being free and take it.
Returns: a socket binded to the reserved free port
Usage::
sock = find_free_port()
port = sock.getsockname()[1]
sock.close()
use_port(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
try:
s = socket.socket(family, type, proto)
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close() # type: ignore[possibly-undefined]
print(f"Socket creation attempt failed: {e}")
raise RuntimeError("Failed to create a socket")
def stop_etcd(subprocess, data_dir: Optional[str] = None):
if subprocess and subprocess.poll() is None:
logger.info("stopping etcd server")
subprocess.terminate()
subprocess.wait()
if data_dir:
logger.info("deleting etcd data dir: %s", data_dir)
shutil.rmtree(data_dir, ignore_errors=True)
class EtcdServer:
"""
.. note:: tested on etcd server v3.4.3.
Starts and stops a local standalone etcd server on a random free
port. Useful for single node, multi-worker launches or testing,
where a sidecar etcd server is more convenient than having to
separately setup an etcd server.
This class registers a termination handler to shutdown the etcd
subprocess on exit. This termination handler is NOT a substitute for
calling the ``stop()`` method.
The following fallback mechanism is used to find the etcd binary:
1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
2. Uses ``<this file root>/bin/etcd`` if one exists
3. Uses ``etcd`` from ``PATH``
Usage
::
server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
server.start()
client = server.get_client()
# use client
server.stop()
Args:
etcd_binary_path: path of etcd server binary (see above for fallback path)
"""
def __init__(self, data_dir: Optional[str] = None):
self._port = -1
self._host = "localhost"
root = os.path.dirname(__file__)
default_etcd_bin = os.path.join(root, "bin/etcd")
self._etcd_binary_path = os.environ.get(
"TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
)
if not os.path.isfile(self._etcd_binary_path):
self._etcd_binary_path = "etcd"
self._base_data_dir = (
data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
)
self._etcd_cmd = None
self._etcd_proc: Optional[subprocess.Popen] = None
def _get_etcd_server_process(self) -> subprocess.Popen:
if not self._etcd_proc:
raise RuntimeError(
"No etcd server process started. Call etcd_server.start() first"
)
else:
return self._etcd_proc
def get_port(self) -> int:
"""Return the port the server is running on."""
return self._port
def get_host(self) -> str:
"""Return the host the server is running on."""
return self._host
def get_endpoint(self) -> str:
"""Return the etcd server endpoint (host:port)."""
return f"{self._host}:{self._port}"
def start(
self,
timeout: int = 60,
num_retries: int = 3,
stderr: Union[int, TextIO, None] = None,
) -> None:
"""
Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
Args:
timeout: time (in seconds) to wait for the server to be ready
before giving up.
num_retries: number of retries to start the server. Each retry
will wait for max ``timeout`` before considering it as failed.
stderr: the standard error file handle. Valid values are
`subprocess.PIPE`, `subprocess.DEVNULL`, an existing file
descriptor (a positive integer), an existing file object, and
`None`.
Raises:
TimeoutError: if the server is not ready within the specified timeout
"""
curr_retries = 0
while True:
try:
data_dir = os.path.join(self._base_data_dir, str(curr_retries))
os.makedirs(data_dir, exist_ok=True)
return self._start(data_dir, timeout, stderr)
except Exception as e:
curr_retries += 1
stop_etcd(self._etcd_proc)
logger.warning(
"Failed to start etcd server, got error: %s, retrying", str(e)
)
if curr_retries >= num_retries:
shutil.rmtree(self._base_data_dir, ignore_errors=True)
raise
atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
def _start(
self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
) -> None:
sock = find_free_port()
sock_peer = find_free_port()
self._port = sock.getsockname()[1]
peer_port = sock_peer.getsockname()[1]
etcd_cmd = shlex.split(
" ".join(
[
self._etcd_binary_path,
"--enable-v2",
"--data-dir",
data_dir,
"--listen-client-urls",
f"http://{self._host}:{self._port}",
"--advertise-client-urls",
f"http://{self._host}:{self._port}",
"--listen-peer-urls",
f"http://{self._host}:{peer_port}",
]
)
)
logger.info("Starting etcd server: [%s]", etcd_cmd)
sock.close()
sock_peer.close()
self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr)
self._wait_for_ready(timeout)
def get_client(self):
"""Return an etcd client object that can be used to make requests to this server."""
return etcd.Client(
host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
)
def _wait_for_ready(self, timeout: int = 60) -> None:
client = etcd.Client(
host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
)
max_time = time.time() + timeout
while time.time() < max_time:
if self._get_etcd_server_process().poll() is not None:
# etcd server process finished
exitcode = self._get_etcd_server_process().returncode
raise RuntimeError(
f"Etcd server process exited with the code: {exitcode}"
)
try:
logger.info("etcd server ready. version: %s", client.version)
return
except Exception:
time.sleep(1)
raise TimeoutError("Timed out waiting for etcd server to be ready!")
def stop(self) -> None:
"""Stop the server and cleans up auto generated resources (e.g. data dir)."""
logger.info("EtcdServer stop method called")
stop_etcd(self._etcd_proc, self._base_data_dir)

View File

@ -0,0 +1,207 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import random
import time
from base64 import b64decode, b64encode
from typing import Optional
import etcd # type: ignore[import]
# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
from torch.distributed import Store
# Delay (sleep) for a small random amount to reduce CAS failures.
# This does not affect correctness, but will reduce requests to etcd server.
def cas_delay():
time.sleep(random.uniform(0, 0.1))
# pyre-fixme[11]: Annotation `Store` is not defined as a type.
class EtcdStore(Store):
"""
Implement a c10 Store interface by piggybacking on the rendezvous etcd instance.
This is the store object returned by ``EtcdRendezvous``.
"""
def __init__(
self,
etcd_client,
etcd_store_prefix,
# Default timeout same as in c10d/Store.hpp
timeout: Optional[datetime.timedelta] = None,
):
super().__init__() # required for pybind trampoline.
self.client = etcd_client
self.prefix = etcd_store_prefix
if timeout is not None:
self.set_timeout(timeout)
if not self.prefix.endswith("/"):
self.prefix += "/"
def set(self, key, value):
"""
Write a key/value pair into ``EtcdStore``.
Both key and value may be either Python ``str`` or ``bytes``.
"""
self.client.set(key=self.prefix + self._encode(key), value=self._encode(value))
def get(self, key) -> bytes:
"""
Get a value by key, possibly doing a blocking wait.
If key is not immediately present, will do a blocking wait
for at most ``timeout`` duration or until the key is published.
Returns:
value ``(bytes)``
Raises:
LookupError - If key still not published after timeout
"""
b64_key = self.prefix + self._encode(key)
kvs = self._try_wait_get([b64_key])
if kvs is None:
raise LookupError(f"Key {key} not found in EtcdStore")
return self._decode(kvs[b64_key])
def add(self, key, num: int) -> int:
"""
Atomically increment a value by an integer amount.
The integer is represented as a string using base 10. If key is not present,
a default value of ``0`` will be assumed.
Returns:
the new (incremented) value
"""
b64_key = self._encode(key)
# c10d Store assumes value is an integer represented as a decimal string
try:
# Assume default value "0", if this key didn't yet:
node = self.client.write(
key=self.prefix + b64_key,
value=self._encode(str(num)), # i.e. 0 + num
prevExist=False,
)
return int(self._decode(node.value))
except etcd.EtcdAlreadyExist:
pass
while True:
# Note: c10d Store does not have a method to delete keys, so we
# can be sure it's still there.
node = self.client.get(key=self.prefix + b64_key)
new_value = self._encode(str(int(self._decode(node.value)) + num))
try:
node = self.client.test_and_set(
key=node.key, value=new_value, prev_value=node.value
)
return int(self._decode(node.value))
except etcd.EtcdCompareFailed:
cas_delay()
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
"""
Wait until all of the keys are published, or until timeout.
Raises:
LookupError - if timeout occurs
"""
b64_keys = [self.prefix + self._encode(key) for key in keys]
kvs = self._try_wait_get(b64_keys, override_timeout)
if kvs is None:
raise LookupError("Timeout while waiting for keys in EtcdStore")
# No return value on success
def check(self, keys) -> bool:
"""Check if all of the keys are immediately present (without waiting)."""
b64_keys = [self.prefix + self._encode(key) for key in keys]
kvs = self._try_wait_get(
b64_keys,
override_timeout=datetime.timedelta(microseconds=1), # as if no wait
)
return kvs is not None
#
# Encode key/value data in base64, so we can store arbitrary binary data
# in EtcdStore. Input can be `str` or `bytes`.
# In case of `str`, utf-8 encoding is assumed.
#
def _encode(self, value) -> str:
if type(value) == bytes:
return b64encode(value).decode()
elif type(value) == str:
return b64encode(value.encode()).decode()
raise ValueError("Value must be of type str or bytes")
#
# Decode a base64 string (of type `str` or `bytes`).
# Return type is `bytes`, which is more convenient with the Store interface.
#
def _decode(self, value) -> bytes:
if type(value) == bytes:
return b64decode(value)
elif type(value) == str:
return b64decode(value.encode())
raise ValueError("Value must be of type str or bytes")
#
# Get all of the (base64-encoded) etcd keys at once, or wait until all the keys
# are published or timeout occurs.
# This is a helper method for the public interface methods.
#
# On success, a dictionary of {etcd key -> etcd value} is returned.
# On timeout, None is returned.
#
def _try_wait_get(self, b64_keys, override_timeout=None):
timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined]
deadline = time.time() + timeout.total_seconds()
while True:
# Read whole directory (of keys), filter only the ones waited for
all_nodes = self.client.get(key=self.prefix)
req_nodes = {
node.key: node.value
for node in all_nodes.children
if node.key in b64_keys
}
if len(req_nodes) == len(b64_keys):
# All keys are available
return req_nodes
watch_timeout = deadline - time.time()
if watch_timeout <= 0:
return None
try:
self.client.watch(
key=self.prefix,
recursive=True,
timeout=watch_timeout,
index=all_nodes.etcd_index + 1,
)
except etcd.EtcdWatchTimedOut:
if time.time() >= deadline:
return None
else:
continue
except etcd.EtcdEventIndexCleared:
continue

View File

@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .api import (
rendezvous_handler_registry as handler_registry,
RendezvousHandler,
RendezvousParameters,
)
from .dynamic_rendezvous import create_handler
__all__ = ["get_rendezvous_handler"]
def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import static_tcp_rendezvous
return static_tcp_rendezvous.create_rdzv_handler(params)
def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import etcd_rendezvous
return etcd_rendezvous.create_rdzv_handler(params)
def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
from .etcd_rendezvous_backend import create_backend
backend, store = create_backend(params)
return create_handler(store, backend, params)
def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
from .c10d_rendezvous_backend import create_backend
backend, store = create_backend(params)
return create_handler(store, backend, params)
def _register_default_handlers() -> None:
handler_registry.register("etcd", _create_etcd_handler)
handler_registry.register("etcd-v2", _create_etcd_v2_handler)
handler_registry.register("c10d", _create_c10d_handler)
handler_registry.register("static", _create_static_handler)
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
"""
Obtain a reference to a :py:class`RendezvousHandler`.
Custom rendezvous handlers can be registered by
::
from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
def create_my_rdzv(params: RendezvousParameters):
return MyCustomRdzv(params)
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
"""
return handler_registry.create_handler(params)

View File

@ -0,0 +1,128 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import logging
from typing import cast, Optional
from torch.distributed import PrefixStore, Store, TCPStore
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousInfo,
RendezvousParameters,
RendezvousStoreInfo,
)
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"]
logger = logging.getLogger(__name__)
_default_timeout_seconds = 600
class StaticTCPRendezvous(RendezvousHandler):
"""
Static rendezvous that is a wrapper around the TCPStore.
Creates TCPStore based on the input parameters with the
listener on the agent with group_rank=0
"""
def __init__(
self,
master_addr: str,
master_port: int,
rank: int,
world_size: int,
run_id: str,
timeout: int,
):
self.master_addr = master_addr
self.master_port = master_port
self.rank = rank
self.world_size = world_size
self.run_id = run_id
self.timeout = datetime.timedelta(seconds=timeout)
self._store: Optional[Store] = None
def get_backend(self) -> str:
return "static"
@property
def use_agent_store(self) -> bool:
return True
def next_rendezvous(self) -> RendezvousInfo:
logger.info("Creating TCPStore as the c10d::Store implementation")
is_master = self.rank == 0
if not self._store:
self._store = TCPStore( # type: ignore[call-arg]
self.master_addr,
self.master_port,
self.world_size,
is_master,
self.timeout,
multi_tenant=True,
)
store = PrefixStore(self.run_id, self._store)
# TCPStore server instance is used by trainer code
bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port)
return RendezvousInfo(
store,
self.rank,
self.world_size,
bootstrap_store_info,
)
def is_closed(self):
return False
def set_closed(self):
pass
def num_nodes_waiting(self):
return 0
def get_run_id(self) -> str:
return self.run_id
def shutdown(self) -> bool:
return True
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
if "rank" not in params.config:
raise ValueError(
"rank is absent in RendezvousParameters."
"Try add --node-rank to the cmd request"
)
endpoint = params.endpoint.strip()
if not endpoint:
raise ValueError(
"endpoint is absent in RendezvousParameters"
"Try add --master-port and --master-addr to the cmd request"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
if master_port == -1:
raise ValueError(
f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
)
world_size = params.max_nodes
rank = cast(int, params.config.get("rank"))
run_id = params.run_id
if "timeout" in params.config:
timeout = int(params.config["timeout"])
else:
timeout = _default_timeout_seconds
return StaticTCPRendezvous(
master_addr, master_port, rank, world_size, run_id, timeout
)

View File

@ -0,0 +1,284 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import ipaddress
import random
import re
import socket
import time
import weakref
from datetime import timedelta
from threading import Event, Thread
from typing import Any, Callable, Dict, Optional, Tuple, Union
__all__ = ["parse_rendezvous_endpoint"]
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
"""Extract key-value pairs from a rendezvous configuration string.
Args:
config_str:
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
"""
config: Dict[str, str] = {}
config_str = config_str.strip()
if not config_str:
return config
key_values = config_str.split(",")
for kv in key_values:
key, *values = kv.split("=", 1)
key = key.strip()
if not key:
raise ValueError(
"The rendezvous configuration string must be in format "
"<key1>=<value1>,...,<keyN>=<valueN>."
)
value: Optional[str]
if values:
value = values[0].strip()
else:
value = None
if not value:
raise ValueError(
f"The rendezvous configuration option '{key}' must have a value specified."
)
config[key] = value
return config
def _try_parse_port(port_str: str) -> Optional[int]:
"""Try to extract the port number from ``port_str``."""
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
return int(port_str)
return None
def parse_rendezvous_endpoint(
endpoint: Optional[str], default_port: int
) -> Tuple[str, int]:
"""Extract the hostname and the port number from a rendezvous endpoint.
Args:
endpoint:
A string in format <hostname>[:<port>].
default_port:
The port number to use if the endpoint does not include one.
Returns:
A tuple of hostname and port number.
"""
if endpoint is not None:
endpoint = endpoint.strip()
if not endpoint:
return ("localhost", default_port)
# An endpoint that starts and ends with brackets represents an IPv6 address.
if endpoint[0] == "[" and endpoint[-1] == "]":
host, *rest = endpoint, *[]
else:
host, *rest = endpoint.rsplit(":", 1)
# Sanitize the IPv6 address.
if len(host) > 1 and host[0] == "[" and host[-1] == "]":
host = host[1:-1]
if len(rest) == 1:
port = _try_parse_port(rest[0])
if port is None or port >= 2**16:
raise ValueError(
f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
"between 0 and 65536."
)
else:
port = default_port
if not re.match(r"^[\w\.:-]+$", host):
raise ValueError(
f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
"labels, an IPv4 address, or an IPv6 address."
)
return host, port
def _matches_machine_hostname(host: str) -> bool:
"""Indicate whether ``host`` matches the hostname of this machine.
This function compares ``host`` to the hostname as well as to the IP
addresses of this machine. Note that it may return a false negative if this
machine has CNAME records beyond its FQDN or IP addresses assigned to
secondary NICs.
"""
if host == "localhost":
return True
try:
addr = ipaddress.ip_address(host)
except ValueError:
addr = None
if addr and addr.is_loopback:
return True
try:
host_addr_list = socket.getaddrinfo(
host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
except (ValueError, socket.gaierror) as _:
host_addr_list = []
host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list]
this_host = socket.gethostname()
if host == this_host:
return True
addr_list = socket.getaddrinfo(
this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
for addr_info in addr_list:
# If we have an FQDN in the addr_info, compare it to `host`.
if addr_info[3] and addr_info[3] == host:
return True
# Otherwise if `host` represents an IP address, compare it to our IP
# address.
if addr and addr_info[4][0] == str(addr):
return True
# If the IP address matches one of the provided host's IP addresses
if addr_info[4][0] in host_ip_list:
return True
return False
def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
"""Suspend the current thread for ``seconds``.
Args:
seconds:
Either the delay, in seconds, or a tuple of a lower and an upper
bound within which a random delay will be picked.
"""
if isinstance(seconds, tuple):
seconds = random.uniform(*seconds)
# Ignore delay requests that are less than 10 milliseconds.
if seconds >= 0.01:
time.sleep(seconds)
class _PeriodicTimer:
"""Represent a timer that periodically runs a specified function.
Args:
interval:
The interval, in seconds, between each run.
function:
The function to run.
"""
# The state of the timer is hold in a separate context object to avoid a
# reference cycle between the timer and the background thread.
class _Context:
interval: float
function: Callable[..., None]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
stop_event: Event
_name: Optional[str]
_thread: Optional[Thread]
_finalizer: Optional[weakref.finalize]
# The context that is shared between the timer and the background thread.
_ctx: _Context
def __init__(
self,
interval: timedelta,
function: Callable[..., None],
*args: Any,
**kwargs: Any,
) -> None:
self._name = None
self._ctx = self._Context()
self._ctx.interval = interval.total_seconds()
self._ctx.function = function # type: ignore[assignment]
self._ctx.args = args or ()
self._ctx.kwargs = kwargs or {}
self._ctx.stop_event = Event()
self._thread = None
self._finalizer = None
@property
def name(self) -> Optional[str]:
"""Get the name of the timer."""
return self._name
def set_name(self, name: str) -> None:
"""Set the name of the timer.
The specified name will be assigned to the background thread and serves
for debugging and troubleshooting purposes.
"""
if self._thread:
raise RuntimeError("The timer has already started.")
self._name = name
def start(self) -> None:
"""Start the timer."""
if self._thread:
raise RuntimeError("The timer has already started.")
self._thread = Thread(
target=self._run,
name=self._name or "PeriodicTimer",
args=(self._ctx,),
daemon=True,
)
# We avoid using a regular finalizer (a.k.a. __del__) for stopping the
# timer as joining a daemon thread during the interpreter shutdown can
# cause deadlocks. The weakref.finalize is a superior alternative that
# provides a consistent behavior regardless of the GC implementation.
self._finalizer = weakref.finalize(
self, self._stop_thread, self._thread, self._ctx.stop_event
)
# We do not attempt to stop our background thread during the interpreter
# shutdown. At that point we do not even know whether it still exists.
self._finalizer.atexit = False
self._thread.start()
def cancel(self) -> None:
"""Stop the timer at the next opportunity."""
if self._finalizer:
self._finalizer()
@staticmethod
def _run(ctx) -> None:
while not ctx.stop_event.wait(ctx.interval):
ctx.function(*ctx.args, **ctx.kwargs)
@staticmethod
def _stop_thread(thread, stop_event):
stop_event.set()
thread.join()