I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,249 @@
# mypy: allow-untyped-defs
import logging
import os
import threading
import warnings
from datetime import timedelta
from typing import Generator, Tuple
from urllib.parse import urlparse
import torch
import torch.distributed as dist
__all__ = ["is_available"]
logger = logging.getLogger(__name__)
_init_counter = 0
_init_counter_lock = threading.Lock()
def is_available() -> bool:
return hasattr(torch._C, "_rpc_init")
if is_available() and not torch._C._rpc_init():
raise RuntimeError("Failed to initialize torch.distributed.rpc")
if is_available():
import numbers
import torch.distributed.autograd as dist_autograd
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import ( # noqa: F401
_cleanup_python_rpc_handler,
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_WORKER_THREADS,
_DEFAULT_RPC_TIMEOUT_SEC,
_delete_all_user_and_unforked_owner_rrefs,
_destroy_rref_context,
_disable_jit_rref_pickle,
_disable_server_process_global_profiler,
_enable_jit_rref_pickle,
_enable_server_process_global_profiler,
_get_current_rpc_agent,
_invoke_remote_builtin,
_invoke_remote_python_udf,
_invoke_remote_torchscript,
_invoke_rpc_builtin,
_invoke_rpc_python_udf,
_invoke_rpc_torchscript,
_is_current_rpc_agent_set,
_reset_current_rpc_agent,
_rref_context_get_debug_info,
_set_and_start_rpc_agent,
_set_profiler_node_id,
_set_rpc_timeout,
_TensorPipeRpcBackendOptionsBase,
_UNSET_RPC_TIMEOUT,
enable_gil_profiling,
get_rpc_timeout,
PyRRef,
RemoteProfilerManager,
RpcAgent,
RpcBackendOptions,
TensorPipeAgent,
WorkerInfo,
)
from . import api, backend_registry, functions
from .api import * # noqa: F401,F403
from .backend_registry import BackendType
from .options import TensorPipeRpcBackendOptions # noqa: F401
from .server_process_global_profiler import _server_process_global_profile
rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
__all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"]
__all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605
def init_rpc(
name,
backend=None,
rank=-1,
world_size=None,
rpc_backend_options=None,
):
r"""
Initializes RPC primitives such as the local RPC agent
and distributed autograd, which immediately makes the current
process ready to send and receive RPCs.
Args:
name (str): a globally unique name of this node. (e.g.,
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
Name can only contain number, alphabet, underscore, colon,
and/or dash, and must be shorter than 128 characters.
backend (BackendType, optional): The type of RPC backend
implementation. Supported values is
``BackendType.TENSORPIPE`` (the default).
See :ref:`rpc-backends` for more information.
rank (int): a globally unique id/rank of this node.
world_size (int): The number of workers in the group.
rpc_backend_options (RpcBackendOptions, optional): The options
passed to the RpcAgent constructor. It must be an agent-specific
subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
and contains agent-specific initialization configurations. By
default, for all agents, it sets the default timeout to 60
seconds and performs the rendezvous with an underlying process
group initialized using ``init_method = "env://"``,
meaning that environment variables ``MASTER_ADDR`` and
``MASTER_PORT`` need to be set properly. See
:ref:`rpc-backends` for more information and find which options
are available.
"""
torch._C._log_api_usage_once("torch.distributed.init_rpc")
if backend is not None and not isinstance(
backend, backend_registry.BackendType
):
raise TypeError("Argument backend must be a member of BackendType")
if rpc_backend_options is not None and not isinstance(
rpc_backend_options, RpcBackendOptions
):
raise TypeError(
"Argument rpc_backend_options must be an instance of RpcBackendOptions"
)
# Try to detect the backend from the options
if backend is None and rpc_backend_options is not None:
for candidate_backend in BackendType:
if isinstance(
rpc_backend_options,
type(
backend_registry.construct_rpc_backend_options(
candidate_backend
)
),
):
backend = candidate_backend
break
else:
raise TypeError(
f"Could not infer backend for options {rpc_backend_options}"
)
# Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined]
logger.warning(
"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined]
"corresponding to %(backend)s, hence that backend will be used "
"instead of the default BackendType.TENSORPIPE. To silence this "
"warning pass `backend=%(backend)s` explicitly.",
{"backend": backend},
)
if backend is None:
backend = BackendType.TENSORPIPE # type: ignore[attr-defined]
if rpc_backend_options is None:
# default construct a set of RPC backend options.
rpc_backend_options = backend_registry.construct_rpc_backend_options(
backend
)
# Create store, performs rendezvous for static RPC group.
if not world_size:
# If world_size is not set in construction and also not set in environment variables
# The store will be created for the dynamic group setting
store = dist._create_store_from_options(rpc_backend_options, rank)
else:
# This rendezvous state sometimes is destroyed before all processes
# finishing handshaking. To avoid that issue, we make it global to
# keep it alive.
global rendezvous_iterator
rendezvous_iterator = dist.rendezvous(
rpc_backend_options.init_method, rank=rank, world_size=world_size
)
store, _, _ = next(rendezvous_iterator)
# Use same timeout as RPC.
store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout))
# Use a PrefixStore to distinguish multiple invocations.
with _init_counter_lock:
global _init_counter
store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store)
_init_counter += 1
# Initialize autograd before RPC since _init_rpc_backend guarantees all
# processes sync via the store. If we initialize autograd after RPC,
# there could be a race where some nodes might have initialized autograd
# and others might not have. As a result, a node calling
# torch.distributed.autograd.backward() would run into errors since
# other nodes might not have been initialized.
dist_autograd._init(rank)
_set_profiler_node_id(rank)
# Initialize RPC.
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
type_mapping = {
backend: backend_registry.BackendType,
store: dist.Store,
name: str,
rank: numbers.Integral,
# world_size can be None for a dynamic group
world_size: (numbers.Integral, type(None)),
rpc_backend_options: RpcBackendOptions,
}
for arg, arg_type in type_mapping.items():
if not isinstance(arg, arg_type): # type: ignore[arg-type]
raise RuntimeError(
f"Argument {arg} must be of type {arg_type} but got type {type(arg)}"
)
def _init_rpc_backend(
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
store=None,
name=None,
rank=-1,
world_size=None,
rpc_backend_options=None,
):
_validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
if _is_current_rpc_agent_set():
raise RuntimeError("RPC is already initialized")
# Initialize RPC.
rpc_agent = backend_registry.init_backend(
backend,
store=store,
name=name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
api._init_rpc_states(rpc_agent)
@api._require_initialized
def _get_debug_info():
info = _rref_context_get_debug_info()
info.update(api._get_current_rpc_agent().get_debug_info())
info.update(dist_autograd._get_debug_info())
return info

View File

@ -0,0 +1,20 @@
# mypy: allow-untyped-defs
import torch
def is_available():
return hasattr(torch._C, "_faulty_agent_init")
if is_available() and not torch._C._faulty_agent_init():
raise RuntimeError("Failed to initialize torch.distributed.rpc._testing")
if is_available():
# Registers FAULTY_TENSORPIPE RPC backend.
from torch._C._distributed_rpc_testing import (
FaultyTensorPipeAgent,
FaultyTensorPipeRpcBackendOptions,
)
from . import faulty_agent_backend_registry

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
import torch.distributed as dist
import torch.distributed.rpc as rpc
def _faulty_tensorpipe_construct_rpc_backend_options_handler(
rpc_timeout,
init_method,
num_worker_threads,
messages_to_fail,
messages_to_delay,
num_fail_sends,
**kwargs,
):
from . import FaultyTensorPipeRpcBackendOptions
return FaultyTensorPipeRpcBackendOptions(
num_worker_threads=num_worker_threads,
rpc_timeout=rpc_timeout,
init_method=init_method,
messages_to_fail=messages_to_fail,
messages_to_delay=messages_to_delay,
num_fail_sends=num_fail_sends,
)
def _faulty_tensorpipe_init_backend_handler(
store, name, rank, world_size, rpc_backend_options
):
from torch.distributed.rpc import api
from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions
if not isinstance(store, dist.Store):
raise TypeError(f"`store` must be a c10d::Store. {store}")
if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions):
raise TypeError(
f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}"
)
agent = FaultyTensorPipeAgent(
store,
name,
rank,
world_size,
rpc_backend_options,
{}, # reverse_device_map
[], # devices
)
api._init_rpc_states(agent)
return agent
rpc.backend_registry.register_backend(
"FAULTY_TENSORPIPE",
_faulty_tensorpipe_construct_rpc_backend_options_handler,
_faulty_tensorpipe_init_backend_handler,
)

