I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,100 @@
# mypy: allow-untyped-defs
"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module.
It registers custom reducers, that use shared memory to provide shared
views on the same data in different processes. Once the tensor/storage is moved
to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
to send it to other processes without making any copies.
The API is 100% compatible with the original module - it's enough to change
``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
tensors sent through the queues or shared via other mechanisms, moved to shared
memory.
Because of the similarity of APIs we do not document most of this package
contents, and we recommend referring to very good docs of the original module.
"""
import multiprocessing
import sys
import torch
from .reductions import init_reductions
__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
from multiprocessing import * # noqa: F403
__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined]
# This call adds a Linux specific prctl(2) wrapper function to this module.
# See https://github.com/pytorch/pytorch/pull/14391 for more information.
torch._C._multiprocessing_init()
"""Add helper function to spawn N processes and wait for completion of any of
them. This depends `mp.get_context` which was added in Python 3.4."""
from .spawn import (
ENV_VAR_PARALLEL_START,
ProcessContext,
ProcessExitedException,
ProcessRaisedException,
spawn,
SpawnContext,
start_processes,
)
if sys.platform == "darwin" or sys.platform == "win32":
_sharing_strategy = "file_system"
_all_sharing_strategies = {"file_system"}
else:
_sharing_strategy = "file_descriptor"
_all_sharing_strategies = {"file_descriptor", "file_system"}
def set_sharing_strategy(new_strategy):
"""Set the strategy for sharing CPU tensors.
Args:
new_strategy (str): Name of the selected strategy. Should be one of
the values returned by :func:`get_all_sharing_strategies()`.
"""
global _sharing_strategy
assert new_strategy in _all_sharing_strategies
_sharing_strategy = new_strategy
def get_sharing_strategy():
"""Return the current strategy for sharing CPU tensors."""
return _sharing_strategy
def get_all_sharing_strategies():
"""Return a set of sharing strategies supported on a current system."""
return _all_sharing_strategies
def _set_thread_name(name: str) -> None:
"""Set the name of the current thread.
Args:
name (str): Name of the current thread.
"""
torch._C._set_thread_name(name)
def _get_thread_name() -> str:
"""Get the name of the current thread.
Returns:
str: Name of the current thread.
"""
return torch._C._get_thread_name()
init_reductions()

View File

@ -0,0 +1,35 @@
# mypy: allow-untyped-defs
import sys
__all__ = ["register_after_fork"]
if sys.platform == "win32":
import multiprocessing.util as _util
def _register(func):
def wrapper(arg):
func()
_util.register_after_fork(_register, wrapper)
else:
import os
def _register(func):
os.register_at_fork(after_in_child=func)
def register_after_fork(func):
"""Register a callable to be executed in the child process after a fork.
Note:
In python < 3.7 this will only work with processes created using the
``multiprocessing`` module. In python >= 3.7 it also works with
``os.fork()``.
Args:
func (function): Function taking no arguments to be called in the child after fork
"""
_register(func)

View File

@ -0,0 +1,52 @@
import multiprocessing.pool
import multiprocessing.util as util
from .queue import SimpleQueue
def clean_worker(*args, **kwargs):
import gc
multiprocessing.pool.worker(*args, **kwargs)
# Regular multiprocessing workers don't fully clean up after themselves,
# so we have to explicitly trigger garbage collection to make sure that all
# destructors are called...
gc.collect()
class Pool(multiprocessing.pool.Pool):
"""Pool implementation which uses our version of SimpleQueue.
This lets us pass tensors in shared memory across processes instead of
serializing the underlying data.
"""
def _setup_queues(self):
self._inqueue = SimpleQueue()
self._outqueue = SimpleQueue()
self._quick_put = self._inqueue._writer.send
self._quick_get = self._outqueue._reader.recv
def _repopulate_pool(self):
"""Increase the number of pool processes to the specified number.
Bring the number of pool processes up to the specified number, for use after
reaping workers which have exited.
"""
for i in range(self._processes - len(self._pool)):
# changed worker -> clean_worker
args = (
self._inqueue,
self._outqueue,
self._initializer,
self._initargs,
self._maxtasksperchild,
)
if hasattr(self, "_wrap_exception"):
args += (self._wrap_exception,)
w = self.Process(target=clean_worker, args=args)
self._pool.append(w)
w.name = w.name.replace("Process", "PoolWorker")
w.daemon = True
w.start()
util.debug("added worker")

View File

@ -0,0 +1,43 @@
# mypy: allow-untyped-defs
import io
import multiprocessing.queues
import pickle
from multiprocessing.reduction import ForkingPickler
class ConnectionWrapper:
"""Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
def __init__(self, conn):
self.conn = conn
def send(self, obj):
buf = io.BytesIO()
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
self.send_bytes(buf.getvalue())
def recv(self):
buf = self.recv_bytes()
return pickle.loads(buf)
def __getattr__(self, name):
if "conn" in self.__dict__:
return getattr(self.conn, name)
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
class Queue(multiprocessing.queues.Queue):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
self._send = self._writer.send
self._recv = self._reader.recv
class SimpleQueue(multiprocessing.queues.SimpleQueue):
def _make_methods(self):
if not isinstance(self._reader, ConnectionWrapper):
self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
super()._make_methods() # type: ignore[misc]

