I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,146 @@
# mypy: allow-untyped-defs
import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Optional, Sequence
import torch
import torch._decomp
import torch._prims
import torch._refs
import torch._refs.nn
import torch._refs.nn.functional
import torch._refs.special
import torch.overrides
from torch._prims_common import torch_function_passthrough
@functools.lru_cache(None)
def torch_to_refs_map():
"""
Mapping of torch API functions to torch._refs functions.
E.g. torch_to_refs_map()[torch.add] == torch._refs.add
"""
modules = [
(torch, torch._refs),
(torch.nn, torch._refs.nn),
(torch.nn.functional, torch._refs.nn.functional),
(torch.special, torch._refs.special),
(torch.fft, torch._refs.fft),
(torch.linalg, torch._refs.linalg),
]
r: Dict[Any, Any] = {
torch.Tensor.__invert__: torch._refs.bitwise_not,
torch.Tensor.__xor__: torch._refs.bitwise_xor,
torch.Tensor.__and__: torch._refs.bitwise_and,
torch.Tensor.__or__: torch._refs.bitwise_or,
torch.Tensor.__eq__: torch._refs.eq,
torch.Tensor.__rsub__: torch._refs.rsub,
torch.Tensor.__rtruediv__: torch._refs.rtruediv,
torch.Tensor.__floordiv__: torch._refs.floor_divide,
torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
torch.Tensor.__pow__: torch._refs.pow,
torch.Tensor.__rpow__: torch._refs.rpow,
torch.Tensor.new_empty: torch._refs.new_empty,
torch.Tensor.new_full: torch._refs.new_full,
torch.Tensor.new_zeros: torch._refs.new_zeros,
torch.Tensor.new_ones: torch._refs.new_ones,
torch.Tensor.fill_: torch._refs.fill_,
torch.Tensor.zero_: torch._refs.zero_,
torch.Tensor.to: torch._refs.to,
torch.Tensor.sum_to_size: torch._refs.sum_to_size,
# TODO: Should these methods be mapped some other way?
torch.Tensor.copy_: torch._prims.copy_to,
torch.Tensor.resize: torch._prims.resize,
}
for mod_torch, mod_refs in modules:
for s in mod_refs.__all__: # type: ignore[attr-defined]
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
# Support remapping torch.Tensor.foo to _refs.foo
for s in dir(torch.Tensor):
if s in torch._refs.__all__:
r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
# Support conversions
for s in torch._refs._conversions.__all__:
tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s)
r[tensor_attr] = torch._refs._conversions.__dict__.get(s)
return r
@functools.lru_cache(None)
def all_prims():
"""
Set of all prim functions, e.g., torch._prims.add in all_prims()
"""
return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
class TorchRefsMode(torch.overrides.TorchFunctionMode):
"""
Switches the interpretation of torch.* functions and Tensor methods to
use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.)
>>> # xdoctest: +SKIP
>>> with TorchRefsMode():
... torch.add(x, y) # calls torch._refs.add(x, y)
By default, this context manager will fall back on the torch.* if the
ref does not exist; set strict=True to error if this occurs.
If the ref exists we still would like to fall back on the torch.* sometimes,
this behavior can be customized by passing a function to should_fallback_fn.
"""
def __init__(
self,
strict=False,
should_fallback_fn=lambda *_: False,
prims_mode_cls=nullcontext,
):
self.strict = strict
self.should_fallback_fn = should_fallback_fn
self.prims_mode_cls = prims_mode_cls
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
# For primitive operations, run them as is without interception
# Unless we are in prims_mode, in which case we want to use nvprims
if orig_func in torch_function_passthrough or orig_func in all_prims():
with self.prims_mode_cls():
return orig_func(*args, **kwargs)
mapping = torch_to_refs_map()
func = mapping.get(orig_func, None)
# For torch.ops.aten.*, use registered decompositions from torch._decomp
# torch._decomp.decomposition_table provides a mapping from
# torch.ops.aten.* to torch._refs or torch._decomp.decompositions
# implementations.
# There're other ways to implement this functionality,
# see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
if func is None and isinstance(orig_func, torch._ops.OpOverload):
func = torch._decomp.decomposition_table.get(orig_func, None)
elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
default = getattr(orig_func, "default", None)
if default is not None:
func = torch._decomp.decomposition_table.get(default, None)
if func is not None:
# If the ref exists query whether we should use it or not
if self.should_fallback_fn(self, orig_func, func, args, kwargs):
return orig_func(*args, **kwargs)
# torch calls inside func should be interpreted as refs calls
with self:
return func(*args, **kwargs)
if self.strict:
raise RuntimeError(
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
)
return orig_func(*args, **kwargs)

