I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,17 @@
import torch
from torch._subclasses.fake_tensor import (
DynamicOutputShapeException,
FakeTensor,
FakeTensorMode,
UnsupportedFakeTensorException,
)
from torch._subclasses.fake_utils import CrossRefFakeMode
__all__ = [
"FakeTensor",
"FakeTensorMode",
"UnsupportedFakeTensorException",
"DynamicOutputShapeException",
"CrossRefFakeMode",
]

View File

@ -0,0 +1,258 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
import torch
from torch import SymInt
from torch.fx.experimental.sym_node import SymNode
from torch.types import py_sym_types, PySymType
from torch.utils._backport_slots import dataclass_slots
if TYPE_CHECKING:
import sympy
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from .fake_tensor import _DispatchCacheKey, _MetadataIntLike
@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymNode:
"""
Represents a SymNode without the associated ShapeEnv
"""
# n.b. keep the same protocol as SymNode
_expr: sympy.Expr
pytype: type
_hint: Optional[Union[int, float, bool]]
constant: Optional[Union[int, float, bool]]
fx_node: torch.fx.Node
@staticmethod
def from_node(node: SymNode) -> _DeconstructedSymNode:
return _DeconstructedSymNode(
node._expr, node.pytype, node._hint, node.constant, node.fx_node
)
def extract(self, shape_env: ShapeEnv) -> SymNode:
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
def __str__(self) -> str:
return str(self._expr)
def __repr__(self) -> str:
return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}"
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int:
raise NotImplementedError
# _value_eq to match SymNode
def _value_eq(self, other: object) -> bool:
if isinstance(other, (SymNode, _DeconstructedSymNode)):
return (
self._expr == other._expr
and self.pytype == other.pytype
and self._hint == other._hint
and self.constant == other.constant
and self.fx_node == other.fx_node
)
else:
return False
# _value_hash to match SymNode
def _value_hash(self) -> int:
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymType:
"""
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
"""
ty: Type[PySymType]
node: _DeconstructedSymNode
@staticmethod
def from_sym_type(value: PySymType) -> _DeconstructedSymType:
return _DeconstructedSymType(type(value), value.node)
def extract(self, shape_env: ShapeEnv) -> PySymType:
return self.ty(self.node.extract(shape_env))
def __str__(self) -> str:
return f"{self.ty}({self.node})"
def __repr__(self) -> str:
return f"_DeconstructedSymType({self.ty}, {self.node!r})"
def __eq__(self, other: object) -> bool:
return NotImplemented
def __hash__(self) -> int:
return NotImplemented
@dataclass_slots
@dataclass(frozen=True)
class _InputBackref:
value: int
@dataclass_slots
@dataclass
class _PySymInputStub:
"""
Represents a SymInt in the cached key. Needed because SymInt doesn't
support __eq__ or __hash__ directly.
"""
# value can be:
# PySymType: This is the 'normal' SymInt value, wrapped so we can use
# hash/eq as value hash/eq (normally SymInt does object
# hash/eq).
# _DeconstructedSymType: This is used when storing the _PySymInputStub in
# the cache to avoid cyclic ShapeEnv references.
# _InputBackref: This is a back-reference to a previous _PySymInputStub in
# the key.
value: Union[PySymType, _DeconstructedSymType, _InputBackref]
def __init__(
self, value: Union[PySymType, _DeconstructedSymType, _InputBackref]
) -> None:
# For inputs (values in the `key`) we need to keep the PySymType intact
# - this way if we need to reuse it as an output we can properly copy
# the original value.
self.value = value
def strip_shape_env(self) -> None:
if isinstance(self.value, py_sym_types):
self.value = _DeconstructedSymType.from_sym_type(self.value)
def extract(self, shape_env: ShapeEnv) -> PySymType:
if isinstance(self.value, _DeconstructedSymType):
return self.value.extract(shape_env)
else:
# We should never see an _InputBackref here - anyone extracting a
# value should be pulling from the original entry (the one this
# backref points at).
assert not isinstance(self.value, _InputBackref)
return self.value
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return f"_PySymInputStub({self.value!r})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, _PySymInputStub):
return False
elif isinstance(self.value, _InputBackref) or isinstance(
other.value, _InputBackref
):
return self.value == other.value
else:
return self.value.node._value_eq(other.value.node)
def __hash__(self) -> int:
if isinstance(self.value, _InputBackref):
return hash(self.value)
else:
return self.value.node._value_hash()
@dataclass_slots
@dataclass
class _SymIntOutputStub:
"""
Represents a SymInt in the cached output.
"""
# This is either an `int` which represents the index in the key to copy the
# SymNode from or it's the deconstructed SymNode itself.
value: Union[int, _DeconstructedSymNode]
def __init__(self, value: SymInt, key_path: Optional[int]) -> None:
if key_path is None:
self.value = _DeconstructedSymNode.from_node(value.node)
else:
self.value = key_path
def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt:
if isinstance(self.value, _DeconstructedSymNode):
return SymInt(self.value.extract(shape_env))
else:
src = key.key[self.value]
assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt)
return src.value
def __repr__(self) -> str:
return f"_SymIntOutputStub({self.value!r})"
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int:
raise NotImplementedError
@dataclass_slots
@dataclass
class _CacheKeyState:
"""
State used while building our cache key.
"""
# We track the SymNodes so when we get the output we can see if it exactly
# matches one of the inputs so we can uncache it properly.
sym_node_lookup: Dict[int, int] # id(SymNode) -> index
# There are cases where we're asked to perform an op when we have no
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
# ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
# here.
shape_env: Optional[ShapeEnv]
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
self.sym_node_lookup = {}
self.shape_env = shape_env
def cache_on_shape_env(self) -> bool:
"""
Returns true if the CacheKey needs to be cached on the ShapeEnv
rather than the global cache.
If our inputs contain a SymNode then we can't cache this operation on
the global cache because the cached output will implicitly depend on
guard values which might not be true on some other ShapeEnv. So unless
we're also going to cache the guards we need to cache this operation on
the ShapeEnv instead of globally.
"""
return bool(self.sym_node_lookup)
def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
node_id = id(arg.node)
if node_id in self.sym_node_lookup:
result.append(_InputBackref(self.sym_node_lookup[node_id]))
else:
self.sym_node_lookup[node_id] = len(result)
if self.shape_env is None:
self.shape_env = arg.node.shape_env
result.append(_PySymInputStub(arg))
def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike:
if isinstance(arg, SymInt):
return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None))
else:
return arg

View File

