120 lines
3.7 KiB
Python
120 lines
3.7 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional
|
|
import torch
|
|
from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode
|
|
from torch.utils._contextlib import context_decorator
|
|
from torch._C import _len_torch_function_stack
|
|
import functools
|
|
|
|
CURRENT_DEVICE: Optional[torch.device] = None
|
|
|
|
@functools.lru_cache(1)
|
|
def _device_constructors():
|
|
return {
|
|
# standard ones
|
|
torch.empty,
|
|
torch.empty_permuted,
|
|
torch.empty_strided,
|
|
torch.empty_quantized,
|
|
torch.ones,
|
|
torch.arange,
|
|
torch.bartlett_window,
|
|
torch.blackman_window,
|
|
torch.eye,
|
|
torch.fft.fftfreq,
|
|
torch.fft.rfftfreq,
|
|
torch.full,
|
|
torch.fill,
|
|
torch.hamming_window,
|
|
torch.hann_window,
|
|
torch.kaiser_window,
|
|
torch.linspace,
|
|
torch.logspace,
|
|
torch.nested.nested_tensor,
|
|
# This function doesn't actually take a device argument
|
|
# torch.normal,
|
|
torch.ones,
|
|
torch.rand,
|
|
torch.randn,
|
|
torch.randint,
|
|
torch.randperm,
|
|
torch.range,
|
|
torch.sparse_coo_tensor,
|
|
torch.sparse_compressed_tensor,
|
|
torch.sparse_csr_tensor,
|
|
torch.sparse_csc_tensor,
|
|
torch.sparse_bsr_tensor,
|
|
torch.sparse_bsc_tensor,
|
|
torch.tril_indices,
|
|
torch.triu_indices,
|
|
torch.vander,
|
|
torch.zeros,
|
|
torch.asarray,
|
|
# weird ones
|
|
torch.tensor,
|
|
torch.as_tensor,
|
|
torch.scalar_tensor,
|
|
torch.asarray,
|
|
}
|
|
|
|
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
|
class DeviceContext(TorchFunctionMode):
|
|
def __init__(self, device):
|
|
self.device = torch.device(device)
|
|
|
|
def __enter__(self):
|
|
global CURRENT_DEVICE
|
|
self.old_device = CURRENT_DEVICE
|
|
CURRENT_DEVICE = self.device
|
|
# We need to put the device at the bottom of the stack
|
|
# If we set default device within a function mode context
|
|
# exiting that context mode will pop the device function mode off
|
|
# of the stack incorrectly
|
|
cur_stack = []
|
|
for _ in range(_len_torch_function_stack()):
|
|
cur_stack.append(_pop_mode())
|
|
|
|
_push_mode(self)
|
|
|
|
for mode in reversed(cur_stack):
|
|
_push_mode(mode)
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
global CURRENT_DEVICE
|
|
CURRENT_DEVICE = self.old_device
|
|
cur_stack = []
|
|
# Invariant: there should only be one DeviceContext on the stack at any time
|
|
# (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext
|
|
# or else someone else has popped it!
|
|
for _ in range(_len_torch_function_stack() - 1):
|
|
mode = _pop_mode()
|
|
assert not isinstance(mode, DeviceContext)
|
|
cur_stack.append(mode)
|
|
|
|
if _len_torch_function_stack() > 0:
|
|
mode = _pop_mode()
|
|
assert isinstance(mode, DeviceContext)
|
|
|
|
for mode in reversed(cur_stack):
|
|
_push_mode(mode)
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
if func in _device_constructors() and kwargs.get('device') is None:
|
|
kwargs['device'] = self.device
|
|
return func(*args, **kwargs)
|
|
|
|
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
|
def device_decorator(device, func):
|
|
return context_decorator(lambda: device, func)
|
|
|
|
def set_device(device):
|
|
"""
|
|
Set the default device inside of the wrapped function by decorating it with this function.
|
|
|
|
If you would like to use this as a context manager, use device as a
|
|
context manager directly, e.g., ``with torch.device(device)``.
|
|
"""
|
|
return lambda func: device_decorator(torch.device(device), func)
|