View File

@ -0,0 +1,54 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Optional
import torch
from torch.utils._content_store import ContentStoreReader
LOAD_TENSOR_READER: Optional[ContentStoreReader] = None
@contextlib.contextmanager
def load_tensor_reader(loc):
global LOAD_TENSOR_READER
assert LOAD_TENSOR_READER is None
# load_tensor is an "op", and we will play merry hell on
# Inductor's memory planning if we return a tensor that
# aliases another tensor that we previously returned from
# an operator. So unlike standard ContentStoreReader use,
# we disable the cache so that you always get fresh storages
# (no aliasing for you!)
LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
try:
yield
finally:
LOAD_TENSOR_READER = None
def register_debug_prims():
torch.library.define(
"debugprims::load_tensor",
"(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor",
)
@torch.library.impl("debugprims::load_tensor", "BackendSelect")
def load_tensor_factory(name, size, stride, dtype, device):
if LOAD_TENSOR_READER is None:
from torch._dynamo.testing import rand_strided
return rand_strided(size, stride, dtype, device)
else:
from torch._dynamo.utils import clone_input
# device argument here takes care of coercion
r = LOAD_TENSOR_READER.read_tensor(name, device=device)
assert list(r.size()) == size, f"{r.size()} != {size}"
assert list(r.stride()) == stride, f"{r.stride()} != {stride}"
assert r.device == device, f"{r.device} != {device}"
# Unlike the other properties, we will do coercions for dtype
# mismatch
if r.dtype != dtype:
r = clone_input(r, dtype=dtype)
return r

View File

@ -0,0 +1,60 @@
# mypy: allow-untyped-defs
from typing import Callable, Optional
from torch._prims.context import TorchRefsMode
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
def execute(
gm: GraphModule,
*args,
executor: str = "aten",
executor_parameters: Optional[dict] = None,
):
"""
Prototype ATen executor.
Just executes the context's graph.
"""
if executor == "aten":
return gm.forward(*args)
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten."
raise ValueError(msg)
def make_traced(fn: Callable):
"""
Returns a function that, when called, will
trace its torch operations to prims and then
execute those prims on the requested trace executor
(possibly lowering them to that trace executor first).
Only supports the torch operations defined in _torch_to_reference_map
in context.py and operations with positional args. All args must
be tensors.
In the near future all these restrictions will be lifted.
Example usage:
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='aten')
"""
def _traced(*args, executor="aten", **kwargs):
# TODO: caching
wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
with TorchRefsMode():
gm = make_fx(wrapped)(all_args)
return execute(gm, all_args, executor=executor)
return _traced

View File