View File

@ -0,0 +1,47 @@
# mypy: allow-untyped-defs
import logging
from contextlib import contextmanager
from typing import cast
from . import api, TensorPipeAgent
logger = logging.getLogger(__name__)
@contextmanager
def _group_membership_management(store, name, is_join):
token_key = "RpcGroupManagementToken"
join_or_leave = "join" if is_join else "leave"
my_token = f"Token_for_{name}_{join_or_leave}"
while True:
# Retrieve token from store to signal start of rank join/leave critical section
returned = store.compare_set(token_key, "", my_token).decode()
if returned == my_token:
# Yield to the function this context manager wraps
yield
# Finished, now exit and release token
# Update from store to signal end of rank join/leave critical section
store.set(token_key, "")
# Other will wait for this token to be set before they execute
store.set(my_token, "Done")
break
else:
# Store will wait for the token to be released
try:
store.wait([returned])
except RuntimeError:
logger.error(
"Group membership token %s timed out waiting for %s to be released.",
my_token,
returned,
)
raise
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
ret = agent._update_group_membership(
worker_info, my_devices, reverse_device_map, is_join
)
return ret

View File

@ -0,0 +1,965 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import collections
import contextlib
import functools
import inspect
import logging
import threading
from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar
import torch
from torch._C._distributed_rpc import (
_cleanup_python_rpc_handler,
_delete_all_user_and_unforked_owner_rrefs,
_destroy_rref_context,
_get_current_rpc_agent,
_invoke_remote_builtin,
_invoke_remote_python_udf,
_invoke_remote_torchscript,
_invoke_rpc_builtin,
_invoke_rpc_python_udf,
_invoke_rpc_torchscript,
_is_current_rpc_agent_set,
_reset_current_rpc_agent,
_set_and_start_rpc_agent,
get_rpc_timeout,
PyRRef,
RemoteProfilerManager,
TensorPipeAgent,
WorkerInfo,
)
from torch.futures import Future
from ._utils import _group_membership_management, _update_group_membership
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
from .internal import (
_build_rpc_profiling_key,
_internal_rpc_pickler,
PythonUDF,
RPCExecMode,
)
__all__ = [
"shutdown",
"get_worker_info",
"remote",
"rpc_sync",
"rpc_async",
"RRef",
"AllGatherStates",
"method_factory",
"new_method",
]
logger = logging.getLogger(__name__)
# NB: Ignoring RRef leaks during shutdown. Without this, applications have to
# make sure there is no references to any RRef in the application code and
# Python GC has done its job to delete those RRefs. This is could result in bad
# debugging experiences especially when for large applications. Therefore, by
# default, we are going to ignore RRef leaks during shutdown. This is usually
# fine as shutdown means applications have done training and no longer care
# about states.
#
# To enable RRef leak checking, set this _ignore_rref_leak to False
_ignore_rref_leak = True
_default_pickler = _internal_rpc_pickler
@contextlib.contextmanager
def _use_rpc_pickler(rpc_pickler):
r"""
rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler
"""
global _default_pickler
_default_pickler = rpc_pickler
try:
yield
finally:
_default_pickler = _internal_rpc_pickler
def _require_initialized(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not _is_current_rpc_agent_set():
raise RuntimeError(
"RPC has not been initialized. Call "
"torch.distributed.rpc.init_rpc first."
)
return func(*args, **kwargs)
return wrapper
class AllGatherStates:
def __init__(self):
# Each `gathered_objects` is an empty dict at beginning.
# The leader worker is elected as the first worker in a sorted worker
# name list. Whenever there is a worker entering `_all_gather()`, it
# runs `_gather_to_leader()` on the leader to add its own name and
# data obj to this dict. The leader also adds itself's name to the dict
# on calling `_all_gather()`.
# Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader
# will broadcast the gathered dict to all follower workers and set their
# `gathered_objects` field and the `proceed_signal` field.
self.gathered_objects = {}
# All workers wait on this signal until it receives all gathered
# objects.
self.proceed_signal = threading.Event()
# States used by `def _all_gather()`.
# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer.
_ALL_WORKER_NAMES: Set[Any] = set()
_all_gather_dict_lock = threading.RLock()
_all_gather_sequence_id: Dict[str, int] = {}
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(
AllGatherStates
)
def _init_rpc_states(agent):
worker_infos = agent.get_worker_infos()
global _ALL_WORKER_NAMES
_ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
# NB: backend implementation might have already set the rpc_agent.
if not _is_current_rpc_agent_set():
_set_and_start_rpc_agent(agent)
def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
with _all_gather_dict_lock:
if not worker_names:
worker_names = _ALL_WORKER_NAMES
assert (
worker_name in worker_names
), f"{worker_name} is not expected by leader."
states = _all_gather_sequence_id_to_states[sequence_id]
assert (
worker_name not in states.gathered_objects
), f"{worker_name} reported intent sequence id {sequence_id} twice. "
states.gathered_objects[worker_name] = obj
if worker_names == set(states.gathered_objects.keys()):
states.proceed_signal.set()
def _broadcast_to_followers(sequence_id, objects_map):
with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states[sequence_id]
assert (
not states.proceed_signal.is_set()
), f"Termination signal sequence id {sequence_id} got set twice."
states.gathered_objects = objects_map
states.proceed_signal.set()
_thread_local_var = threading.local()
@contextlib.contextmanager
def _wait_all():
r"""
A context manager that collects all futures returned by ``rpc_async`` and
waits them on the context manager's exit; relieving the user of needing
to explicitly call wait.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> with rpc._wait_all():
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
>>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
>>> #fut_1 and fut_2 are waited on
"""
_thread_local_var.future_list = []
try:
yield
finally:
try:
torch.futures.wait_all(_thread_local_var.future_list)
finally:
del _thread_local_var.future_list
@_require_initialized
def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
r"""
This is similar to torch.distributed.all_gather(), but is using RPC. It
picks the worker with the smallest name (alphabetic order) as the leader.
Then all followers send their data ``obj`` to the leader. After the leader
has received all, it will broadcast the results back to all followers. This
function blocks until all workers have received the gathered results.
"""
if not worker_names:
assert (
_ALL_WORKER_NAMES is not None
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
worker_names = _ALL_WORKER_NAMES
leader_name = min(worker_names)
self_name = _get_current_rpc_agent().get_worker_info().name
with _all_gather_dict_lock:
concat_names = "".join(sorted(worker_names))
sequence_num = _all_gather_sequence_id.get(concat_names, 0)
_all_gather_sequence_id[concat_names] = sequence_num + 1
sequence_id = concat_names + str(sequence_num)
is_leader = leader_name == self_name
if timeout == UNSET_RPC_TIMEOUT:
# Timeout is specified by agent for RPC calls
rpc_timeout = get_rpc_timeout()
# No timeout for signal
signal_timeout = None
elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
# No timeout for RPC
rpc_timeout = timeout
# No timeout for signal
signal_timeout = None
else:
# Signal and RPC timeout use the same timeout
signal_timeout = rpc_timeout = timeout
# Phase 1: Followers send it's object to the leader
if is_leader:
_gather_to_leader(sequence_id, self_name, obj, worker_names)
else:
rpc_sync(
leader_name,
_gather_to_leader,
args=(sequence_id, self_name, obj, worker_names),
timeout=rpc_timeout,
)
with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states[sequence_id]
# Timeout is either set by function parameter or None (which is indefinite)
states.proceed_signal.wait(timeout=signal_timeout)
# Phase 2: Leader broadcast gathered results to all followers
# Leader's signal is the first to be unblocked, after receiving all
# followers' data objects.
if is_leader:
worker_name_to_response_future_dict = {}
for follower_name in worker_names - {leader_name}:
fut = rpc_async(
follower_name,
_broadcast_to_followers,
args=(sequence_id, states.gathered_objects),
timeout=rpc_timeout,
)
worker_name_to_response_future_dict[follower_name] = fut
errors = []
for follower_name, fut in worker_name_to_response_future_dict.items():
try:
fut.wait()
except RuntimeError as ex:
errors.append((follower_name, ex))
if errors:
raise RuntimeError(
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
)
# Clean up for the states using the sequence_id
with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states.pop(sequence_id)
return states.gathered_objects
@_require_initialized
def _barrier(worker_names):
r"""
Synchronizes local and remote RPC processes.
This will block until all local and remote RPC processes specified under worker_names
reach this method to wait for all outstanding work to complete.
Args:
worker_names (List[str]): The set of workers to synchronize.
"""
try:
_all_gather(None, set(worker_names))
except RuntimeError as ex:
logger.error("Failed to complete barrier, got error %s", ex)
@_require_initialized
def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
r"""
Block until all local and remote RPC processes reach this method and wait
for all outstanding work to complete. Every RPC process must call this
method before exit to perform a graceful shutdown. This should be used to
terminate the RPC framework, and there is no guarantee that the RPC
framework will work after this method returns.
"""
try:
_all_gather(None, timeout=timeout)
except RuntimeError as ex:
logger.error(
"Failed to respond to 'Shutdown Proceed' in time, got error %s", ex
)
raise ex
@_require_initialized
def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
r"""
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
stops the local agent from accepting outstanding requests, and shuts
down the RPC framework by terminating all RPC threads. If ``graceful=True``,
this will block until all local and remote RPC processes reach this method
and wait for all outstanding work to complete. Otherwise, if
``graceful=False``, this is a local shutdown, and it does not wait for other
RPC processes to reach this method.
.. warning::
For :class:`~torch.futures.Future` objects returned by
:meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
be called after ``shutdown()``.
Args:
graceful (bool): Whether to do a graceful shutdown or not. If True,
this will 1) wait until there is no pending system
messages for ``UserRRefs`` and delete them; 2) block
until all local and remote RPC processes have reached
this method and wait for all outstanding work to
complete.
Example::
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
API for more details. For example,
export MASTER_ADDR=localhost
export MASTER_PORT=5678
Then run the following code in two different processes:
>>> # xdoctest: +SKIP
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> # do some work
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
>>> # ready to shutdown
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> # wait for worker 0 to finish work, and then shutdown.
>>> rpc.shutdown()
"""
if graceful:
try:
agent = _get_current_rpc_agent()
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
_wait_all_workers(timeout)
_delete_all_user_and_unforked_owner_rrefs()
agent.join(shutdown=True, timeout=timeout)
else:
# This is a dynamic group so we need to grab the token for the operation
my_worker_info = agent.get_worker_info()
my_name = my_worker_info.name
with _group_membership_management(agent.store, my_name, False):
all_worker_infos = agent.get_worker_infos()
for worker in all_worker_infos:
if worker.name != my_name:
rpc_sync(
worker.name,
_update_group_membership,
args=(my_worker_info, [], {}, False),
)
agent.join(shutdown=True, timeout=timeout)
finally:
# In case of errors, continue to complete the local shutdown.
_finalize_shutdown()
else:
_finalize_shutdown()
def _finalize_shutdown():
try:
# This raises a `TORCH_CHECK()` exception on RRef leak detected.
_destroy_rref_context(_ignore_rref_leak)
finally:
_get_current_rpc_agent().shutdown()
# clean up python rpc handler in shutdown(), see comments in
# PythonRpcHandler::cleanup(), call it in python API because the
# cleanup() function has python dependency, it assumes python
# interpreter exists.
# No matter if RRef leak exception is raised, this clean-up code
# must run to avoid destruction segfault in Python 3.5.
#
# future.wait() should not be called after shutdown().
# pythonRpcHandler is cleaned up in shutdown(), after
# shutdown(), python objects returned from rpc python call can not be
# resolved.
_cleanup_python_rpc_handler()
_reset_current_rpc_agent()
@_require_initialized
def get_worker_info(worker_name=None):
r"""
Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
expensive string on every invocation.
Args:
worker_name (str): the string name of a worker. If ``None``, return the
the id of the current worker. (default ``None``)
Returns:
:class:`~torch.distributed.rpc.WorkerInfo` instance for the given
``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
current worker if ``worker_name`` is ``None``.
"""
if worker_name is not None:
return _get_current_rpc_agent().get_worker_info(worker_name)
else:
return _get_current_rpc_agent().get_worker_info()
def _to_worker_info(to):
if isinstance(to, WorkerInfo):
return to
elif isinstance(to, (str, int)):
return get_worker_info(to)
else:
raise ValueError(f"Cannot get WorkerInfo from name {to}")
def _rref_typeof_on_owner(rref, blocking: bool = True):
rref_type = type(rref.local_value())
if blocking:
return rref_type
else:
# Wrap result into a completed Future. This is so that if blocking=`False`
# is specified, we return a future regardless of if this call is on user
# or owner.
future = Future[type]()
future.set_result(rref_type)
return future
def _rref_typeof_on_user(
rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True
):
fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout)
if blocking:
return fut.wait()
else:
return fut
T = TypeVar("T")
GenericWithOneTypeVar = Generic[T]
if TYPE_CHECKING:
class RRef(PyRRef[T], Generic[T]):
pass
else:
try:
# Combine the implementation class and the type class.
class RRef(PyRRef, Generic[T]):
pass
except TypeError:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
# Mypy doesn't understand __class__ (mypy bug #4177)
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
pass
# Combine the implementation class and the type class.
# Types for classes expecting a certain generic parameter (mypy bug #7791)
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
pass
# Install docstrings from `PyRRef` to `RRef`.
#
# This is for the fact that pybind11 generates the parameter
# `self` as type `rpc.PyRRef`, so a `:inherited-members:`
# under `.. autoclass:: RRef` does not work.
# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`.
#
def method_factory(method_name, docstring):
def method(self, *args, **kwargs):
return getattr(super(RRef, self), method_name)(*args, **kwargs)
if method.__doc__:
method.__doc__ = docstring
return method
for method_name, method in inspect.getmembers(PyRRef):
# Ignore magic methods, except "__str__".
if method_name.startswith("_") and method_name != "__str__":
continue
# Get pybind11 generated docstring.
# It's like,
"""
to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object
Blocking call that copies the value of the RRef from the owner
to the local node and returns it. If the current node is the
owner, returns a reference to the local value.
"""
docstring = getattr(method, "__doc__", None)
assert docstring is not None, "RRef user-facing methods should all have docstrings."
# Do surgery on pybind11 generated docstrings.
docstring = docstring.replace(
"torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef"
)
# Attach user-facing RRef method with modified docstring.
new_method = method_factory(method_name, docstring)
setattr(RRef, method_name, new_method)
@_require_initialized
def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
r"""
Make a remote call to run ``func`` on worker ``to`` and return an
:class:`~torch.distributed.rpc.RRef` to the result value immediately.
Worker ``to`` will be the owner of the returned
:class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is
a user. The owner manages the global reference count of its
:class:`~torch.distributed.rpc.RRef`, and the owner
:class:`~torch.distributed.rpc.RRef` is only destructed when globally there
are no living references to it.
Args:
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
func (Callable): a callable function, such as Python callables, builtin
operators (e.g. :meth:`~torch.add`) and annotated
TorchScript functions.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
timeout (float, optional): timeout in seconds for this remote call. If the
creation of this
:class:`~torch.distributed.rpc.RRef` on worker
``to`` is not successfully processed on this
worker within this timeout, then the next time
there is an attempt to use the RRef (such as
``to_here()``), a timeout will be raised
indicating this failure. A value of 0 indicates
an infinite timeout, i.e. a timeout error will
never be raised. If not provided, the default
value set during initialization or with
``_set_rpc_timeout`` is used.
Returns:
A user :class:`~torch.distributed.rpc.RRef` instance to the result
value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here`
to retrieve the result value locally.
.. warning ::
The ``remote`` API does not copy storages of argument tensors until
sending them over the wire, which could be done by a different thread
depending on the RPC backend type. The caller should make sure that the
contents of those tensors stay intact until the returned RRef is
confirmed by the owner, which can be checked using the
:meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API.
.. warning ::
Errors such as timeouts for the ``remote`` API are handled on a
best-effort basis. This means that when remote calls initiated by
``remote`` fail, such as with a timeout error, we take a best-effort
approach to error handling. This means that errors are handled and set
on the resulting RRef on an asynchronous basis. If the RRef has not been
used by the application before this handling (such as ``to_here`` or
fork call), then future uses of the ``RRef`` will appropriately raise
errors. However, it is possible that the user application will use the
``RRef`` before the errors are handled. In this case, errors may not be
raised as they have not yet been handled.
Example::
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
API for more details. For example,
export MASTER_ADDR=localhost
export MASTER_PORT=5678
Then run the following code in two different processes:
>>> # xdoctest: +SKIP
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>> # On both workers:
>>> @torch.jit.script
>>> def my_script_add(tensor: torch.Tensor, scalar: int):
>>> return torch.add(tensor, scalar)
>>> # On worker 0:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))
>>> rref.to_here()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
torch._C._log_api_usage_once("torch.distributed.rpc_remote")
qualified_name = torch.jit._builtins._find_builtin(func)
dst_worker_info = _to_worker_info(to)
should_profile = _get_should_profile()
ctx_manager = _enable_rpc_profiler(
should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info
)
with ctx_manager as rf:
args = args if args else ()
kwargs = kwargs if kwargs else {}
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
if is_async_exec:
wrapped = func._wrapped_async_rpc_function
if isinstance(wrapped, torch.jit.ScriptFunction):
func = wrapped
if qualified_name is not None:
rref = _invoke_remote_builtin(
dst_worker_info, qualified_name, timeout, *args, **kwargs
)
elif isinstance(func, torch.jit.ScriptFunction):
rref = _invoke_remote_torchscript(
dst_worker_info.name,
torch._jit_internal._qualified_name(func),
timeout,
is_async_exec,
*args,
**kwargs,
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
rref = _invoke_remote_python_udf(
dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec
)
# attach profiling information
if should_profile:
assert torch.autograd._profiler_enabled()
assert rf is not None
fut = rf._call_end_callbacks_on_future(rref._get_future())
rref._set_profiling_future(fut)
return rref
def _invoke_rpc(
to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT
):
if not callable(func):
raise TypeError("function should be callable.")
qualified_name = torch.jit._builtins._find_builtin(func)
dst_worker_info = _to_worker_info(to)
should_profile = _get_should_profile()
ctx_manager = _enable_rpc_profiler(
should_profile, qualified_name, func, rpc_type, dst_worker_info
)
with ctx_manager as rf:
args = args if args else ()
kwargs = kwargs if kwargs else {}
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
if is_async_exec:
wrapped = func._wrapped_async_rpc_function
if isinstance(wrapped, torch.jit.ScriptFunction):
func = wrapped
if qualified_name is not None:
fut = _invoke_rpc_builtin(
dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs
)
elif isinstance(func, torch.jit.ScriptFunction):
fut = _invoke_rpc_torchscript(
dst_worker_info.name,
torch._jit_internal._qualified_name(func),
args,
kwargs,
rpc_timeout,
is_async_exec,
)
else:
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
fut = _invoke_rpc_python_udf(
dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec
)
if should_profile:
assert torch.autograd._profiler_enabled()
assert rf is not None
# Schedule profiling callbacks to run when the future completes.
# This returns a future that is completed when the original future
# completes and the profiling callbacks have been completed as well,
# to guarantee that fut.wait() completes the profiling. This new
# future will contain the same value as the original future.
fut = rf._call_end_callbacks_on_future(fut)
return fut
@_require_initialized
def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT):
r"""
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe.
Args:
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
func (Callable): a callable function, such as Python callables, builtin
operators (e.g. :meth:`~torch.add`) and annotated
TorchScript functions.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
timeout (float, optional): timeout in seconds to use for this RPC. If
the RPC does not complete in this amount of
time, an exception indicating it has
timed out will be raised. A value of 0
indicates an infinite timeout, i.e. a timeout
error will never be raised. If not provided,
the default value set during initialization
or with ``_set_rpc_timeout`` is used.
Returns:
Returns the result of running ``func`` with ``args`` and ``kwargs``.
Example::
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
API for more details. For example,
export MASTER_ADDR=localhost
export MASTER_PORT=5678
Then run the following code in two different processes:
>>> # xdoctest: +SKIP
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>> # On both workers:
>>> @torch.jit.script
>>> def my_script_add(tensor: torch.Tensor, scalar: int):
>>> return torch.add(tensor, scalar)
>>> # On worker 0:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
torch._C._log_api_usage_once("torch.distributed.rpc_sync")
fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
return fut.wait()
@_require_initialized
def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
r"""
Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe. This method will immediately return a
:class:`~torch.futures.Future` that can be awaited on.
Args:
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
func (Callable): a callable function, such as Python callables, builtin
operators (e.g. :meth:`~torch.add`) and annotated
TorchScript functions.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
timeout (float, optional): timeout in seconds to use for this RPC. If
the RPC does not complete in this amount of
time, an exception indicating it has
timed out will be raised. A value of 0
indicates an infinite timeout, i.e. a timeout
error will never be raised. If not provided,
the default value set during initialization
or with ``_set_rpc_timeout`` is used.
Returns:
Returns a :class:`~torch.futures.Future` object that can be waited
on. When completed, the return value of ``func`` on ``args`` and
``kwargs`` can be retrieved from the :class:`~torch.futures.Future`
object.
.. warning ::
Using GPU tensors as arguments or return values of ``func`` is not
supported since we don't support sending GPU tensors over the wire. You
need to explicitly copy GPU tensors to CPU before using them as
arguments or return values of ``func``.
.. warning ::
The ``rpc_async`` API does not copy storages of argument tensors until
sending them over the wire, which could be done by a different thread
depending on the RPC backend type. The caller should make sure that the
contents of those tensors stay intact until the returned
:class:`~torch.futures.Future` completes.
Example::
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
API for more details. For example,
export MASTER_ADDR=localhost
export MASTER_PORT=5678
Then run the following code in two different processes:
>>> # xdoctest: +SKIP
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>> # On both workers:
>>> @torch.jit.script
>>> def my_script_add(tensor: torch.Tensor, scalar: int):
>>> return torch.add(tensor, scalar)
>>> # On worker 0:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
>>> ret = fut.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
torch._C._log_api_usage_once("torch.distributed.rpc_async")
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
if hasattr(_thread_local_var, "future_list"):
_thread_local_var.future_list.append(fut)
return fut
def _get_should_profile():
# Legacy profiler should be enabled. RPC profiling is not supported with
# Kineto profiler.
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
return (
torch.autograd._profiler_enabled()
and torch._C._autograd._profiler_type()
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
)
def _enable_rpc_profiler(
should_profile, qualified_name, func, rpc_type, dst_worker_info
):
ctx_manager = contextlib.nullcontext()
if should_profile:
# Create appropriate string representation based on type of func
# (builtin, script, python)
if qualified_name is None:
func_name = (
torch._jit_internal._qualified_name(func)
if isinstance(func, torch.jit.ScriptFunction)
else func.__qualname__
)
else:
func_name = qualified_name
# Build RPC profiling key.
rpc_profiling_key = _build_rpc_profiling_key(
rpc_type,
func_name,
get_worker_info().name,
dst_worker_info.name,
)
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
# Mypy doesn't support re-def of a variable not in the same block (#1174)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
return ctx_manager

View File

@ -0,0 +1,432 @@
# mypy: allow-untyped-defs
import collections
import enum
from typing import cast, Dict, List, Set, Tuple
import torch
import torch.distributed as dist
from . import api, constants as rpc_constants
from ._utils import _group_membership_management, _update_group_membership
__all__ = [
"backend_registered",
"register_backend",
"construct_rpc_backend_options",
"init_backend",
"BackendValue",
"BackendType",
]
BackendValue = collections.namedtuple(
"BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
)
def _backend_type_repr(self):
return "BackendType." + self.name
_backend_type_doc = """
An enum class of available backends.
PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend.
Additional ones can be registered using the
:func:`~torch.distributed.rpc.backend_registry.register_backend` function.
"""
# Create an enum type, `BackendType`, with empty members.
# Can't handle Function Enum API (mypy bug #9079)
BackendType = enum.Enum(value="BackendType", names={}) # type: ignore[misc]
# Unable to assign a function a method (mypy bug #2427)
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
if BackendType.__doc__:
BackendType.__doc__ = _backend_type_doc
def backend_registered(backend_name):
"""
Checks if backend_name is registered as an RPC backend.
Args:
backend_name (str): string to identify the RPC backend.
Returns:
True if the backend has been registered with ``register_backend``, else
False.
"""
return backend_name in BackendType.__members__.keys()
def register_backend(
backend_name, construct_rpc_backend_options_handler, init_backend_handler
):
"""Registers a new RPC backend.
Args:
backend_name (str): backend string to identify the handler.
construct_rpc_backend_options_handler (function):
Handler that is invoked when
rpc_backend.construct_rpc_backend_options(**dict) is called.
init_backend_handler (function): Handler that is invoked when the
`_init_rpc_backend()` function is called with a backend.
This returns the agent.
"""
global BackendType
if backend_registered(backend_name):
raise RuntimeError(f"RPC backend {backend_name}: already registered")
# Create a new enum type, `BackendType`, with extended members.
existing_enum_dict = {member.name: member.value for member in BackendType}
extended_enum_dict = dict(
{
backend_name: BackendValue(
construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,
init_backend_handler=init_backend_handler,
)
},
**existing_enum_dict,
)
# Can't handle Function Enum API (mypy bug #9079)
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc]
# Unable to assign a function a method (mypy bug #2427)
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
if BackendType.__doc__:
BackendType.__doc__ = _backend_type_doc
return BackendType[backend_name]
def construct_rpc_backend_options(
backend,
rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
init_method=rpc_constants.DEFAULT_INIT_METHOD,
**kwargs,
):
return backend.value.construct_rpc_backend_options_handler(
rpc_timeout, init_method, **kwargs
)
def init_backend(backend, *args, **kwargs):
return backend.value.init_backend_handler(*args, **kwargs)
def _init_process_group(store, rank, world_size):
# Initialize ProcessGroup.
process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
# We're using a bunch of private APIs here since `new_group` requires the
# default group to be initialized.
group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout)
assert group is not None, "Failed to initialize default ProcessGroup."
if (rank != -1) and (rank != group.rank()):
raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}")
if (world_size != -1) and (world_size != group.size()):
raise RuntimeError(
f"world_size argument {world_size} doesn't match pg size {group.size()}"
)
return group
def _tensorpipe_construct_rpc_backend_options_handler(
rpc_timeout,
init_method,
num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS,
_transports=None,
_channels=None,
**kwargs,
):
from . import TensorPipeRpcBackendOptions
return TensorPipeRpcBackendOptions(
rpc_timeout=rpc_timeout,
init_method=init_method,
num_worker_threads=num_worker_threads,
_transports=_transports,
_channels=_channels,
)
def _tensorpipe_validate_devices(devices, device_count):
return all(
d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count)
for d in devices
)
# detect if any worker has invalid device_map configurations, and return
# reverse device maps
def _tensorpipe_exchange_and_check_all_device_maps(
my_name, my_device_count, my_device_maps, my_devices, group
):
gathered: List[
Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]]
] = [("", 0, {}, []) for _ in range(group.size())]
dist.all_gather_object(
gathered, (my_name, my_device_count, my_device_maps, my_devices), group
)
all_names = [name for name, _, _, _ in gathered]
all_device_counts = {name: count for name, count, _, _ in gathered}
all_device_maps = {name: map_ for name, _, map_, _ in gathered}
all_devices = {name: devices for name, _, _, devices in gathered}
_validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)
# passed all checked, construct reverse mapping and get list of devices handled by this agent
reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps)
return reverse_device_maps, my_devices
def _validate_device_maps(
all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True
):
for node in all_names:
devices = all_devices[node]
if len(set(devices)) != len(devices):
raise ValueError(
f"Node {node} has duplicated devices\n" f"devices = {devices}"
)
if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
raise ValueError(
f"Node {node} has devices with invalid indices\n"
f"devices = {devices}\n"
f"device count = {all_device_counts[node]}"
)
for source_node in all_names:
# For dynamic group (non-static) do not check the target node name since it may not have joined yet
if is_static_group and not set(all_device_maps[source_node].keys()).issubset(
all_names
):
raise ValueError(
f"Node {source_node} has invalid target node names in its device maps\n"
f"device maps = {all_device_maps[source_node].keys()}\n"
f"node names = {all_names}"
)
for target_node, map_ in all_device_maps[source_node].items():
if len(set(map_.values())) != len(map_):
raise ValueError(
f"Node {source_node} has duplicated target devices "
f"in its device map for {target_node}\n"
f"device map = {map_}"
)
if all_devices[source_node]:
if not set(map_.keys()).issubset(all_devices[source_node]):
raise ValueError(
f"Node {source_node} has unexpected source devices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"devices = {all_devices[source_node]}"
)
elif not _tensorpipe_validate_devices(
map_.keys(), all_device_counts[source_node]
):
raise ValueError(
f"Node {source_node} has source devices with invalid indices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"device count = {all_device_counts[source_node]}"
)
if all_devices.get(target_node, []):
if not set(map_.values()).issubset(all_devices[target_node]):
raise ValueError(
f"Node {source_node} has unexpected target devices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"devices = {all_devices[target_node]}"
)
elif target_node in all_device_counts and not _tensorpipe_validate_devices(
map_.values(), all_device_counts[target_node]
):
raise ValueError(
f"Node {source_node} has target devices with invalid indices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"device count = {all_device_counts[target_node]}"
)
def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
if not my_devices:
devices_set: Set[torch.device] = set()
for map_ in my_device_maps.values():
devices_set.update(map_.keys())
for map_ in reverse_device_maps.values():
devices_set.update(map_.keys())
devices_set.discard(torch.device("cpu"))
my_devices = list(devices_set)
my_devices = sorted(my_devices, key=lambda d: d.index)
return my_devices
def _create_reverse_mapping(my_name, all_names, all_device_maps):
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
for node in all_names:
if my_name in all_device_maps[node]:
reverse_device_maps[node] = {
v: k for k, v in all_device_maps[node][my_name].items()
}
return reverse_device_maps
def _get_device_infos():
from . import TensorPipeAgent
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
opts = agent._get_backend_options()
device_count = torch.cuda.device_count()
if torch.cuda.is_available() and opts.devices:
torch.cuda.init()
return device_count, opts.device_maps, opts.devices
def _set_devices_and_reverse_device_map(agent):
from . import TensorPipeAgent
agent = cast(TensorPipeAgent, agent)
# Group state is retrieved from local agent
# On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid
my_worker_info = agent.get_worker_info()
my_name = my_worker_info.name
all_worker_infos = agent.get_worker_infos()
# One round to get device_maps of all workers and construct reverse device maps
all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
for worker_info in all_worker_infos:
worker_name = worker_info.name
if worker_name != my_name:
# TODO: make async?
device_count, device_map, devices = api.rpc_sync(
worker_name, _get_device_infos
)
else:
opts = agent._get_backend_options()
device_count, device_map, devices = (
torch.cuda.device_count(),
opts.device_maps,
opts.devices,
)
all_device_counts[worker_name] = device_count
all_device_maps[worker_name] = device_map
all_devices[worker_name] = devices
all_names.append(worker_name)
_validate_device_maps(
all_names,
all_device_counts,
all_device_maps,
all_devices,
is_static_group=False,
)
reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
# Perform RPC call to all workers, including itself, to include newly joined worker information and device maps
for worker_name in all_names:
# Set device list for each worker
all_devices[worker_name] = _create_device_list(
all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps
)
api.rpc_sync(
worker_name,
_update_group_membership,
args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True),
)
def _tensorpipe_init_backend_handler(
store, name, rank, world_size, rpc_backend_options
):
from . import TensorPipeAgent, TensorPipeRpcBackendOptions
if not isinstance(store, dist.Store):
raise TypeError(f"`store` must be a c10d::Store. {store}")
if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions):
raise TypeError(
f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}"
)
device_count = torch.cuda.device_count()
is_static_group = True if world_size else False
# world_size is specified so this is a static group (ranks cannot join and leave)
if is_static_group:
# The agent's join method is required to behave like a barrier and perform
# collective operations, for which it relies on a process group, instead of
# re-implementing this on top of RPCs.
group = _init_process_group(store, rank, world_size)
reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
name,
device_count,
rpc_backend_options.device_maps,
rpc_backend_options.devices,
group,
)
if torch.cuda.is_available() and devices:
# It's necessary to initialize PyTorch CUDA states here (e.g.,
# CUDACachingAllocator). If this is missing, we could hit errors like
# "allocator not initialized", because other processes might send
# CUDA-related RPC request to this process before user code in this
# process initializes its PyTorch CUDA states.
torch.cuda.init()
# TODO: add try-except and destroy _agent in all processes if any fails.
agent = TensorPipeAgent(
store,
name,
rank,
world_size,
rpc_backend_options,
reverse_device_maps,
devices,
)
api._init_rpc_states(agent)
# Run one dummy round of RPC to initialize channels/transports. Without
# this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC
# on that process before rpc.shutdown(), as the agent initialization can
# take longer than 5s.
api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
# Need a barrier here to make sure no peers leave before the rank0 finishes
# _all_gather
group.barrier().wait()
return agent
# initialization for dynamic rpc (ranks can join and leave)
else:
with _group_membership_management(store, name, True):
# Construct TPAgent with empty reverse_device_map and devices
# these properties will be updated after initialization
agent = TensorPipeAgent(
store,
name,
rank,
world_size,
rpc_backend_options,
{},
[],
)
api._init_rpc_states(agent)
try:
# Notify all workers in group this rank has joined and set devices and reverse_device_map
# This is a synchronous operation that completes once all existing ranks are updated
_set_devices_and_reverse_device_map(agent)
except Exception:
api.shutdown()
raise
return agent
register_backend(
"TENSORPIPE",
_tensorpipe_construct_rpc_backend_options_handler,
_tensorpipe_init_backend_handler,
)

View File

@ -0,0 +1,25 @@
from datetime import timedelta
from typing import List
from torch._C._distributed_rpc import (
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_WORKER_THREADS,
_DEFAULT_RPC_TIMEOUT_SEC,
_UNSET_RPC_TIMEOUT,
)
# For any RpcAgent.
DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC
DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD
DEFAULT_SHUTDOWN_TIMEOUT: float = 0
# For TensorPipeAgent.
DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS
# Ensure that we don't time out when there are long periods of time without
# any operations against the underlying ProcessGroup.
DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1)
# Value indicating that timeout is not set for RPC call, and the default should be used.
UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT
__all__: List[str] = []

View File

@ -0,0 +1,169 @@
# mypy: allow-untyped-defs
import functools
def async_execution(fn):
r"""
A decorator for a function indicating that the return value of the function
is guaranteed to be a :class:`~torch.futures.Future` object and this
function can run asynchronously on the RPC callee. More specifically, the
callee extracts the :class:`~torch.futures.Future` returned by the wrapped
function and installs subsequent processing steps as a callback to that
:class:`~torch.futures.Future`. The installed callback will read the value
from the :class:`~torch.futures.Future` when completed and send the
value back as the RPC response. That also means the returned
:class:`~torch.futures.Future` only exists on the callee side and is never
sent through RPC. This decorator is useful when the wrapped function's
(``fn``) execution needs to pause and resume due to, e.g., containing
:meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
.. note:: To enable asynchronous execution, applications must pass the
function object returned by this decorator to RPC APIs. If RPC detected
attributes installed by this decorator, it knows that this function
returns a ``Future`` object and will handle that accordingly.
However, this does not mean this decorator has to be outmost one when
defining a function. For example, when combined with ``@staticmethod``
or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
inner decorator to allow the target function be recognized as a static
or class function. This target function can still execute asynchronously
because, when accessed, the static or class method preserves attributes
installed by ``@rpc.functions.async_execution``.
Example::
The returned :class:`~torch.futures.Future` object can come from
:meth:`~torch.distributed.rpc.rpc_async`,
:meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
constructor. The example below shows directly using the
:class:`~torch.futures.Future` returned by
:meth:`~torch.futures.Future.then`.
>>> from torch.distributed import rpc
>>>
>>> # omitting setup and shutdown RPC
>>>
>>> # On all workers
>>> @rpc.functions.async_execution
>>> def async_add_chained(to, x, y, z):
>>> # This function runs on "worker1" and returns immediately when
>>> # the callback is installed through the `then(cb)` API. In the
>>> # mean time, the `rpc_async` to "worker2" can run concurrently.
>>> # When the return value of that `rpc_async` arrives at
>>> # "worker1", "worker1" will run the lambda function accordingly
>>> # and set the value for the previously returned `Future`, which
>>> # will then trigger RPC to send the result back to "worker0".
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> # On worker0
>>> # xdoctest: +SKIP
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> async_add_chained,
>>> args=("worker2", torch.ones(2), 1, 1)
>>> )
>>> print(ret) # prints tensor([3., 3.])
When combined with TorchScript decorators, this decorator must be the
outmost one.
>>> from torch import Tensor
>>> from torch.futures import Future
>>> from torch.distributed import rpc
>>>
>>> # omitting setup and shutdown RPC
>>>
>>> # On all workers
>>> @torch.jit.script
>>> def script_add(x: Tensor, y: Tensor) -> Tensor:
>>> return x + y
>>>
>>> @rpc.functions.async_execution
>>> @torch.jit.script
>>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
>>> return rpc.rpc_async(to, script_add, (x, y))
>>>
>>> # On worker0
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> async_add,
>>> args=("worker2", torch.ones(2), 1)
>>> )
>>> print(ret) # prints tensor([2., 2.])
When combined with static or class method, this decorator must be the
inner one.
>>> from torch.distributed import rpc
>>>
>>> # omitting setup and shutdown RPC
>>>
>>> # On all workers
>>> class AsyncExecutionClass:
>>>
>>> @staticmethod
>>> @rpc.functions.async_execution
>>> def static_async_add(to, x, y, z):
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> @classmethod
>>> @rpc.functions.async_execution
>>> def class_async_add(cls, to, x, y, z):
>>> ret_fut = torch.futures.Future()
>>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: ret_fut.set_result(fut.wait() + z)
>>> )
>>> return ret_fut
>>>
>>> @rpc.functions.async_execution
>>> def bound_async_add(self, to, x, y, z):
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> # On worker0
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> AsyncExecutionClass.static_async_add,
>>> args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret) # prints tensor([4., 4.])
>>>
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> AsyncExecutionClass.class_async_add,
>>> args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret) # prints tensor([4., 4.])
This decorator also works with RRef helpers, i.e., .
:meth:`torch.distributed.rpc.RRef.rpc_sync`,
:meth:`torch.distributed.rpc.RRef.rpc_async`, and
:meth:`torch.distributed.rpc.RRef.remote`.
>>> from torch.distributed import rpc
>>>
>>> # reuse the AsyncExecutionClass class above
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
>>> print(ret) # prints tensor([4., 4.])
>>>
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
>>> print(ret) # prints tensor([4., 4.])
>>>
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
>>> print(ret) # prints tensor([4., 4.])
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
# Can't declare and use attributes of function objects (mypy#2087)
wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
return wrapper