@ -0,0 +1,959 @@
# mypy: ignore-errors
import functools
import itertools
import math
import sys
from typing import Callable, Union
import torch
import torch._custom_op
import torch._logging
from torch._dispatch.python import no_python_dispatcher
from torch._ops import OpOverload
from torch._prims_common import (
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
is_float_dtype,
is_integer_dtype,
)
from torch._subclasses.fake_tensor import (
DataDependentOutputException,
DynamicOutputShapeException,
FakeTensor,
in_kernel_invocation_manager,
run_fallback_kernel,
UnsupportedOperatorException,
)
from torch.fx.operator_schemas import normalize_function
from torch.utils._stats import count_label
pytree = torch.utils._pytree
__all__ = [
"op_implementations_checks",
"get_fast_op_impls",
"stride_incorrect_op",
"has_meta",
]
op_implementations_dict = {}
op_implementations_checks = []
aten = torch._ops.ops.aten
def ordered_set(*items):
return dict.fromkeys(items, True)
# This function indicates if the backend device
# supports non-contiguous tensors
def is_noncontiguous_supported(device):
return device.type != "hpu"
_like_tensor_constructors = ordered_set(
aten.empty_like.default,
aten.empty_like.out,
aten.full_like.default,
aten.full_like.out,
aten.ones_like.default,
aten.ones_like.out,
aten.rand_like.default,
aten.rand_like.out,
aten.randn_like.default,
aten.randn_like.out,
aten.randint_like.default,
aten.randint_like.out,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,
aten.zeros_like.default,
aten.zeros_like.out,
aten.new_empty.default,
aten.new_empty.out,
aten.new_empty_strided.default,
aten.new_empty_strided.out,
aten.new_full.default,
aten.new_full.out,
aten.new_zeros.default,
aten.new_zeros.out,
aten.new_ones.default,
aten.new_ones.out,
)
_device_not_kwarg_ops = ordered_set(
aten._resize_output_.default,
aten._nested_tensor_from_tensor_list.default,
aten._nested_tensor_from_tensor_list.out,
aten.pin_memory.default,
aten.to.device,
aten.to.prim_Device,
aten.is_pinned.default,
aten._pin_memory.default,
aten._pin_memory.out,
aten._resize_output.default,
aten._resize_output.out,
)
# this op is never actually used
_non_kwarg_device_constructors = (aten._list_to_tensor,)
def contains_tensor_types(type):
tensor_type = torch._C.TensorType.get()
return type.isSubtypeOf(tensor_type) or any(
contains_tensor_types(e) for e in type.containedTypes()
)
@functools.lru_cache(None)
def _is_tensor_constructor(func: OpOverload):
assert isinstance(func, OpOverload)
schema = func._schema
if any(contains_tensor_types(arg.type) for arg in schema.arguments):
return False
# TODO: no real reason to restrict multiple outputs
return (
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
)
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def impl_decorator(op_impl):
if isinstance(run_impl_check, OpOverload):
assert (
run_impl_check not in op_implementations_dict
), f"duplicate registration: {run_impl_check}"
op_implementations_dict[run_impl_check] = op_impl
elif isinstance(run_impl_check, (list, tuple)):
for op in run_impl_check:
register_op_impl(op)(op_impl)
else:
assert callable(run_impl_check)
op_implementations_checks.append((run_impl_check, op_impl))
return op_impl
return impl_decorator
@register_op_impl(op_implementations_dict.__contains__)
def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
@register_op_impl(_is_tensor_constructor)
@register_op_impl([*_like_tensor_constructors])
def constructors(fake_mode, func, *args, **kwargs):
assert func not in _non_kwarg_device_constructors
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
if "names" in kwargs:
raise UnsupportedOperatorException(
"torch.compile doesn't support named tensors"
)
if func in _like_tensor_constructors:
default_device = new_kwargs["input"].device
# TODO: file issue
args = (new_kwargs.pop("input"),)
else:
# cpu is default device if none is specified
default_device = torch.device("cpu")
args = ()
out_device = new_kwargs.pop("device", None)
out_device = out_device if out_device is not None else default_device
new_kwargs["device"] = torch.device("meta")
# _like constructors have fake tensor inputs (maybe this causes the non-like
# to fail? hmmm)
with in_kernel_invocation_manager(fake_mode):
r = func(*args, **new_kwargs)
return FakeTensor(fake_mode, r, out_device)
@register_op_impl(aten.is_pinned.default)
def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args, kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
# we'll ignore device argument because it is deprecated and not
# actually used by is_pinned.
with in_kernel_invocation_manager(fake_mode):
r = func(inp)
return r
@register_op_impl(aten.to.prim_Device)
@register_op_impl(aten.to.device)
def non_kwarg_to(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args, kwargs, normalize_to_only_use_kwargs=True
)
input_device = new_kwargs["device"]
out_device = input_device if input_device else new_kwargs["input"].device
new_kwargs["device"] = torch.device("meta")
inp = new_kwargs.pop("input")
with in_kernel_invocation_manager(fake_mode):
r = func(inp, **new_kwargs)
# TODO: I think this does the wrong thing if r is inp
return fake_mode.fake_tensor_converter.from_meta_and_device(
fake_mode, r, out_device
)
def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False
op_name = op.name()
if "fft" in op_name:
return True
return False
# These operators have meta implementations with incorrect strides
@register_op_impl(stride_incorrect_op)
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
# This is a workaround for meta implmentations with incorrect strides
def is_symbolic(x):
if isinstance(x, FakeTensor):
return x._has_symbolic_sizes_strides
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return True
return False
# For static shapes, we can fall back to eager for the real strides
if fake_mode.allow_fallback_kernels:
require_dynamic = any(
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
)
if not require_dynamic:
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
raise UnsupportedOperatorException(func)
# Dont default to default device handling,
# since the device of `the_template` is ignored
@register_op_impl(aten.resize_as_.default)
def resize_as_(fake_mode, func, *args, **kwargs):
with in_kernel_invocation_manager(fake_mode):
return func(*args, **kwargs)
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
# TODO: remove me
return constructors(fake_mode, func, *args, **kwargs)
# index.Tensor data-dependent in only some conditions
@register_op_impl(
lambda func: torch.Tag.dynamic_output_shape in func.tags
and func
not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
)
def dyn_shape(fake_mode, func, *args, **kwargs):
raise DynamicOutputShapeException(func)
def _unique(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
# Do not use a memo for unique_dim
if dim is not None or (nnz := arg.unique_memo) is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
# If numel is zero, then the output size must be zero.
# In this case, we must not allocate an unbacked SymInt,
# because if we do, it will immediately get refined to
# zero, but this will be inconsistent with size oblivious
# tests (which will continue to claim that the unbacked
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = 0
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
maxval = sys.maxsize - 1
numel = arg.numel() if dim is None else arg.size(dim)
if not has_free_symbols(numel):
maxval = int(numel)
_constrain_range_for_size(nnz, max=maxval)
if dim is None:
arg.unique_memo = nnz
if dim is None:
ret = [arg.new_empty((nnz,))]
else:
ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
if return_inverse or return_if_dim_and_cpu:
inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
else:
inverse = arg.new_empty(0)
ret.append(inverse)
if return_counts or return_if_dim_and_cpu:
counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
else:
counts = arg.new_empty(0)
ret.append(counts)
return tuple(ret)
@register_op_impl(aten._unique2.default)
def unique2(
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
):
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
@register_op_impl(aten.unique_dim.default)
def unique_dim(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
):
return _unique(
fake_mode,
func,
arg,
# normalize dim to be non-negative
dim if dim >= 0 else dim % max(arg.ndim, 1),
sorted,
return_inverse,
return_counts,
)
@register_op_impl(aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
raise DynamicOutputShapeException(func)
output_size = fake_mode.shape_env.create_unbacked_symint()
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
_constrain_range_for_size(output_size)
# TODO: consider a memo
return repeats.new_empty(output_size)
@register_op_impl(torch.ops.aten.item.default)
@register_op_impl(torch.ops.aten._local_scalar_dense.default)
def local_scalar_dense(fake_mode, func, arg):
if (r := arg.item_memo) is not None:
return r
if fake_mode.shape_env is None or (
not fake_mode.shape_env.allow_scalar_outputs
and not fake_mode.allow_scalar_outputs
):
# Without symints/symfloats, cannot handle this
raise DataDependentOutputException(func)
if is_float_dtype(arg.dtype):
r = fake_mode.shape_env.create_unbacked_symfloat()
elif is_integer_dtype(arg.dtype):
r = fake_mode.shape_env.create_unbacked_symint()
elif is_boolean_dtype(arg.dtype):
r = fake_mode.shape_env.create_unbacked_symbool()
else:
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
arg.item_memo = r
return r
@register_op_impl(torch.ops.aten.nonzero.default)
def nonzero(fake_mode, func, arg):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
if (nnz := arg.nonzero_memo) is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
# If numel is zero, then the output size must be zero.
# In this case, we must not allocate an unbacked SymInt,
# because if we do, it will immediately get refined to
# zero, but this will be inconsistent with size oblivious
# tests (which will continue to claim that the unbacked
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = 0
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
maxval = sys.maxsize - 1
if not has_free_symbols(arg.numel()):
maxval = int(arg.numel())
_constrain_range_for_size(nnz, max=maxval)
arg.nonzero_memo = nnz
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
@register_op_impl(torch.ops.aten.masked_select.default)
def masked_select(fake_mode, func, self, mask):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
nnz = fake_mode.shape_env.create_unbacked_symint()
# see nonzero for commentary
maxval = sys.maxsize - 1
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
from torch.utils._sympy.numbers import IntInfinity
from torch.utils._sympy.value_ranges import bound_sympy
# If num elements is expressed symbolically, calculate
# the concrete value based on upper bounds. Otherwise,
# we can set max val directly.
if not has_free_symbols(self.numel()):
num_elements = int(self.numel())
else:
prod_node = math.prod(self.shape).node
prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range)
if isinstance(prod_range.upper, IntInfinity):
num_elements = sys.maxsize - 1
else:
num_elements = prod_range.upper
if num_elements > 2:
maxval = num_elements
_constrain_range_for_size(nnz, max=maxval)
return self.new_empty((nnz,))
# NB: this must be ordered after local_scalar_dense
@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
def data_dep(fake_mode, func, *args, **kwargs):
raise DataDependentOutputException(func)
# Bool Indices get Expanded as Masks
# See: IndexingUtils.h:expandTensors
def check_no_bool_index_tensors(func, self, indices):
for index in indices:
if index is not None and index.dtype in (torch.bool, torch.uint8):
raise DynamicOutputShapeException(func)
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs["input"].device
with in_kernel_invocation_manager(fake_mode):
out = func(*args, **kwargs)
if not is_noncontiguous_supported(out_device):
out = out.new_empty(out.shape)
if out is new_kwargs["input"]:
return out # copy_
return FakeTensor(fake_mode, out, out_device)
_is_builtin_namespaces = ordered_set("aten", "prims", "prim")
def is_builtin(op):
return op.namespace in _is_builtin_namespaces
def has_meta(func):
return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
@register_op_impl(
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
)
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
tensor_lists = []
for arg in itertools.chain(args, kwargs.values()):
if (
isinstance(arg, (list, tuple))
and len(arg)
and isinstance(arg[0], torch.Tensor)
):
tensor_lists.append(arg)
try:
with in_kernel_invocation_manager(fake_mode):
out_meta = func(*args, **kwargs)
except NotImplementedError as not_implemented_error:
return NotImplemented
if not out_meta:
return out_meta
assert tensor_lists
out_fake = []
for i, meta_t in enumerate(out_meta):
device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
out_fake.append(
fake_mode.fake_tensor_converter.from_meta_and_device(
fake_mode, meta_t, device
)
)
return out_fake
# Dont default to default device handling,
# Since op can take in non-zero sized cpu
# index tensors with cuda self
@register_op_impl(aten.index.Tensor)
def index_tensor(fake_mode, func, *args, **kwargs):
from torch._meta_registrations import meta_index_Tensor
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs["input"].device
# ensure nonzero call goes to fake tensor
with fake_mode:
out = meta_index_Tensor(*args, **kwargs)
return out.to(out_device)
# Can take mixed meta/non-meta arguments; the meta registration
# will roughly do the right thing even when given real devices
@register_op_impl(aten._embedding_bag.default)
def embedding_bag(fake_mode, func, *args, **kwargs):
from torch._meta_registrations import meta_embedding_bag
with fake_mode:
return meta_embedding_bag(*args, **kwargs)
# takes in multiple-devices, dont default to default device handling
@register_op_impl(aten._unsafe_index_put.default)
@register_op_impl(aten.copy.default)
@register_op_impl(aten.copy_.default)
@register_op_impl(aten.slice_scatter.default)
def multi_device_op_default(fake_mode, func, *args, **kwargs):
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
# same with multi_device_op_default, but return the input
@register_op_impl(aten.copy.out)
@register_op_impl(aten.slice_scatter.out)
def multi_device_op_out(fake_mode, func, *args, **kwargs):
with in_kernel_invocation_manager(fake_mode):
out = func(*args, **kwargs)
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
return new_kwargs["input"]
@register_op_impl(aten.index_put.default)
@register_op_impl(aten.index_put_.default)
def index_put_impl(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
values = new_kwargs["values"]
self_device = new_kwargs["input"].fake_device
torch._check(
self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
)
out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
if func is aten.index_put_.default:
return new_kwargs["input"]
else:
return out
@register_op_impl(aten._nested_tensor_from_tensor_list.default)
@register_op_impl(aten._nested_tensor_from_tensor_list.out)
@register_op_impl(aten._nested_view_from_buffer.default)
@register_op_impl(aten._nested_view_from_buffer_copy.default)
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
raise UnsupportedOperatorException(
"torch.compile does not support strided NestedTensor"
)
@register_op_impl(
[
x
for x in _device_not_kwarg_ops
if x
not in (
# these are already registered elsewhere
aten.is_pinned.default,
aten.to.device,
aten.to.prim_Device,
aten._nested_tensor_from_tensor_list.default,
aten._nested_tensor_from_tensor_list.out,
)
]
)
def nyi(fake_mode, func, *args, **kwargs):
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
def conv(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
device = kwargs["input"].fake_device
# need to re-enable mode so the tensors report fake device
with fake_mode:
# if the input is unsqueezed is done in Convolution.cpp we get segfault
k = kwargs["weight"].ndim
batch = kwargs["input"].shape[0]
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import has_hint
if not has_hint(batch):
# TODO: We can make this a little more faithful with best effort
# channels last detection (but only if it's statically obvious!)
mem_fmt = None
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
mem_fmt = None
else:
if func is aten.convolution.default:
conv_backend = torch._C._select_conv_backend(**kwargs)
else:
conv_backend = torch._C._select_conv_backend(
kwargs["input"],
kwargs["weight"],
bias=None,
stride=kwargs["stride"],
padding=kwargs["padding"],
dilation=kwargs["dilation"],
transposed=kwargs["transposed"],
output_padding=kwargs["output_padding"],
groups=kwargs["groups"],
bias_sizes=kwargs["bias_sizes"],
)
mem_fmt = torch._C._conv_determine_backend_memory_format(
kwargs["input"], kwargs["weight"], conv_backend
)
def convert(t, mem_fmt):
if t is None:
return t
if mem_fmt is not None:
t = t.to(memory_format=mem_fmt)
return FakeTensor(fake_mode, t, device)
with in_kernel_invocation_manager(fake_mode):
out = func(**kwargs)
if func is aten.convolution.default:
return convert(out, mem_fmt)
else:
return (
convert(out[0], mem_fmt),
convert(out[1], mem_fmt),
convert(out[2], None),
)
@register_op_impl(torch.ops.aten._pack_padded_sequence.default)
def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
new_batch_size = fake_mode.shape_env.create_unbacked_symint()
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
_constrain_range_for_size(new_batch_size)
if not batch_first:
# Inputs should have shape (batch_size, seq_len, *)
inputs = inputs.transpose(0, 1)
res_size = inputs.shape[1:]
packed_data = inputs.new_empty(res_size)
batch_size = inputs.new_empty((new_batch_size,))
return (packed_data, batch_size)
FAST_OP_IMPLEMENTATIONS = {}
# Unlike register_op_impl, these don't do the slow iteration for
# run_impl_check, and these run BEFORE decompositions
def register_fast_op_impl(func: OpOverload):
def impl_decorator(op_impl):
FAST_OP_IMPLEMENTATIONS[func] = op_impl
return op_impl
return impl_decorator
# infer_size_impl in ExpandUtils
def infer_size(a, b):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
dimsA = len(a)
dimsB = len(b)
ndim = max(dimsA, dimsB)
expandedSizes = [0] * ndim
for i in range(ndim - 1, -1, -1):
offset = ndim - 1 - i
dimA = dimsA - 1 - offset
dimB = dimsB - 1 - offset
sizeA = a[dimA] if dimA >= 0 else 1
sizeB = b[dimB] if dimB >= 0 else 1
# NB: It is very important to test for broadcasting, before testing
# sizeA == sizeB. This is because the broadcasting tests are likely
# to be statically known (in particular, if sizeA/sizeB is unbacked
# but size-like, we will unsoundly assume they never equal 1), but
# the sizeA == sizeB test may not be statically known. However, once
# we have established that no broadcasting is happening, the
# sizeA == sizeB is now expect_true and we can defer it as a runtime
# assert (this works because Python will return the terminal
# expression of an or statement as-is, without bool()'ing it; if this
# were not the case, we'd need to write this using torch.sym_or() or
# something like that).
torch._check(
guard_size_oblivious(sizeA == 1)
or guard_size_oblivious(sizeB == 1)
or sizeA == sizeB,
lambda: f"The size of tensor a ({sizeA}) "
f"must match the size of tensor b ({sizeB}) "
f"at non-singleton dimension {i})",
)
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
return tuple(expandedSizes)
def make_fast_binary_impl(slow_ref):
def fast_binary_impl(mode, *args, **kwargs):
def slow(msg):
count_label(f"slow {msg}")
with mode:
return slow_ref(*args, **kwargs)
count_label("attempt fast")
# Fast path (based off of TensorIterator fast path).
# Unfortunately, there is no way to easily deduplicate
# this with either the TensorIterator C++ implementation
# (which we don't want to SymIntify, and also the algorithm
# here is slightly different from TensorIterator to allow
# for broadcasting), nor the PrimTorch implementation
# (which does not actually implement a fast path.)
operands = args
# compute_shape
has_scalars = False
has_tensors = False
final_shape = None
for op in operands:
shape = op.shape if isinstance(op, torch.Tensor) else ()
if len(shape) == 0:
has_scalars = True
else:
has_tensors = True
if final_shape is None:
final_shape = shape
# TODO: Minor optimization: track if the shapes
# were equal so you can skip the equality check
# below if unnecessary
final_shape = infer_size(final_shape, shape)
assert final_shape is not None
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
# Do some extra safety checks to see if the output
# stride is obvious
for op in operands:
if (
isinstance(op, torch.Tensor)
and len(op.shape) == len(final_shape)
and guard_size_oblivious(sym_eq(op.shape, final_shape))
):
break
else:
return slow("both tensors nontrivially broadcast")
# compute_types
cpu = torch.device("cpu")
common_device = cpu
common_dtype = None
output_dtype = None
has_different_input_dtypes = False
for op in operands:
if not isinstance(op, torch.Tensor):
# Use elementwise_dtypes for the tricky case
has_different_input_dtypes = True
continue
if common_device == cpu and not op.device.type == "cpu":
common_device = op.device
# Slightly simplified here as target_dtype cannot vary
if common_dtype is None:
common_dtype = op.dtype
elif common_dtype != op.dtype:
has_different_input_dtypes = True
if has_different_input_dtypes:
# compute promotion
# TODO: we don't need the compute type
_, common_dtype = elementwise_dtypes(
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# check all tensors on same device
# cpu scalars are assumed allow
current_cpu_scalars_on_non_cpu = 0
max_cpu_scalars_on_non_cpu = 1 # hard coded atm
for op in operands:
if not isinstance(op, torch.Tensor):
continue
if common_device != cpu and op.dim() == 0 and op.device == cpu:
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
return slow("error")
current_cpu_scalars_on_non_cpu += 1
elif op.device != common_device:
return slow("error")
# compute_fast_setup_type
is_contiguous = True
is_channels_last = True
# TODO: is_non-overlapping_and_dense (not bound from Python
# no inplace, no out, everything defined
if is_noncontiguous_supported(common_device):
for op in operands:
if not isinstance(op, torch.Tensor):
continue
is_contiguous = is_contiguous and op.is_contiguous(
memory_format=torch.contiguous_format
)
is_channels_last = is_channels_last and op.is_contiguous(
memory_format=torch.channels_last
)
if is_contiguous:
# do contiguous
count_label("fast is_contiguous")
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.contiguous_format,
),
device=common_device,
)
if is_channels_last:
count_label("fast channels_last")
# do channels last
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.channels_last,
),
device=common_device,
)
return slow("no contiguity match")
return fast_binary_impl
# disable the python dispatcher to avoid decomposing detach() further
# (proxy_mode should still decompose detach() though)
def fast_detach(fake_mode, x):
with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
out = torch.ops.aten.detach.default(x)
return FakeTensor(fake_mode, out, x.device)
@functools.lru_cache(None)
def get_fast_op_impls():
import torch._refs
register_fast_op_impl(torch.ops.aten.add.Tensor)(
make_fast_binary_impl(torch._refs.add)
)
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
make_fast_binary_impl(torch._refs.sub)
)
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
register_fast_op_impl(torch.ops.aten.div.Tensor)(
make_fast_binary_impl(torch._refs.div)
)
register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
return FAST_OP_IMPLEMENTATIONS

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,197 @@
# mypy: ignore-errors
import functools
import warnings
from typing import Callable, Union
import torch
import torch.utils._pytree as pytree
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import (
FakeTensorMode,
tree_flatten_only,
UnsupportedFakeTensorException,
)
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch._ops.ops.aten
def outputs_alias_inputs(outputs, inputs):
input_storages = {
inp._typed_storage()._cdata
for inp in tree_flatten_only(torch.Tensor, inputs)
if torch._C._has_storage(inp)
}
return any(
torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
for out in tree_flatten_only(torch.Tensor, outputs)
)
def outputs_are_inputs(outputs, inputs):
input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
def output_alias_each_other(outputs):
storages = set()
for out in tree_flatten_only(torch.Tensor, outputs):
if not torch._C._has_storage(out):
continue
stor = out._typed_storage()._cdata
if stor in storages:
return True
storages.add(stor)
return False
def is_sdpa_error(func, idx, e):
if (
(
func is aten._scaled_dot_product_flash_attention.default
or func is aten._flash_attention_forward.default
)
and idx in (6, 7)
and "Devices" in repr(e)
):
return True
if (
(
func is aten._scaled_dot_product_efficient_attention.default
or func is aten._efficient_attention_forward.default
)
and idx in (2, 3)
and "Devices" in repr(e)
):
return True
if (
func is aten._scaled_dot_product_cudnn_attention.default
and idx in (6, 7)
and "Devices" in repr(e)
):
return True
return False
class CrossRefFakeMode(TorchDispatchMode):
def __init__(
self,
ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
*,
check_strides=True,
check_aliasing=True,
):
super().__init__()
self.ignore_op_fn = (
ignore_op_fn if ignore_op_fn is not None else lambda fn: False
)
self.check_strides = check_strides
self.check_aliasing = check_aliasing
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
fake_r = None
# empty_like excluded for now due to sparse complex
# aten._to_dense.default this one is getting called with csc
if (
func
not in (
aten.lift_fresh.default,
aten.lift_fresh_copy.default,
aten.set_.source_Storage_storage_offset,
)
and not self.ignore_op_fn(func)
and torch.Tag.dynamic_output_shape not in func.tags
and torch.Tag.inplace_view not in func.tags
and torch.Tag.data_dependent_output not in func.tags
):
# Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
from torch.fx.experimental.symbolic_shapes import ShapeEnv
try:
# TODO: enable_python_dispatcher() here
with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
fake_args, fake_kwargs = pytree.tree_map_only(
torch.Tensor,
functools.partial(fake_mode.from_tensor, static_shapes=True),
(args, kwargs),
)
with warnings.catch_warnings():
fake_r = func(*fake_args, **fake_kwargs)
except UnsupportedFakeTensorException:
pass
context = (
f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
f"found"
)
r = func(*args, **kwargs)
if fake_r is not None:
r_flat = pytree.tree_leaves(r)
f_flat = pytree.tree_leaves(fake_r)
assert len(f_flat) == len(
r_flat
), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
if self.check_aliasing:
r_aliasing = outputs_alias_inputs(r, (args, kwargs))
f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
assert (
r_aliasing == f_aliasing
), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
r_identity_eq = outputs_are_inputs(r, (args, kwargs))
f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
assert (
r_identity_eq == f_identity_eq
), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
r_output_alias_each_other = output_alias_each_other(r)
f_output_alias_each_other = output_alias_each_other(fake_r)
assert r_output_alias_each_other == f_output_alias_each_other, (
f"{context} mismatch in outputs_alias_each_other check "
f"{f_output_alias_each_other} != {r_output_alias_each_other}"
)
for idx, (r_out, fake_out) in enumerate(
zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
):
r_is_ten = isinstance(r_out, torch.Tensor)
assert r_is_ten == isinstance(
fake_out, torch.Tensor
), f"{context} mismatched number of tensor outputs"
if r_is_ten:
assert r_out.requires_grad == fake_out.requires_grad, (
f"{context} mismatched requires_grad-ness of outputs. "
f"This usually means that you have added autograd support "
f"for your operator at a dispatch key other than Autograd, "
f"which will lead to problems"
)
if torch._C._has_storage(r_out):
r_offset = r_out.storage_offset()
f_offset = fake_out.storage_offset()
assert (
r_offset == f_offset
), f"{context} mismatched storage offset"
try:
torch._prims.utils.compare_tensor_meta(
r_out,
fake_out,
check_strides=self.check_strides,
allow_rhs_unbacked=True,
)
except Exception as e:
if is_sdpa_error(func, idx, e):
continue
error_message = (
f"{context} mismatched tensor metadata: {e}"
if len(r_flat) == 1
else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
)
raise RuntimeError(error_message) from e
return r