@ -0,0 +1,319 @@
# mypy: allow-untyped-defs
from typing import Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch import _prims
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.types import _device, _dtype
def throw_on_non_cuda(device):
raise RuntimeError(
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
)
def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
rngprim_def = torch.library.custom_op(
"rngprims::" + name, impl_aten, mutates_args=(), schema=schema
)
rngprim_def.register_fake(impl_meta)
prim_packet = getattr(torch._ops.ops.rngprims, name)
prim = prim_packet.default
if tags:
prim._tags = tags
for p in (prim_packet, prim):
p.__doc__ = doc
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
p.schema = name + schema
p.impl_aten = impl_aten
p.prim_meta_impl = impl_meta
# Philox rand offsets could be shared in future with other philox ops, so
# keeping these functions in global scope.
def philox_rand_offset_meta(
shape: torch.Size,
):
return _prims.TensorLike(torch.tensor(0, dtype=torch.int64))
def philox_rand_offset(
shape: torch.Size,
):
# For impl, look at the function calc_execution_policy in the file
# aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at
# commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6
numel_scalar = 1
for dim_size in shape:
numel_scalar *= dim_size
numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64)
block_size = 256
unroll = 4
curand4_engine_calls = 4
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
grid_size = (numel + block_size - 1) // block_size
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
offset = (
(numel - 1) // (block_size * grid_size * unroll) + 1
) * curand4_engine_calls
return offset
def register_philox_rand():
name = "philox_rand"
schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950
def _philox_rand_meta(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
stride = make_contiguous_strides_for(shape)
random_values = _prims.TensorMeta(
shape=shape, strides=stride, dtype=dtype, device=device
)
offset = philox_rand_offset_meta(shape)
return (random_values, offset)
def _philox_rand(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
if device.type == "cpu":
devices = []
else:
devices = [device]
if device.type != "cuda":
raise throw_on_non_cuda(device)
with torch.random.fork_rng(devices):
CUDARngStateHelper.set_torch_state_tensor(seed, offset)
random_values = torch.rand(shape, device=device, dtype=dtype)
return random_values, philox_rand_offset(shape)
register_rng_prim(
name=name,
schema=schema,
impl_aten=_philox_rand,
impl_meta=_philox_rand_meta,
doc="Philox based stateless rand operator",
tags=(torch.Tag.nondeterministic_seeded,),
)
def get_device(args, kwargs):
if kwargs.get("device"):
device = kwargs.get("device")
if isinstance(device, str):
device = torch.device(device)
return device.type
devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
if any(dev == "cuda" for dev in devices):
return "cuda"
elif any(dev == "xpu" for dev in devices):
return "xpu"
elif any(dev == "hpu" for dev in devices):
return "hpu"
elif any(dev == "cpu" for dev in devices):
return "cpu"
return None
def register_run_and_save_rng_state_op():
class RunAndSaveRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_and_save_rng_state")
def __call__(self, op, *args, **kwargs):
return super().__call__(op, *args, **kwargs)
run_and_save_rng_state = RunAndSaveRngState()
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
)
@run_and_save_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, **kwargs):
return torch.cuda.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(op, *args, **kwargs):
return torch.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.HPU)
def impl_hpu(op, *args, **kwargs):
if hasattr(torch, "hpu"):
return torch.hpu.get_rng_state(), op(*args, **kwargs)
raise RuntimeError("functionalize a hpu RNG operator is not supported.")
@run_and_save_rng_state.py_impl(DispatchKey.XPU)
def impl_xpu(op, *args, **kwargs):
return torch.xpu.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(op, *args, **kwargs):
impl_map = {
"cuda": impl_cuda,
"cpu": impl_cpu,
"hpu": impl_hpu,
"xpu": impl_xpu,
}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, op, *args, **kwargs):
# Check device to call the right impl
with mode:
return impl_backend_select(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
out = impl_backend_select(op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
return run_and_save_rng_state
def register_run_with_rng_state_op():
class RunWithRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_with_rng_state")
def __call__(self, rng_state, op, *args, **kwargs):
return super().__call__(rng_state, op, *args, **kwargs)
run_with_rng_state = RunWithRngState()
run_with_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_with_rng_state, deferred_error=True)
)
@run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(rng_state, op, *args, **kwargs):
current_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state.cpu())
out = op(*args, **kwargs)
torch.cuda.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(rng_state, op, *args, **kwargs):
current_state = torch.get_rng_state()
torch.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(DispatchKey.HPU)
def impl_hpu(rng_state, op, *args, **kwargs):
if hasattr(torch, "hpu"):
current_state = torch.hpu.get_rng_state()
torch.hpu.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.hpu.set_rng_state(current_state)
return out
raise RuntimeError("functionalize a hpu RNG operator is not supported.")
@run_with_rng_state.py_impl(DispatchKey.XPU)
def impl_xpu(rng_state, op, *args, **kwargs):
current_state = torch.xpu.get_rng_state()
torch.xpu.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.xpu.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
# TODO: you don't need to do this, the dispatch here already disabled
# it
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (rng_state, op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(rng_state, op, *args, **kwargs):
impl_map = {
"cuda": impl_cuda,
"cpu": impl_cpu,
"hpu": impl_hpu,
"xpu": impl_xpu,
}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
# Skip setting the set_rng_state as it does not work well with fake tensors.
# And it does not matter for the fake tensor mode.
with mode:
return op(*args, **kwargs)
@run_with_rng_state.py_functionalize_impl
def impl_functional(ctx, rng_state, op, *args, **kwargs):
unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
out = run_with_rng_state(
unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
)
return ctx.wrap_tensors(out)
return run_with_rng_state
run_and_save_rng_state = register_run_and_save_rng_state_op()
run_with_rng_state = register_run_with_rng_state_op()
def register_rng_prims():
register_philox_rand()