View File

@ -0,0 +1,285 @@
# mypy: allow-untyped-defs
import collections
import copyreg
import io
import pickle
import sys
import threading
import traceback
from enum import Enum
import torch
import torch.distributed as dist
from torch._C._distributed_rpc import _get_current_rpc_agent
__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"]
# Thread local tensor tables to store tensors while pickling torch.Tensor
# objects
_thread_local_tensor_tables = threading.local()
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
class RPCExecMode(Enum):
SYNC = "sync"
ASYNC = "async"
ASYNC_JIT = "async_jit"
REMOTE = "remote"
class _InternalRPCPickler:
r"""
This class provides serialize() and deserialize() interfaces to serialize
data to be "binary string + tensor table" format
So for RPC python UDF function and args, non tensor data will be serialized
into regular binary string, tensor data will be put into thread local tensor
tables, this serialization format is consistent with builtin operator and args
using JIT pickler. This format will make tensor handling in C++ much easier,
e.g. attach tensor to distributed autograd graph in C++
"""
def __init__(self):
# Ignore type error because dispatch_table is defined in third-party package
self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
self._dispatch_table[torch.Tensor] = self._tensor_reducer
# Used for registering customized picklers.
self._class_reducer_dict = {}
def _register_reducer(self, obj_class, reducer):
# For the same class, only register the reducer once.
if obj_class not in self._class_reducer_dict:
self._class_reducer_dict[obj_class] = reducer
@classmethod
def _tensor_receiver(cls, tensor_index):
global _thread_local_tensor_tables
return _thread_local_tensor_tables.recv_tables[tensor_index]
def _tensor_reducer(self, tensor):
global _thread_local_tensor_tables
_thread_local_tensor_tables.send_tables.append(tensor)
tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
return (_InternalRPCPickler._tensor_receiver, (tensor_index,))
@classmethod
def _py_rref_receiver(cls, rref_fork_data):
return dist.rpc.PyRRef._deserialize(rref_fork_data)
def _py_rref_reducer(self, py_rref):
rref_fork_data = py_rref._serialize()
return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))
def _rref_reducer(self, rref):
return self._py_rref_reducer(rref)
@classmethod
def _script_module_receiver(cls, script_module_serialized):
"""
Given a serialized representation of a ScriptModule created with torch.jit.save,
loads and returns the ScriptModule.
"""
f = io.BytesIO(script_module_serialized)
m = torch.jit.load(f)
return m
def _script_module_reducer(self, script_module):
"""
Serializes a ScriptModule.
"""
f = io.BytesIO()
torch.jit.save(script_module, f)
return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),))
def serialize(self, obj):
r"""
Serialize non tensor data into binary string, tensor data into
tensor table
"""
f = io.BytesIO()
p = _pickler(f)
p.dispatch_table = self._dispatch_table
# rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
# user picklers could have different initialization function from _InternalRPCPickler,
# but all the user picklers should call serialize() and use _rref_reducer to pickle rref
# in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
# compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor,
# so putting rref's dispatch table here
#
# The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
# The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
# Ignore type error because dispatch_table is defined in third-party package
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index]
# An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
# Ignore type error because dispatch_table is defined in third-party package
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index]
# Add dispatch pickling for ScriptModule or its subclass.
if isinstance(obj, torch.jit.ScriptModule):
# Ignore type error because dispatch_table is defined in third-party package
p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index]
# Install customized picklers.
for class_name in self._class_reducer_dict.keys():
p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index]
# save _thread_local_tensor_tables.send_tables if it is in nested call
global _thread_local_tensor_tables
if hasattr(_thread_local_tensor_tables, "send_tables"):
old_send_tables = _thread_local_tensor_tables.send_tables
else:
old_send_tables = None
_thread_local_tensor_tables.send_tables = []
p.dump(obj)
# restore _thread_local_tensor_tables.send_tables if return
# from nested call, otherwise clean up the table
tensors = _thread_local_tensor_tables.send_tables
if old_send_tables is not None:
_thread_local_tensor_tables.send_tables = old_send_tables
else:
del _thread_local_tensor_tables.send_tables
return (f.getvalue(), tensors)
def deserialize(self, binary_data, tensor_table):
r"""
Deserialize binary string + tensor table to original obj
"""
# save _thread_local_tensor_tables.recv_tables if it is in nested call
global _thread_local_tensor_tables
if hasattr(_thread_local_tensor_tables, "recv_tables"):
old_recv_tables = _thread_local_tensor_tables.recv_tables
else:
old_recv_tables = None
_thread_local_tensor_tables.recv_tables = tensor_table
try:
unpickler = _unpickler(io.BytesIO(binary_data))
ret = unpickler.load()
except AttributeError as e:
# Occurs when function is not found on module/class during
# unpickling.
except_str = (
str(e)
+ """ Default RPC pickler does not serialize
function code. Ensure that UDFs are defined on both caller and
callee modules."""
)
ret = AttributeError(except_str)
# Ensure the stack trace gets preserved
ret.__cause__ = e
# restore _thread_local_tensor_tables.recv_tables if return
# from nested call, otherwise clean up the table
if old_recv_tables is not None:
_thread_local_tensor_tables.recv_tables = old_recv_tables
else:
del _thread_local_tensor_tables.recv_tables
return ret
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
_internal_rpc_pickler = _InternalRPCPickler()
def serialize(obj):
return _internal_rpc_pickler.serialize(obj)
def deserialize(binary_data, tensor_table):
return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
def _run_function(python_udf):
r"""
This function is exclusively called from C++.
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
Runs a Python UDF and returns its return value.
Wraps any exception in ``RemoteException`` if the function raises.
"""
try:
if isinstance(python_udf, AttributeError):
raise python_udf
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
except Exception as e:
# except str = exception info + traceback string
except_str = (
f"On {_get_current_rpc_agent().get_worker_info()}:\n"
f"{repr(e)}\n{traceback.format_exc()}"
)
print(except_str, file=sys.stderr)
result = RemoteException(except_str, type(e))
return result
def _handle_exception(result):
if isinstance(result, RemoteException):
exception_msg = result.msg.encode("utf-8").decode("unicode_escape")
# We wrap exception re-creation here in case some exception classes
# cannot be constructed directly from a string.
exc = None
try:
exc = result.exception_type(exception_msg)
except BaseException as e:
raise RuntimeError( # noqa: B904
f"Failed to create original exception type. Error msg was {str(e)}"
f" Original exception on remote side was {exception_msg}"
) from e
if exc is not None:
raise exc
def _build_rpc_profiling_key(
exec_type, func_name, current_worker_name, dst_worker_name
):
"""
Builds the key that RPC calls are profiled with using the autograd profiler.
This will be the name of the corresponding Event recorded in the profiler.
Args:
exec_type (RPCExecMode): Type of RPC/RRef call
func_name (str): Name of function being profiled.
current_worker_name (str): Name of current worker.
dst_worker_name (str): Name of the destination worker.
Returns:
String representing profiling key
"""
profile_key = (
f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})"
)
return profile_key
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
"""
This function should be called from RPC/RRef functions to create a
RecordFunction object for profiling. This function also runs the before
callbacks that start the profiling, though the user is responsible for
running the appropriate callbacks when the function to be profiled finishes.
Args:
exec_type (RPCExecMode): Type of RPC/RRef call
func_name (str): Name of function being profiled.
current_worker_name (str): Name of current worker.
dest_worker_name (str): Name of the destination worker.
Returns:
An instance of `torch.autograd._RecordFunction`.
"""
assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})"
rf = torch.autograd._RecordFunction() # type: ignore[attr-defined]
torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined]
return rf
PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])