View File

@ -0,0 +1,807 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
import weakref
from abc import ABC, abstractmethod
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
import torch
import torch._inductor.config as inductor_config
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
from torch._subclasses.meta_utils import is_sparse_any
from torch.utils._python_dispatch import (
_detect_infra_mode,
_disable_infra_mode,
return_and_correct_aliasing,
TorchDispatchMode,
)
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
# NOTE Some special handling for tensor conversion during export is needed.
# Normally, when tracing through the model with tensor.to(), the maybe-aliasing
# relationship between input and output tensors will be baked into the graph.
# For example, if we got a tensor with device cpu and call tensor.to("cpu"),
# it will become a no-op in the graph. For a whole graph capture, this is not
# sound so we need to do something different. Instead, in export we will try to
# preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy
# operator to be traced in the graph, and subsequently banning mutations on all
# such converted tensors.
# In addition to patching .to() method call in functionalization, we will have to
# patch other similar methods like float() and cpu(), because they intentionally
# don't fall back to .to() methods, but have the same behavior as .to() according to
# pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html
# thus we simply force them to go through .to() call.
def _conversion_method_template(**extra_kwargs):
def _(self, *args, **kwargs):
return self.to(*args, **{**kwargs, **extra_kwargs})
return _
class FunctionalTensor(torch.Tensor):
"""
Functional tensors represent tensors that will remove mutations
from a program. If you perform a mutable operation on a functional tensor,
it will re-dispatch to the functional variant of that operation.
Historically, functionalization is implemented in C++ in the dispatcher.
This class is a lightweight python shim around the C++ functionalization logic.
FunctionalTensor is required to be used with a corresponding
FunctionalTensormode active, because it relies
on using the mode for dispatch (which can properly handle factory functions).
"""
elem: torch.Tensor
# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
_mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL
# Note: The reason we add these extra keys to our FunctionalTensor subclass
# is to mirror the behavior of C++ functionalization (we can choose to change this
# later, as long as it doesn't break anything).
# FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor
# to the wrapper, excluding functorch and python dispatch keys.
# Here I'm trying to re-use the keyset the functorch wrapper subclasses copy,
# except that they don't include ZeroTensor so I'm manually adding it in.
_extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add(
torch._C.DispatchKey.ZeroTensor
)
# These are all aten ops that correspond to metadata queries.
# We want FunctionalTensor to be able to handle them directly.
metadata_fns = [
torch.ops.aten.is_contiguous.default, # type: ignore[has-type]
torch.ops.aten.is_contiguous.memory_format, # type: ignore[has-type]
torch.ops.aten.is_strides_like_format.default, # type: ignore[has-type]
torch.ops.aten.is_non_overlapping_and_dense.default, # type: ignore[has-type]
torch.ops.aten.size.default, # type: ignore[has-type]
torch.ops.aten.sym_size.default, # type: ignore[has-type]
torch.ops.aten.stride.default, # type: ignore[has-type]
torch.ops.aten.sym_stride.default, # type: ignore[has-type]
torch.ops.aten.storage_offset.default, # type: ignore[has-type]
torch.ops.aten.sym_storage_offset.default, # type: ignore[has-type]
torch.ops.aten.numel.default, # type: ignore[has-type]
torch.ops.aten.sym_numel.default, # type: ignore[has-type]
torch.ops.aten.dim.default, # type: ignore[has-type]
torch.ops.prim.device.default, # type: ignore[has-type]
]
# These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing
# TODO (tmanlaibaatar) make it a tag
maybe_aliasing_or_mutating_ops = [
torch.ops.aten.dropout.default, # type: ignore[has-type]
torch.ops.aten.batch_norm.default, # type: ignore[has-type]
torch.ops.aten.native_batch_norm.default, # type: ignore[has-type]
torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type]
torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type]
torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type]
torch.ops.aten.atleast_1d.default, # type: ignore[has-type]
torch.ops.aten.atleast_2d.default, # type: ignore[has-type]
torch.ops.aten.atleast_3d.default, # type: ignore[has-type]
torch.ops.aten.cartesian_prod.default, # type: ignore[has-type]
torch.ops.aten.conj_physical.default, # type: ignore[has-type]
torch.ops.aten.alpha_dropout.default, # type: ignore[has-type]
torch.ops.aten.feature_dropout.default, # type: ignore[has-type]
torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type]
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
]
# Used by auto_functionalize to determine base of tensors during inference mode.
_inference_mode_base: Optional["FunctionalTensor"] = None
def __new__(cls, elem, mode):
assert torch._is_functional_tensor(elem)
# In general, we'd like our functional tensor subclass to only be in charge of functionalization,
# and defer to the inner subclass for all other functionality.
# Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback
# until after we redispatch to our inner ZeroTensor.
# However, there are a few keys that we need to mirror between the inner and outer tensors.
# Conjugate
# Negative
# Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`.
# We **need** calls to is_conj() to return the same thing on the outer and inner tensors,
# Because user code / framework code that branches like so needs to do the same thing
# when it sees the outer FunctionalTensor:
# if (x.is_conj()) {
# return at::view_as_real(x.resolve_conj());
# } else {
# return at::view_as_real(x);
# }
extra_dispatch_keys = (
FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem)
)
out = torch.Tensor._make_wrapper_subclass( # type: ignore[arg-type, attr-defined]
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
cls,
elem.shape, # sizes
elem.stride() if not is_sparse_any(elem) else None, # strides
(
elem.storage_offset() if not is_sparse_any(elem) else None
), # storage_offset
None, # memory_format
elem.dtype, # dtype
elem.layout, # layout
elem.device, # device
False, # pin_memory
elem.requires_grad, # requires_grad
None, # dispatch_sizes_strides_policy
False, # dispatch_device
False, # dispatch_layout
extra_dispatch_keys, # _extra_dispatch_keys
)
torch._C._set_throw_on_mutable_data_ptr(out)
out.elem = elem
if (
torch.is_inference_mode_enabled()
and torch._inductor.config.enable_auto_functionalized_v2
):
if out.is_base_tensor():
out._inference_mode_base = None
# This assumes that the FunctionalTensor.elem does not change its storage after this point.
# Otherwise this would be invalid.
mode._storage_to_base[out.elem.untyped_storage()] = out
else:
out._inference_mode_base = mode._storage_to_base[
out.elem.untyped_storage()
]
assert out._inference_mode_base is not None
return out
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
unrecognized_types = [
t
for t in types
if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor]
]
if unrecognized_types:
not_implemented_log.debug(
"FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
)
return NotImplemented
if kwargs is None:
kwargs = {}
# FunctionalTensor needs to plumb all metadata requests to the inner tensor.
# In theory we don't have to do this - but if we want to service metadata requests here,
# we need to carefully make sure all metadata is accurate (including metadata mutations)
if func in FunctionalTensor.metadata_fns:
# All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry
# about the problem of keeping metadata in sync between the wrapper and inner tensor.
# This also alleviates us from having to manually handle metadata mutations on the wrapper.
assert len(kwargs) == 0
if func in [
torch.ops.aten.is_strides_like_format.default,
torch.ops.aten.is_contiguous.memory_format,
]:
assert len(args) == 2 and isinstance(args[0], FunctionalTensor)
return func(torch._from_functional_tensor(args[0].elem), args[1])
assert len(args) == 1 and isinstance(args[0], FunctionalTensor)
return func(torch._from_functional_tensor(args[0].elem))
# Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up:
# - _make_wrapper_subclass requires a __torch_dispatch__
# - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor,
# which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper.
# - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(),
# which causes every subclass created above autograd to have autograd view metadata
# (in addition to also being a FunctionalTensorWrapper).
raise RuntimeError(
"Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()"
)
def __repr__(self):
return f"FunctionalTensor({repr(self.elem)})"
@staticmethod
def to_functional(x):
# We will do the wrapping for the user.
assert not torch._is_functional_tensor(x)
# The only autograd metadata we care about on the FunctionalTensor is:
# - requires_grad (so autograd runs)
# - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine)
# this is handled by FunctionalTensor.to_functional
x_functional = torch._to_functional_tensor(x)
# Technically the FunctionalTensormode here is unnecessary,
# but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing.
# _mirror_autograd_meta_to queries tensor sizes,
# and otherwise the sym_size() call will go to the proxy mode before hitting
# FunctionalTensor.__torch_dispatch__
functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
assert functional_mode is not None
with functional_mode:
torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined]
out = FunctionalTensor(x_functional, functional_mode)
torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined]
return out
def from_functional(self):
torch._sync(self)
return torch._from_functional_tensor(self.elem)
def is_base_tensor(self) -> bool:
return torch._is_functional_tensor_base(self.elem)
def replace_(self, output) -> None:
torch._functionalize_replace(self.elem, output)
def commit_update(self) -> None:
torch._functionalize_commit_update(self.elem)
def sync(self) -> None:
torch._functionalize_sync(self.elem)
def mark_mutation_hidden_from_autograd(self) -> None:
torch._functionalize_mark_mutation_hidden_from_autograd(self.elem)
def tolist(self) -> Any:
if self.elem.dim() == 0:
return self.elem.item()
elif self.elem.dim() == 1:
return [elem.item() for elem in self.elem]
else:
return [elem.tolist() for elem in self.elem]
def to(self, *args, **kwargs):
if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export:
# If copy is specified as pos arg, it's always the second one.
if len([arg for arg in args if isinstance(arg, bool)]) <= 1:
return super().to(*args, **{**kwargs, "copy": True})
return super().to(*args, **kwargs)
def cuda(self, device=None, *args, **kwargs):
device = device or torch.cuda.current_device()
if len(args) > 0:
return self.to(device, *args, **kwargs)
else:
return self.to(device=device, **kwargs)
char = _conversion_method_template(dtype=torch.int8)
cpu = _conversion_method_template(device=torch.device("cpu"))
bfloat16 = _conversion_method_template(dtype=torch.bfloat16)
byte = _conversion_method_template(dtype=torch.uint8)
double = _conversion_method_template(dtype=torch.float64)
float = _conversion_method_template(dtype=torch.float32)
bool = _conversion_method_template(dtype=torch.bool)
half = _conversion_method_template(dtype=torch.float16)
int = _conversion_method_template(dtype=torch.int32)
long = _conversion_method_template(dtype=torch.int64)
# TODO(sparse-team): fixes #133174 but can we do without the relay?
def to_dense(self):
return self.elem.to_dense()
@property
def layout(self):
return self.elem.layout
class FunctionalTensorMode(TorchDispatchMode):
def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
super().__init__()
self.export = export
self.is_on_stack = False
self.enter_stack = []
# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL
self.pre_dispatch = pre_dispatch
# This will be turned off later for pre-dispatch functionalization
self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
# Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
# track of the ordering between side effectful operations.
self._tokens: Dict[Any, torch.Tensor] = {}
# Filled after forward tracing.
self._tokens_forward_output: Dict[Any, torch.Tensor] = {}
# Functionalization runs twice in AOTAutograd, once in
# `run_functionalized_fw_and_collect_metadata` to collect metadata to
# see which tensors need to be functionalized and discover how many
# tokens we need, and another time in `make_fx` which does the actual
# tracing to replace ops with their functional variants and handling
# side-effectful ops. In the second stage there should be no token
# discovery. This flag distinguishes between the two stages.
self._allow_token_discovery = _allow_token_discovery
self._storage_to_base: weakref.WeakKeyDictionary[
torch.storage.UntypedStorage, Optional[FunctionalTensor]
] = weakref.WeakKeyDictionary()
# No-op if FunctionalTensorMode is already in use
def __enter__(self):
def _get_prev_mode():
if self._dispatch_key == torch._C.DispatchKey.PreDispatch:
return _get_dispatch_mode_pre_dispatch(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
return torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
if _get_prev_mode() is None:
self.enter_stack.append(True)
return super().__enter__()
else:
self.enter_stack.append(False)
return self
def __exit__(self, a, b, c):
is_on_stack = self.enter_stack.pop()
if is_on_stack:
super().__exit__(a, b, c)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if self.export:
# We need to make sure that we don't decompose to() as usual in export mode,
# because it can get optimized away. Instead we always replace it with _to_copy().
if func == torch.ops.aten.to.dtype_layout:
kwargs.pop("copy", None)
return self.__torch_dispatch__(
torch.ops.aten._to_copy.default, types, args, kwargs
)
if func == torch.ops.aten.to.dtype:
schema = tuple(arg.name for arg in func._schema.arguments)
for arg, name in zip(args[1:], schema[1:]):
kwargs[name] = arg
kwargs.pop("copy", None)
return self.__torch_dispatch__(
torch.ops.aten._to_copy.default, types, args[:1], kwargs
)
unrecognized_types = [
t
for t in types
if not issubclass(t, torch._subclasses.FakeTensor)
and t not in [torch.Tensor, FunctionalTensor]
]
if unrecognized_types:
not_implemented_log.debug(
"FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
)
return NotImplemented
def _can_decompose(func):
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832
# Never decompose dropout in export
if self.export and func == torch.ops.aten.dropout.default:
return False
# We unconditionally decompose ops that are maybe aliasing or mutating ops
if func in FunctionalTensor.maybe_aliasing_or_mutating_ops:
return True
# (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops,
# because we must know statically of an op mutates or aliasing in order to functionalize it properly
# (2) for mutating ops that have CompositeImplicit decomps, we choose to decompose them today.
# In theory, we could walk this back and avoid decomposing them later if we need to.
alias_info_present = any(arg.alias_info for arg in func._schema.arguments)
if alias_info_present or func._schema.is_mutable:
return True
# If we are here, it means we are seeing functional composite op.
# For pre-dispatch IR or export inference IR, we wont' decompose them
if (self.export or self.pre_dispatch) and func._can_decompose():
if func.namespace not in ["aten", "prim"]:
# TODO (tmanlaibaatar) check if the op is PT2 compliant
warnings.warn(
f"At pre-dispatch tracing, we assume that any custom op marked with "
f"CompositeImplicitAutograd and have functional schema are safe to not decompose. "
f"Found {func} to be one such op."
)
return False
# in normal torch.compile IR, we decompose functional composite ops
return True
if (
func not in FunctionalTensor.metadata_fns
and _can_decompose(func)
# Not all funcs from __torch_dispatch__ are actual dispatcher ops,
# e.g. prim.device
and torch._C._dispatch_has_kernel(func.name())
):
with self:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:
return r
def wrap(x):
# Only wrap our outputs in subclasses if the inner functionalization call
# also wrapped outputs into FunctionalTensorWrappers.
# When can this happen? e.g. `torch.div(2, 2)`
assert not isinstance(x, FunctionalTensor)
if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
return FunctionalTensor(x, self)
return x
def unwrap(x):
return x.elem
from torch._higher_order_ops.auto_functionalize import (
can_auto_functionalize,
do_auto_functionalize,
do_auto_functionalize_v2,
)
if can_auto_functionalize(
func
) and not torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), torch._C.DispatchKey.Functionalize
):
# it doesn't matter what mode we use here because
# the implementation of do_auto_functionalize doesn't
# interact with FunctionalTensorMode at all
if self.export or not inductor_config.enable_auto_functionalized_v2:
return do_auto_functionalize(func, args, kwargs)
else:
return do_auto_functionalize_v2(func, args, kwargs)
from torch._higher_order_ops.effects import handle_effects, has_effects
if has_effects(func, args, kwargs):
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), torch._C.DispatchKey.Functionalize
)
return handle_effects(
self._allow_token_discovery, self._tokens, func, args, kwargs
)
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
FunctionalTensor, unwrap, (args, kwargs)
)
# Expectation: functionalization should not **already** be enabled above our mode.
# Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization
# to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper.
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.Functionalize
)
is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded(
torch._C.DispatchKey.Functionalize
)
assert is_excluded or not is_included
include_to_set = (
torch._C._dispatch_tls_local_include_set()
| torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
exclude_to_set = (
torch._C._dispatch_tls_local_exclude_set().remove(
torch._C.DispatchKey.Functionalize
)
- FunctionalTensor._extra_dispatch_keys
)
# All we want to do here is re-use the existing C++ functionalization logic.
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
try:
# By default for python functionalization (for AOTAutograd), we reapply views.
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
# Sometimes these functions cannot be directly dispatched to functionalize key
# because args are sometimes not functional tensors for some reason?
if func in FunctionalTensor.metadata_fns:
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
else:
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
# back to PreDispatch later
outs_unwrapped = func._op_dk(
torch._C.DispatchKey.Functionalize,
*args_unwrapped,
**kwargs_unwrapped,
)
# We don't allow any mutation on result of dropout or _to_copy
if self.export:
if func in (
torch.ops.aten.dropout.default,
torch.ops.aten._to_copy.default,
):
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
finally:
torch._disable_functionalization()
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.Functionalize
)
is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded(
torch._C.DispatchKey.Functionalize
)
assert is_excluded or not is_included
if (
# If no outputs are our functional subclass, then don't try to fix up aliasing
not any(
isinstance(x, FunctionalTensor)
for x in pytree.tree_leaves(outs_wrapped)
)
# Since lift_fresh lifts its argument into a functional tensor, we can skip the
# aliasing correction step. Otherwise, we would be setting the storage of a
# lifted tensor to that of an unlifted tensor.
# Ref: https://github.com/pytorch/pytorch/issues/111506
or func == torch.ops.aten.lift_fresh.default
):
return outs_wrapped
# for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper
if (
torch.Tag.inplace_view in func.tags
and func is not torch.ops.aten.set_.source_Tensor
):
with torch.utils._mode_utils.no_dispatch():
func(*args, **kwargs)
# Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing.
# inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects.
# Use this util to figure out the right thing to return.
# If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for.
return return_and_correct_aliasing(func, args, kwargs, outs_wrapped)
@classmethod
def is_infra_mode(cls) -> bool:
return True
@contextlib.contextmanager
def disable_functional_mode():
return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
# This is similar to torch.func.functionalize, but:
# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass).
# One important advantage to using this mode is that it will let us
# run functionalization underneath __torch_dispatch__,
# which we need in AOTAutograd.
# - Doing so means that it does not automatically compose with other
# functorch transforms, since these transforms always run above __torch_dispatch__.
# That's why this util lives here, and not in functorch.
def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()):
# TODO: pull these from aot autograd
def to_fun(t):
if isinstance(t, torch.Tensor):
return FunctionalTensor.to_functional(t)
return t
def from_fun(t):
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t)
return t
torch._sync(t)
return torch._from_functional_tensor(t.elem)
def inner(*args, **kwargs):
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
with disable_above, mode:
func_args = pytree.tree_map_only(torch.Tensor, to_fun, args)
func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs)
func_outputs = func(*func_args, **func_kwargs)
outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs)
return outputs
return inner
class BaseFunctionalizeAPI(ABC):
@abstractmethod
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
pass
@abstractmethod
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
) -> Any:
pass
@abstractmethod
def functionalize(self, inner_f: Callable) -> Callable:
pass
@abstractmethod
def redispatch_to_next(self) -> ContextManager:
pass
@abstractmethod
def replace(self, input_tensor, output_tensor) -> None:
pass
@abstractmethod
def commit_update(self, tensor) -> None:
pass
@abstractmethod
def sync(self, tensor) -> None:
pass
@abstractmethod
def mark_mutation_hidden_from_autograd(self, tensor) -> None:
pass
class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
def __init__(
self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False
) -> None:
super().__init__()
self.mode = mode if mode else FunctionalTensorMode()
self.pre_dispatch = pre_dispatch
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
with self.mode:
return torch.utils._pytree.tree_map_only(
torch.Tensor, FunctionalTensor.to_functional, args
)
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]]
) -> Any:
return torch.utils._pytree.tree_map_only(
FunctionalTensor, FunctionalTensor.from_functional, args
)
def functionalize(self, inner_f: Callable) -> Callable:
return dispatch_functionalize(inner_f, self.mode)
def redispatch_to_next(self) -> ContextManager:
# [NOTE] We don't do anything here because at the time
# we exercise this path, we would have already popped the
# FunctionalTensorMode from mode stack. Since FunctionalTensorMode
# is now stateful, it is better to explicitly pass in correct mode
# directly instead of globally setting it.
return contextlib.nullcontext()
def replace(self, input_tensor, output_tensor) -> None:
assert isinstance(input_tensor, FunctionalTensor)
assert not isinstance(output_tensor, FunctionalTensor)
input_tensor.replace_(output_tensor)
def commit_update(self, tensor) -> None:
assert isinstance(tensor, FunctionalTensor)
tensor.commit_update()
def sync(self, tensor) -> None:
assert isinstance(tensor, FunctionalTensor)
tensor.sync()
def mark_mutation_hidden_from_autograd(self, tensor) -> None:
assert isinstance(tensor, FunctionalTensor)
tensor.mark_mutation_hidden_from_autograd()
class CppFunctionalizeAPI(BaseFunctionalizeAPI):
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
return _wrap_all_tensors_to_functional(args, level=0)
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
)
return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views())
def functionalize(self, inner_f: Callable) -> Callable:
return torch.func.functionalize(inner_f)
def redispatch_to_next(self) -> ContextManager:
return torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
def replace(self, input_tensor, output_tensor) -> None:
torch._functionalize_replace(input_tensor, output_tensor)
def commit_update(self, tensor) -> None:
torch._functionalize_commit_update(tensor)
def sync(self, tensor) -> None:
torch._functionalize_sync(tensor)
def mark_mutation_hidden_from_autograd(self, tensor) -> None:
torch._functionalize_mark_mutation_hidden_from_autograd(tensor)
class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
def __init__(self, interpreter):
self.interpreter = interpreter
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional
return _wrap_all_tensors_to_functional(args, level=self.interpreter.level())
def unwrap_tensors(
self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
)
return _unwrap_all_tensors_from_functional(
args, reapply_views=self.interpreter.functionalize_add_back_views()
)
def functionalize(self, inner_f: Callable) -> Callable:
return torch.func.functionalize(
inner_f,
remove=(
"mutations_and_views"
if self.interpreter.functionalize_add_back_views()
else "mutations"
),
)
def redispatch_to_next(self) -> ContextManager:
return self.interpreter.lower()
def replace(self, input_tensor, output_tensor) -> None:
torch._functionalize_replace(input_tensor, output_tensor)
def commit_update(self, tensor) -> None:
torch._functionalize_commit_update(tensor)
def sync(self, tensor) -> None:
torch._functionalize_sync(tensor)
def mark_mutation_hidden_from_autograd(self, tensor) -> None:
torch._functionalize_mark_mutation_hidden_from_autograd(tensor)
def mb_unwrap_functional_tensor(tensor: torch.Tensor):
if isinstance(tensor, FunctionalTensor):
return torch._from_functional_tensor(tensor.elem)
return tensor

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,230 @@
# mypy: ignore-errors
from collections import namedtuple
from copy import deepcopy
from itertools import combinations
import torch
from torch.fx.operator_schemas import normalize_function
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
# Named Tuples used within SchemaCheckMode
Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
# Simplified naming for C++ classes
SchemaArgument = torch._C._SchemaArgument
SchemaArgType = torch._C._SchemaArgType
SchemaInfo = torch._C._SchemaInfo
# This TorchDispatchMode Subclass is used to verify op schemas
# This TorchDispatchMode Scubclass currently:
# - Records the called ops
# - Checks for mutations on all inputs
# - Checks for aliasing on all inputs
# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py
def is_iterable_of_tensors(iterable):
# Tensor itself is iterable so we check this first
if isinstance(iterable, torch.Tensor):
return False
try:
if len(iterable) == 0:
return False
for t in iter(iterable):
if not isinstance(t, torch.Tensor):
return False
except TypeError as te:
return False
return True
def clone_inputs(args):
inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
inputs.append(arg.detach().clone())
elif is_iterable_of_tensors(arg):
inputs.append([t.detach().clone() for t in arg])
else:
inputs.append(arg)
return inputs
class SchemaCheckMode(TorchDispatchMode):
def __init__(self) -> None:
# Information recorded for testing purposes. For example:
# - incorrect schemas
# - overly conservative schemas
self.ops = []
self.mutated = []
self.aliasing = []
def reset_cache(self):
self.ops.clear()
self.mutated.clear()
self.aliasing.clear()
def display_ops(self):
print(*self.ops, sep=",")
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def bitwise_equal(lhs, rhs):
if lhs.is_quantized:
# TODO: This is only OK if can't have NaN quantized; idk if
# this is actually true
return torch.equal(lhs, rhs)
else:
return torch.allclose(lhs, rhs, equal_nan=True)
def has_mutated(before, after, md):
are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
if (
are_tensors
and before.layout != torch.sparse_csr
and after.layout != torch.sparse_csr
):
return not (
before.size() == after.size()
and bitwise_equal(before, after)
and md[0] == after.stride()
and md[1] == after._typed_storage()._cdata
)
return False
def has_aliased(lhs, rhs):
try:
return torch._C._overlaps(lhs, rhs)
except Exception as exception:
if str(exception).startswith("Cannot inspect value of type "):
return False
else:
raise exception
def standardize_name(name):
return name if name != "self" else "input"
def unwrap(e):
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
try:
return e.elem
except AttributeError as t:
return e
return e
def parse_metadata(e):
if isinstance(e, torch.Tensor):
if not type(e) == torch.Tensor:
try:
current = e.elem
return (
deepcopy(current.stride()),
current._typed_storage()._cdata,
)
except AttributeError as t:
return None
# Sparse CSR tensors do not have strides or storage
elif e.layout != torch.sparse_csr:
return (deepcopy(e.stride()), e._typed_storage()._cdata)
return None
self.ops.append(func._schema.name)
# Clone and process arguments and outputs
pre_arguments = normalize_function(
func, args, kwargs, normalize_to_only_use_kwargs=True
).kwargs
c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
cloned_arguments = {
name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
}
cloned_metadata = {
name: [
parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
]
for name in pre_arguments
}
out = func(*args, **kwargs)
arguments = {
name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
}
tuple_out = out if isinstance(out, tuple) else (out,)
tuple_out = tree_map(unwrap, tuple_out)
schema_info = SchemaInfo(func._schema)
schema_info.add_argument_values(pre_arguments)
# Process arguments with outputs
for i in range(len(func._schema.arguments)):
arg = func._schema.arguments[i]
name = standardize_name(arg.name)
if arguments.get(name) is not None:
before = cloned_arguments.get(name)
md = cloned_metadata.get(name)
after = arguments.get(name)
for j in range(len(tuple_out)):
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
if (
has_aliased(tuple_out[j], after)
and func._schema.name not in unsafe_ops
):
if not schema_info.may_contain_alias(
SchemaArgument(SchemaArgType.output, j),
SchemaArgument(SchemaArgType.input, i),
):
raise RuntimeError(
f"Argument {name} is not defined to alias output but was aliasing"
)
else:
self.aliasing.append(
Aliasing(func._schema.name, name, f"output_{j}")
)
if after is tuple_out[j] and isinstance(after, torch.Tensor):
# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
if not schema_info.is_mutable(
SchemaArgument(SchemaArgType.input, i)
) and func not in [
torch.ops.aten.lift.default,
torch.ops.aten.lift_fresh.default,
]:
raise RuntimeError(
f"""\
Dispatcher operators below autograd are not allowed to directly return inputs.
However, we found that `outputs[{str(j)}] is {name}"""
)
if any(
has_mutated(a, b, c)
for a, b, c in zip(
pytree.tree_leaves(before), pytree.tree_leaves(after), md
)
):
if not schema_info.is_mutable(
SchemaArgument(SchemaArgType.input, i)
):
raise RuntimeError(
f"Argument {name} is not defined as mutable but was mutated"
)
else:
self.mutated.append(Mutation(func._schema.name, name))
# Aliasing between outputs
for i, j in combinations(range(len(func._schema.returns)), 2):
if has_aliased(tuple_out[i], tuple_out[j]):
if not schema_info.may_contain_alias(
SchemaArgument(SchemaArgType.output, i),
SchemaArgument(SchemaArgType.output, j),
):
raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
return out