View File

@ -0,0 +1,647 @@
# mypy: allow-untyped-defs
import multiprocessing
import os
import threading
from multiprocessing.reduction import ForkingPickler
from multiprocessing.util import register_after_fork
from typing import Union
import torch
from torch._namedtensor_internals import check_serializing_named_tensor
try:
# Early load resource_sharer to prevent a partially initialized instance
# from being inherited in a forked child process. The reduce_storage method
# requires this module indirectly through DupFd(). The built-in mp.Queue
# class pickles arguments in a background thread which may overlap with the
# fork.
import multiprocessing.resource_sharer
except ImportError:
pass
class StorageWeakRef:
r"""A weak reference to a Storage.
The cdata member is a Python number containing the integer representation of
the Storage pointer.
"""
__slots__ = ["cdata", "_free_weak_ref"]
def __init__(self, storage):
self.cdata = storage._weak_ref()
# Save a direct reference to _free_weak_ref because the `torch` module
# might be cleared during Python shutdown before this module is cleared.
self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
@classmethod
def from_weakref(cls, cdata):
instance = cls.__new__(cls)
instance.cdata = cdata
instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
return instance
def expired(self):
return torch.Storage._expired(self.cdata) # type: ignore[attr-defined]
def __del__(self):
self._free_weak_ref(self.cdata)
def __hash__(self):
return self.cdata
def __eq__(self, other):
if id(self) == id(other):
return True
return self.cdata == other.cdata
class SharedCache(dict):
"""Dictionary from multiprocessing handles to StorageWeakRef."""
def __init__(self) -> None:
# free_dead_references() is called if the len exceeds the current
# limit. The limit scales with the number of remaining live objects.
self.limit = 128
# `fork` inherits lock state, so in case we fork when the lock is held,
# we register a function to reset the lock to a new object to avoid
# possible deadlocks, following python multiprocessing library design.
self._after_fork()
register_after_fork(self, SharedCache._after_fork)
def _after_fork(self):
self.lock = threading.Lock()
def get(self, key):
with self.lock:
return dict.get(self, key)
def __setitem__(self, key, storage_ref):
with self.lock:
dict.__setitem__(self, key, storage_ref)
if len(self) > self.limit:
self.free_dead_references()
def free_dead_references(self):
live = 0
for key, storage_ref in list(self.items()):
if storage_ref.expired():
del self[key]
else:
live += 1
self.limit = max(128, live * 2)
# mapping from handles to StorageWeakRef objects
shared_cache = SharedCache()
def rebuild_event(device, handle):
return torch.cuda.Event.from_ipc_handle(device, handle)
def reduce_event(event):
handle = event.ipc_handle()
return (rebuild_event, (event.device, handle))
def rebuild_tensor(cls, storage, metadata):
storage_offset, size, stride, requires_grad = metadata
t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
if cls == torch.nn.parameter.Parameter:
# we have to pass requires_grad into constructor, rather than set it as an
# attribute later, because it's an important check for Integer Tensors to
# have requires_grad=False (or else they raise an error)
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
else:
t.requires_grad = requires_grad
return t
def rebuild_meta_tensor(
tensor_cls,
tensor_size,
tensor_stride,
tensor_offset,
dtype,
storage_size_bytes,
requires_grad,
):
untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
typed_storage = torch.TypedStorage(
wrap_storage=untyped_storage, dtype=dtype, _internal=True
)
t = torch._utils._rebuild_tensor(
typed_storage,
tensor_offset,
tensor_size,
tensor_stride,
)
if tensor_cls == torch.nn.parameter.Parameter:
# It is crucial for integer tensors to receive
# the requires_grad=False as an argument in the constructor
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
else:
t.requires_grad = requires_grad
return t
def rebuild_cuda_tensor(
tensor_cls,
tensor_size,
tensor_stride,
tensor_offset,
storage_cls,
dtype,
storage_device,
storage_handle,
storage_size_bytes,
storage_offset_bytes,
requires_grad,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
):
# If storage_handle is None, storage points to nullptr.
if storage_handle is None or storage_size_bytes == 0:
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
else:
storage = storage_from_cache(
storage_cls, (storage_handle, storage_offset_bytes)
)
if storage is None:
torch.cuda._lazy_init()
storage = storage_cls._new_shared_cuda(
storage_device,
storage_handle,
storage_size_bytes,
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
)
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
storage
)
else:
# We already ref counting this Storage, but producer needs new ref-counters to be released.
storage_cls._release_ipc_counter(
ref_counter_handle, ref_counter_offset, device=storage_device
)
_storage = (
storage
if isinstance(storage, torch.UntypedStorage)
else storage._untyped_storage
)
t = torch._utils._rebuild_tensor(
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
tensor_offset,
tensor_size,
tensor_stride,
)
if tensor_cls == torch.nn.parameter.Parameter:
# It is crucial for integer tensors to receive
# the requires_grad=False as an argument in the constructor
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
else:
t.requires_grad = requires_grad
return t
def reduce_tensor(tensor):
if tensor.requires_grad and not tensor.is_leaf:
raise RuntimeError(
"Cowardly refusing to serialize non-leaf tensor which requires_grad, "
"since autograd does not support crossing process boundaries. "
"If you just want to transfer the data, call detach() on the tensor "
"before serializing (e.g., putting it on the queue)."
)
check_serializing_named_tensor(tensor)
torch.utils.hooks.warn_if_has_hooks(tensor)
# Note [CUDA IPC and the caching allocator]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# When you send a CUDA tensor over IPC, you might expect that you will
# get out the same storage from the other end. However, the CUDA caching
# allocator makes it difficult to preserve this invariant. Consider
# the following situation: a tensor of size 0x100 points to offset 0x20 of
# a storage at 0xA100 of size 0x100. (For simplicity, all of these
# sizes are given in bytes). HOWEVER, with the caching allocator, this storage
# might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
#
# When we want to send this CUDA tensor over IPC, we must send the
# *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
# the storage 0xA100 (because that is what CUDA supports). So, on the
# other end, there simply isn't any way to say, "Wait, you gave me
# a bigger region (0xA000) than the one I wanted (0xA100)".
#
# OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
# one storage itself? No, because this cudaMalloc allocation might contain
# storages of mixed types: float, bytes, double... If you make the entire
# allocation a single storage of a type A, we'll hit an error when constructing
# a tensor of type B on the storage.
#
# cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
# receiver side. However, cudaIpcMemHandles from each device in a given process may
# only be opened by one context per device per other process.
# If we open and close a memory handle multiples times in a process, CUDA is allowed
# to give it a different address; similarly, once we close the memory, we're not
# allowed to access it(and the storage/tensor built on top of it), even if it is
# still live in the original process. As we cannot make a cudaMalloc allocation
# to a single storage in one go, this requires us to cache the device pointer for
# each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
# the old ones alives.
# See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
#
# This is fine, because all we need to do is to save our position in the allocation,
# and reconstruct storage and tensor from it.
# 0xA000 -> -------CUDA Allocation------
# | |
# | |
# | |
# | |
# 0xA100 -> --------storage1 begin------
# | |
# 0xA120 -> --------tensor1 begin ------
# | |
# | |
# | |
# | |
# | |
# 0xA160 -> --------tensor1 end---------
# | |
# | |
# | |
# 0xA200 -> --------storage1 end--------
# | |
# 0xE000 -> --------CUDA allocation-----
#
# To send tensor1, the following info are required from sender to receiver for
# storage recontruction.
# 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
# basePtr may not be exactly 0xA000 since it's a different process.
# 2. offset(0xA100) of storage1 in the CUDA allocation.
# 3. size of storage1(0x100).
#
# On receiver side:
# 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
# of the same type using (basePtr, offset, size).
# 2. we can reconstruct the tensor on top of the reconstructed storage
# Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
#
# This strategy has a few implications:
#
# 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
# go (non-compositionally), and this requires to have a global map
# memHandle -> devPtr for each process.
#
# 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize
# of the storage beyond 0x100 would merely have caused us to do a
# reallocation. You don't really want to do this, but if you did,
# all that would happen is that you would lose IPC sharing. But if
# you do this in the new world, we will happily let you write out of
# bounds of your "allocation", clobbering unrelated data in the cached
# allocator block. BAD!
#
# By the way, in old versions of PyTorch, we supported this situation
# natively using a "storage view", which permitted multiple storages to be
# views on each other. But this was the *only* use of storage views, so we
# eliminated it so that we could just use tensor views to implement the same
# thing.
#
# TODO: Handle distinguishing between subclass and non-subclass versions of NT better
# https://github.com/pytorch/pytorch/issues/110543
from torch.nested._internal.nested_tensor import NestedTensor
if tensor.is_nested and not isinstance(tensor, NestedTensor):
return reduce_nested_tensor(tensor)
if tensor.layout in {
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_bsr,
torch.sparse_csc,
torch.sparse_bsc,
}:
return reduce_sparse_tensor(tensor)
storage = tensor._typed_storage()
if storage._untyped_storage.device.type == "cuda":
(
device,
handle,
storage_size_bytes,
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
) = storage._share_cuda_()
tensor_offset = tensor.storage_offset()
shared_cache[handle] = StorageWeakRef(storage)
# _backward_hooks purposely omitted here, see
# Note [Don't serialize hooks]
return (
rebuild_cuda_tensor,
(
type(tensor),
tensor.size(),
tensor.stride(),
tensor_offset, # tensor offset in its storage
type(storage),
tensor.dtype,
device,
handle, # identifier which CUDA allocation is the storage in.
storage_size_bytes, # size(in bytes) of the storage
storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
tensor.requires_grad,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
),
)
elif storage._untyped_storage.device.type == "meta":
return (
rebuild_meta_tensor,
(
type(tensor),
tensor.size(),
tensor.stride(),
tensor.storage_offset(),
tensor.dtype,
tensor.untyped_storage().size(),
tensor.requires_grad,
),
)
# _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
metadata = (
tensor.storage_offset(),
tensor.size(),
tensor.stride(),
tensor.requires_grad,
)
return (rebuild_tensor, (type(tensor), storage, metadata))
def rebuild_nested_tensor(
rebuild_buffer_func,
rebuild_buffer_args,
rebuild_sizes_func,
rebuild_sizes_args,
rebuild_strides_func,
rebuild_strides_args,
rebuild_offsets_func,
rebuild_offsets_args,
):
buffer = rebuild_buffer_func(*rebuild_buffer_args)
sizes = rebuild_sizes_func(*rebuild_sizes_args)
strides = rebuild_strides_func(*rebuild_strides_args)
offsets = rebuild_offsets_func(*rebuild_offsets_args)
return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
def reduce_nested_tensor(nt):
rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
rebuild_strides_func, rebuild_strides_args = reduce_tensor(
nt._nested_tensor_strides()
)
rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
nt._nested_tensor_storage_offsets()
)
return (
rebuild_nested_tensor,
(
rebuild_buffer_func,
rebuild_buffer_args,
rebuild_sizes_func,
rebuild_sizes_args,
rebuild_strides_func,
rebuild_strides_args,
rebuild_offsets_func,
rebuild_offsets_args,
),
)
def rebuild_sparse_coo_tensor(
rebuild_indices_func,
rebuild_indices_args,
rebuild_values_func,
rebuild_values_args,
shape,
is_coalesced,
):
indices = rebuild_indices_func(*rebuild_indices_args)
values = rebuild_values_func(*rebuild_values_args)
return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
def rebuild_sparse_compressed_tensor(
rebuild_compressed_indices_func,
rebuild_compressed_indices_args,
rebuild_plain_indices_func,
rebuild_plain_indices_args,
rebuild_values_func,
rebuild_values_args,
shape,
layout,
):
compressed_indices = rebuild_compressed_indices_func(
*rebuild_compressed_indices_args
)
plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
values = rebuild_values_func(*rebuild_values_args)
return torch.sparse_compressed_tensor(
compressed_indices, plain_indices, values, shape, layout=layout
)
def reduce_sparse_tensor(sparse):
if sparse.layout is torch.sparse_coo:
rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
return (
rebuild_sparse_coo_tensor,
(
rebuild_indices_func,
rebuild_indices_args,
rebuild_values_func,
rebuild_values_args,
sparse.shape,
sparse.is_coalesced(),
),
)
else:
if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices = sparse.crow_indices()
plain_indices = sparse.col_indices()
elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
compressed_indices = sparse.ccol_indices()
plain_indices = sparse.row_indices()
else:
raise NotImplementedError(sparse.layout)
(
rebuild_compressed_indices_func,
rebuild_compressed_indices_args,
) = reduce_tensor(compressed_indices)
rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
plain_indices
)
rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
return (
rebuild_sparse_compressed_tensor,
(
rebuild_compressed_indices_func,
rebuild_compressed_indices_args,
rebuild_plain_indices_func,
rebuild_plain_indices_args,
rebuild_values_func,
rebuild_values_args,
sparse.shape,
sparse.layout,
),
)
def fd_id(fd):
# Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
# this doesn't work with shared memory handles, which is why we don't
# support the "file_descriptor" sharing method on that platform.
stat = os.fstat(fd)
return (stat.st_ino, stat.st_dev)
def storage_from_cache(cls, key):
storage_ref = shared_cache.get(key)
if storage_ref is None:
return None
return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
def rebuild_storage_fd(cls, df, size):
fd = df.detach()
try:
storage = storage_from_cache(cls, fd_id(fd))
if storage is not None:
return storage
storage = cls._new_shared_fd_cpu(fd, size)
shared_cache[fd_id(fd)] = StorageWeakRef(storage)
return storage
finally:
os.close(fd)
def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
cls, handle
)
if storage is not None:
return storage._shared_decref()
if dtype is None:
storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
else:
byte_size = size * torch._utils._element_size(dtype)
untyped_storage: torch.UntypedStorage = (
torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
)
storage = torch.TypedStorage(
wrap_storage=untyped_storage, dtype=dtype, _internal=True
)
shared_cache[handle] = StorageWeakRef(storage)
return storage._shared_decref()
def rebuild_storage_empty(cls):
return cls()
def rebuild_typed_storage(storage, dtype):
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
# Use for torch.storage.TypedStorage
def reduce_typed_storage(storage):
return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
def rebuild_typed_storage_child(storage, storage_type):
return storage_type(wrap_storage=storage, _internal=True)
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
def reduce_typed_storage_child(storage):
return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
def reduce_storage(storage):
from . import get_sharing_strategy
if storage.is_cuda:
raise RuntimeError(
"Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
)
elif storage.device.type == "meta":
raise RuntimeError(
"Cannot pickle meta storage; try pickling a meta tensor instead"
)
elif get_sharing_strategy() == "file_system":
metadata = storage._share_filename_cpu_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
if isinstance(storage, torch.TypedStorage):
metadata += (storage.dtype,)
storage._shared_incref()
elif storage.size() == 0:
# This is special cased because Empty tensors
# (with size 0) cannot be mmapped.
return (rebuild_storage_empty, (type(storage),))
else:
fd, size = storage._share_fd_cpu_()
df = multiprocessing.reduction.DupFd(fd)
cache_key = fd_id(fd)
metadata = (df, size)
rebuild = rebuild_storage_fd # type: ignore[assignment]
shared_cache[cache_key] = StorageWeakRef(storage)
return (rebuild, (type(storage),) + metadata)
def init_reductions():
ForkingPickler.register(torch.cuda.Event, reduce_event)
for t in torch._storage_classes:
if t.__name__ == "UntypedStorage":
ForkingPickler.register(t, reduce_storage)
else:
ForkingPickler.register(t, reduce_typed_storage_child)
ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
for t in torch._tensor_classes:
ForkingPickler.register(t, reduce_tensor)
# TODO: Maybe this should be in tensor_classes? :)
ForkingPickler.register(torch.Tensor, reduce_tensor)
from torch.nn.parameter import Parameter
ForkingPickler.register(Parameter, reduce_tensor)