View File

@ -0,0 +1,175 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional, Union
import torch
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
from . import constants as rpc_contants
DeviceType = Union[int, str, torch.device]
__all__ = ["TensorPipeRpcBackendOptions"]
def _to_device(device: DeviceType) -> torch.device:
device = torch.device(device)
if device.type != "cuda":
raise ValueError(
"`set_devices` expect a list of CUDA devices, but got "
f"device type {device.type}."
)
return device
def _to_device_map(
device_map: Dict[DeviceType, DeviceType]
) -> Dict[torch.device, torch.device]:
full_device_map: Dict[torch.device, torch.device] = {}
reverse_map: Dict[torch.device, torch.device] = {}
for k, v in device_map.items():
k, v = torch.device(k), torch.device(v)
if v in reverse_map:
raise ValueError(
"`device_map` only supports 1-to-1 mapping, "
f"trying to map {k} and {reverse_map[v]} to {v}"
)
full_device_map[k] = v
reverse_map[v] = k
return full_device_map
def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
return list(map(_to_device, devices))
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
r"""
The backend options for
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
:class:`~torch.distributed.rpc.RpcBackendOptions`.
Args:
num_worker_threads (int, optional): The number of threads in the
thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests (default: 16).
rpc_timeout (float, optional): The default timeout, in seconds,
for RPC requests (default: 60 seconds). If the RPC has not
completed in this timeframe, an exception indicating so will
be raised. Callers can override this timeout for individual
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
init_method (str, optional): The URL to initialize the distributed
store used for rendezvous. It takes any value accepted for the
same argument of :meth:`~torch.distributed.init_process_group`
(default: ``env://``).
device_maps (Dict[str, Dict], optional): Device placement mappings from
this worker to the callee. Key is the callee worker name and value
the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)
that maps this worker's devices to the callee worker's devices.
(default: ``None``)
devices (List[int, str, or ``torch.device``], optional): all local
CUDA devices used by RPC agent. By Default, it will be initialized
to all local devices from its own ``device_maps`` and corresponding
devices from its peers' ``device_maps``. When processing CUDA RPC
requests, the agent will properly synchronize CUDA streams for
all devices in this ``List``.
"""
def __init__(
self,
*,
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
devices: Optional[List[DeviceType]] = None,
_transports: Optional[List] = None,
_channels: Optional[List] = None,
):
full_device_maps = (
{}
if device_maps is None
else {k: _to_device_map(v) for k, v in device_maps.items()}
)
full_device_list = [] if devices is None else _to_device_list(devices)
super().__init__(
num_worker_threads,
_transports,
_channels,
rpc_timeout,
init_method,
full_device_maps,
full_device_list,
)
def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
r"""
Set device mapping between each RPC caller and callee pair. This
function can be called multiple times to incrementally add
device placement configurations.
Args:
to (str): Callee name.
device_map (Dict of int, str, or torch.device): Device placement
mappings from this worker to the callee. This map must be
invertible.
Example:
>>> # xdoctest: +SKIP("distributed")
>>> # both workers
>>> def add(x, y):
>>> print(x) # tensor([1., 1.], device='cuda:1')
>>> return x + y, (x + y).to(2)
>>>
>>> # on worker 0
>>> options = TensorPipeRpcBackendOptions(
>>> num_worker_threads=8,
>>> device_maps={"worker1": {0: 1}}
>>> # maps worker0's cuda:0 to worker1's cuda:1
>>> )
>>> options.set_device_map("worker1", {1: 2})
>>> # maps worker0's cuda:1 to worker1's cuda:2
>>>
>>> rpc.init_rpc(
>>> "worker0",
>>> rank=0,
>>> world_size=2,
>>> backend=rpc.BackendType.TENSORPIPE,
>>> rpc_backend_options=options
>>> )
>>>
>>> x = torch.ones(2)
>>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
>>> # The first argument will be moved to cuda:1 on worker1. When
>>> # sending the return value back, it will follow the invert of
>>> # the device map, and hence will be moved back to cuda:0 and
>>> # cuda:1 on worker0
>>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
>>> print(rets[1]) # tensor([2., 2.], device='cuda:1')
"""
full_device_map = _to_device_map(device_map)
curr_device_maps = super().device_maps
if to in curr_device_maps:
for k, v in full_device_map.items():
if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
raise ValueError(
"`set_device_map` only supports 1-to-1 mapping, trying"
f" to map {k} to {v} and {curr_device_maps[to][k]}"
)
super()._set_device_map(to, full_device_map)
def set_devices(self, devices: List[DeviceType]):
r"""
Set local devices used by the TensorPipe RPC agent. When processing
CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
CUDA streams for all devices in this ``List``.
Args:
devices (List of int, str, or torch.device): local devices used by
the TensorPipe RPC agent.
"""
self.devices = _to_device_list(devices)