View File

@ -0,0 +1,328 @@
# mypy: allow-untyped-defs
import logging
import multiprocessing
import multiprocessing.connection
import os
import pickle
import signal
import sys
import tempfile
import time
import warnings
from concurrent.futures import as_completed, ThreadPoolExecutor
from typing import Optional
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
log = logging.getLogger(__name__)
__all__ = [
"ProcessContext",
"ProcessException",
"ProcessExitedException",
"ProcessRaisedException",
"spawn",
"SpawnContext",
"start_processes",
]
class ProcessException(Exception):
__slots__ = ["error_index", "error_pid"]
def __init__(self, msg: str, error_index: int, pid: int):
super().__init__(msg)
self.msg = msg
self.error_index = error_index
self.pid = pid
def __reduce__(self):
return type(self), (self.msg, self.error_index, self.pid)
class ProcessRaisedException(ProcessException):
"""Exception raised when a process failed due to an exception raised by the code."""
def __init__(
self,
msg: str,
error_index: int,
error_pid: int,
):
super().__init__(msg, error_index, error_pid)
class ProcessExitedException(ProcessException):
"""Exception raised when a process failed due to signal or exited with a specific code."""
__slots__ = ["exit_code"]
def __init__(
self,
msg: str,
error_index: int,
error_pid: int,
exit_code: int,
signal_name: Optional[str] = None,
):
super().__init__(msg, error_index, error_pid)
self.exit_code = exit_code
self.signal_name = signal_name
def __reduce__(self):
return (
type(self),
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
)
def _wrap(fn, i, args, error_file):
# prctl(2) is a Linux specific system call.
# On other systems the following function call has no effect.
# This is set to ensure that non-daemonic child processes can
# terminate if their parent terminates before they do.
_prctl_pr_set_pdeathsig(signal.SIGINT)
try:
fn(i, *args)
except KeyboardInterrupt:
pass # SIGINT; Killed by parent, do nothing
except Exception:
# Propagate exception to parent process, keeping original traceback
import traceback
with open(error_file, "wb") as fh:
pickle.dump(traceback.format_exc(), fh)
sys.exit(1)
class ProcessContext:
def __init__(self, processes, error_files):
self.error_files = error_files
self.processes = processes
self.sentinels = {
process.sentinel: index for index, process in enumerate(processes)
}
def pids(self):
return [int(process.pid) for process in self.processes]
def join(self, timeout=None):
r"""Join one or more processes within spawn context.
Attempt to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes and raises an exception with the cause
of the first process exiting.
Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.
Args:
timeout (float): Wait this long before giving up on waiting.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True
# Wait for any process to fail or all of them to succeed.
ready = multiprocessing.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# Assume failure. Terminate processes that are still alive.
# Try SIGTERM then SIGKILL if the process isn't going down.
# The reason is related to python signal handling is limited
# to main thread and if that is in c/c++ land and stuck it won't
# to handle it. We have seen processes getting stuck not handling
# SIGTERM for the above reason.
timeout: int = 30
for process in self.processes:
if process.is_alive():
log.warning("Terminating process %s via signal SIGTERM", process.pid)
process.terminate()
end = time.monotonic() + timeout
for process in self.processes:
time_to_wait = max(0, end - time.monotonic())
process.join(time_to_wait)
for process in self.processes:
if process.is_alive():
log.warning(
"Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
process.pid,
)
process.kill()
process.join()
# The file will only be created if the process crashed.
failed_process = self.processes[error_index]
if not os.access(self.error_files[error_index], os.R_OK):
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
try:
name = signal.Signals(-exitcode).name
except ValueError:
name = f"<Unknown signal {-exitcode}>"
raise ProcessExitedException(
"process %d terminated with signal %s" % (error_index, name),
error_index=error_index,
error_pid=failed_process.pid,
exit_code=exitcode,
signal_name=name,
)
else:
raise ProcessExitedException(
"process %d terminated with exit code %d" % (error_index, exitcode),
error_index=error_index,
error_pid=failed_process.pid,
exit_code=exitcode,
)
with open(self.error_files[error_index], "rb") as fh:
original_trace = pickle.load(fh)
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
msg += original_trace
raise ProcessRaisedException(msg, error_index, failed_process.pid)
class SpawnContext(ProcessContext):
def __init__(self, processes, error_files):
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
super().__init__(processes, error_files)
# Note: [start_processes]
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
# general enough, and backends like XLA can reuse them in Colab notebooks as well.
# Currently we only add this API first, we can consider adding it to documentation as
# needed in the future.
def start_processes(
fn,
args=(),
nprocs=1,
join=True,
daemon=False,
start_method="spawn",
):
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
# this func will start processes in parallel if start_method is 'forkserver'.
# Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
# todo: investigate why spawn does not work with threadpool and raises SIGINT
if (
start_method == "forkserver"
and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
):
log.info("Starting processes in parallel.")
start_parallel = True
else:
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
start_parallel = False
mp = multiprocessing.get_context(start_method)
error_files = [None] * nprocs
processes = [None] * nprocs
def start_process(i):
# Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously
# used a multiprocessing.Queue but that can be prone to
# deadlocks, so we went with a simpler solution for a one-shot
# message between processes.
tf = tempfile.NamedTemporaryFile(
prefix="pytorch-errorfile-", suffix=".pickle", delete=False
)
tf.close()
os.unlink(tf.name)
process = mp.Process(
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,
)
process.start()
return i, process, tf.name
if not start_parallel:
for i in range(nprocs):
idx, process, tf_name = start_process(i)
error_files[idx] = tf_name
processes[idx] = process
else:
with ThreadPoolExecutor(max_workers=nprocs) as executor:
futures = [executor.submit(start_process, i) for i in range(nprocs)]
for fut in as_completed(futures):
idx, process, tf_name = fut.result()
# idx and process rank needs to be the same.
error_files[idx] = tf_name
processes[idx] = process
context = ProcessContext(processes, error_files)
if not join:
return context
# Loop on join until it returns True or raises an exception.
while not context.join():
pass
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
If one of the processes exits with a non-zero exit status, the
remaining processes are killed and an exception is raised with the
cause of termination. In the case an exception was caught in the
child process, it is forwarded and its traceback is included in
the exception raised in the parent process.
Args:
fn (function): Function is called as the entrypoint of the
spawned process. This function must be defined at the top
level of a module so it can be pickled and spawned. This
is a requirement imposed by multiprocessing.
The function is called as ``fn(i, *args)``, where ``i`` is
the process index and ``args`` is the passed through tuple
of arguments.
args (tuple): Arguments passed to ``fn``.
nprocs (int): Number of processes to spawn.
join (bool): Perform a blocking join on all processes.
daemon (bool): The spawned processes' daemon flag. If set to True,
daemonic processes will be created.
start_method (str): (deprecated) this method will always use ``spawn``
as the start method. To use a different start method
use ``start_processes()``.
Returns:
None if ``join`` is ``True``,
:class:`~ProcessContext` if ``join`` is ``False``
"""
if start_method != "spawn":
msg = (
f"This method only supports start_method=spawn (got: {start_method}).\n"
"To use a different start_method use:\n\t\t"
" torch.multiprocessing.start_processes(...)"
)
warnings.warn(msg, FutureWarning, stacklevel=2)
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")