View File

@ -0,0 +1,80 @@
# mypy: allow-untyped-defs
from functools import partial
import torch
from torch.futures import Future
from . import functions, rpc_async
from .constants import UNSET_RPC_TIMEOUT
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
@functions.async_execution
def _local_invoke_async_execution(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
def _rref_type_cont(rref_fut):
rref_type = rref_fut.value()
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
return rpc_api(
rref.owner(),
_invoke_func,
args=(rref, func_name, args, kwargs),
timeout=timeout,
)
rref_fut = rref._get_type(timeout=timeout, blocking=False)
if rpc_api != rpc_async:
rref_fut.wait()
return _rref_type_cont(rref_fut)
else:
# A little explanation on this.
# rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
# Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
# To address that, we return a Future that is completed with the result of the async call.
result: Future = Future()
def _wrap_rref_type_cont(fut):
try:
_rref_type_cont(fut).then(_complete_op)
except BaseException as ex:
result.set_exception(ex)
def _complete_op(fut):
try:
result.set_result(fut.value())
except BaseException as ex:
result.set_exception(ex)
rref_fut.then(_wrap_rref_type_cont)
return result
# This class manages proxied RPC API calls for RRefs. It is entirely used from
# C++ (see python_rpc_handler.cpp).
class RRefProxy:
def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
self.rref = rref
self.rpc_api = rpc_api
self.rpc_timeout = timeout
def __getattr__(self, func_name):
return partial(
_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout
)

View File

@ -0,0 +1,183 @@
#!/usr/bin/python3
# mypy: allow-untyped-defs
import itertools
from typing import List
import torch
from torch.autograd.profiler_legacy import profile
from . import (
_disable_server_process_global_profiler,
_enable_server_process_global_profiler,
)
__all__: List[str] = []
class _server_process_global_profile(profile):
"""
It has the same API as ``torch.autograd.profiler.profile`` class,
except that it enables profiling on all threads running RPC server request callbacks.
Context manager that manages autograd profiler state and holds a summary of results.
Under the hood it just records events of functions being executed in C++ and
exposes those events to Python. You can wrap any code into it and it will
only report runtime of PyTorch functions.
Note: profiler is thread local and is automatically propagated into the async tasks
Args:
enabled (bool, optional): Setting this to False makes this context manager a no-op.
Default: ``True``.
use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
Adds approximately 4us of overhead to each tensor operation.
Default: ``False``
record_shapes (bool, optional): If shapes recording is set, information
about input dimensions will be collected. This allows one to see which
dimensions have been used under the hood and further group by them
using prof.key_averages(group_by_input_shape=True). Please note that
shape recording might skew your profiling data. It is recommended to
use separate runs with and without shape recording to validate the timing.
Most likely the skew will be negligible for bottom most events (in a case
of nested function calls). But for higher level functions the total
self cpu time might be artificially increased because of the shape
collection.
profile_memory (bool, optional): Whether to report memory usage, default: ``False``
.. warning:
Enabling memory profiling incurs additional profiler overhead
.. warning:
Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
one cannot use the profiler with ``use_cuda = True`` to benchmark
DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
please use ``use_cuda = False`` or ``num_workers = 0``.
Example:
>>> # xdoctest: +SKIP
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> x, y = torch.tensor(1), torch.tensor(2)
>>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
>>> outer_profile_rref.rpc_sync().__enter__()
>>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
>>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
>>> inner_profile_rref.rpc_sync().__enter__()
>>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
>>> inner_profile_rref.rpc_sync().__exit__(None, None, None)
>>> outer_profile_rref.rpc_sync().__exit__(None, None, None)
>>> print(inner_profile_rref.rpc_sync().key_averages())
--------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls
--------- --------------- --------------- --------------- --------------- --------------- ---------------
sub 85.06% 76.275us 100.00% 89.667us 89.667us 1
empty 14.94% 13.392us 14.94% 13.392us 13.392us 1
--------- --------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 89.667us
>>> print(outer_profile_rref.rpc_sync().key_averages())
--------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls
--------- --------------- --------------- --------------- --------------- --------------- ---------------
sub 35.65% 76.275us 41.91% 89.667us 89.667us 1
empty 12.67% 27.101us 12.67% 27.101us 13.551us 2
add 51.68% 110.550us 58.09% 124.259us 124.259us 1
--------- --------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 213.926us
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> # wait for worker 0 to finish work, and then shutdown.
>>> rpc.shutdown()
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __enter__(self):
"""
Turn on server-side process-global profiling.
This enables thread-local profiler on all RPC threads running server-side request callbacks.
"""
if not self.enabled:
return
if self.entered: # type: ignore[has-type]
raise RuntimeError("autograd profiler traces are not reentrant")
self.entered = True
profiler_kind = (
torch.autograd.ProfilerState.CUDA
if self.use_cuda
else torch.autograd.ProfilerState.CPU
)
profiler_config = torch.autograd.ProfilerConfig(
profiler_kind,
self.record_shapes,
self.profile_memory,
False,
False,
False,
torch.profiler._ExperimentalConfig(),
)
_enable_server_process_global_profiler(profiler_config)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Turn off server-side process-global profiling.
Aggregate all profiling events recorded by RPC threads.
These attributes are assigned on exiting context.
Attributes:
function_events (torch.autograd.profiler.EventList). It's a list that has helper
methods, like 1) show record items in a pretty-print table.
2) do averaging by grouping on keys. 3) and more.
process_global_function_events (List[torch.autograd.profiler.FunctionEvent]).
It's a list of ``FunctionEvent`` elements. Every element is a profiling result
of an RPC request handling within the profiling range.
"""
if not self.enabled:
return
process_global_events = _disable_server_process_global_profiler()
# Every element in this list is a thread profiling result from an RPC request handling.
process_global_function_events = []
for thread_local_events in process_global_events:
# Parse from ``Event``s to ``FunctionEvent``s.
thread_local_function_events = (
torch.autograd.profiler_legacy._parse_legacy_records(
thread_local_events
)
)
thread_local_function_events.sort(
key=lambda function_event: [
function_event.time_range.start,
-(function_event.time_range.end),
]
)
process_global_function_events.append(thread_local_function_events)
flattened_function_events = list(
itertools.chain.from_iterable(process_global_function_events)
)
self.function_events = torch.autograd.profiler_util.EventList(
flattened_function_events,
use_device="cuda" if self.use_cuda else None,
profile_memory=self.profile_memory,
)
self.function_events._build_tree()
self.process_global_function_events = process_global_function_events
return False