I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,6 @@
import torch._library.autograd
import torch._library.fake_impl
import torch._library.simple_registry
import torch._library.utils
from torch._library.fake_class_registry import register_fake_class
from torch._library.triton import capture_triton, triton_op

View File

@ -0,0 +1,241 @@
# mypy: allow-untyped-defs
import dataclasses
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Protocol
from torch import _C, _ops, autograd, Tensor
from torch.utils import _pytree
from . import utils
class InfoProtocol(Protocol):
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
@dataclasses.dataclass
class Info:
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)
@dataclass
class Metadata:
keyset: _C.DispatchKeySet
keyword_only_args: Dict[str, Any]
def forward_no_grad(*args):
metadata = args[-1]
args = args[:-1]
with _C._AutoDispatchBelowAutograd():
keyset = metadata.keyset
kwargs = metadata.keyword_only_args
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
return result
def forward(ctx, *args):
metadata = args[-1]
args = args[:-1]
with _C._AutoDispatchBelowAutograd():
keyset = metadata.keyset
kwargs = metadata.keyword_only_args
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
if info._setup_context_fn:
# The Dispatcher will remove args that are equal to their default
# values from (args, kwargs). We're going to add it back so that
# the user can access them.
#
# This is OK to do: The Dispatcher removed the args for serialization
# FC/BC reasons (that is, a graph will not store args that are equal
# to their default values), but that doesn't matter here. If the user
# adds a new default arg, then they must update
# their setup_context (along with the rest of their operator
# registrations)
args, kwargs = utils.fill_defaults(op._schema, args, kwargs)
if has_kwarg_only_args:
info._setup_context_fn(
ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result
)
else:
info._setup_context_fn(ctx=ctx, inputs=args, output=result)
return result
def backward(ctx, *grads):
if info._backward_fn:
try:
prev_needs_input_grad = ctx.needs_input_grad
ctx.needs_input_grad = ctx.needs_input_grad[:-1]
result = info._backward_fn(ctx, *grads)
finally:
ctx.needs_input_grad = prev_needs_input_grad
if isinstance(result, tuple):
return (*result, None)
return result, None
raise RuntimeError(
f"Trying to backward through {op} but no autograd "
f"formula was registered. "
f"Please use register_autograd to add one."
)
Generated = type(
name,
(autograd.Function,),
{
"forward": staticmethod(forward),
"backward": staticmethod(backward),
},
)
schema = op._schema
if any(
utils.is_tensorlist_like_type(a.type)
for a in (*schema.arguments, *schema.returns)
):
Generated = supports_tensorlist(Generated)
# The dispatcher passes any keyword-only-args as kwargs and the
# rest of the args (even if specified as kwargs) as args.
def autograd_impl(keyset, *args, **keyword_only_args):
if _C.is_grad_enabled() and _pytree.tree_any_only(
Tensor, lambda x: x.requires_grad, args, not_list_of_tensor
):
result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
else:
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
return result
return autograd_impl
def supports_tensorlist(cls: Any) -> Any:
"""Allows a given autograd.Function class to support List[Tensor] inputs/outputs.
Regular autograd.Function has a constraint that it only directly supports autograd for
Tensors. Applying @supports_tensorlist enables an autograd.Function to support
autograd for List[Tensor] inputs and outputs.
"""
orig_forward = cls.forward
orig_backward = cls.backward
orig_apply = cls.apply
@dataclass
class Metadata:
input_spec: spec_t
output_spec: Optional[spec_t] = None
result_is_tuple: Optional[bool] = None
def new_forward(ctx, *args):
metadata = args[-1]
args = args[:-1]
if not isinstance(metadata, Metadata):
raise NotImplementedError(
"NYI: calling supports_tensorlist autograd.Function.forward directly. "
"You should probably be calling .apply instead. "
"Please file an issue if not."
)
args = unflatten(list(args), metadata.input_spec)
result = orig_forward(ctx, *args)
metadata.result_is_tuple = isinstance(result, tuple)
if not metadata.result_is_tuple:
result = (result,)
flat_result, output_spec = flatten(result, not_list_of_tensor)
metadata.output_spec = output_spec
if hasattr(ctx, "_pt_metadata"):
raise RuntimeError(
"Please don't set ctx._pt_metadata; PyTorch uses it to store info"
)
ctx._pt_metadata = metadata
return tuple(flat_result)
def new_backward(ctx, *grads):
if not hasattr(ctx, "_pt_metadata"):
raise NotImplementedError(
"NYI: calling supports_tensorlist autograd.Function.backward directly. "
"This will automatically get called by PyTorch autograd. "
"Please file an issue if you need this."
)
metadata = ctx._pt_metadata
grads = unflatten(list(grads), metadata.output_spec)
# If the user's input is ([x, y, z], w),
# then needs_input_grad is (bool, bool, bool, bool, bool).
# We need to
# 1. get rid of the additional bool (which comes from the extra
# `metadata input`)
# 2. unflatten to get the right structure.
prev_needs_input_grad = ctx.needs_input_grad
try:
ctx.needs_input_grad = unflatten(
list(ctx.needs_input_grad[:-1]), metadata.input_spec
)
grad_inputs = orig_backward(ctx, *grads)
finally:
ctx.needs_input_grad = prev_needs_input_grad
if not isinstance(grad_inputs, tuple):
grad_inputs = (grad_inputs,)
# Assume that any Nones in the backward are Tensors.
# If the forward has an arg that is [1, 2, 3], the backward should
# return None as the grad.
# If the forward has an arg that is [tensor, tensor], the backward
# may return [None, None], [grad, None], [None, grad], or [grad, grad].
flat_grad_inputs, grad_inputs_spec = flatten(
grad_inputs, not_list_of_optional_tensor
)
if grad_inputs_spec != metadata.input_spec:
raise RuntimeError(
f"Expected the return from backward to be of the same structure "
f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
f"{metadata.input_spec} (inputs)"
)
return tuple(flat_grad_inputs + [None])
def new_apply(*args):
flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
metadata = Metadata(input_spec)
result = orig_apply(*flat_args, metadata) # type: ignore[misc]
assert metadata.output_spec is not None
result = unflatten(list(result), metadata.output_spec)
if not metadata.result_is_tuple:
assert isinstance(result, tuple)
assert len(result) == 1
return result[0]
return result
cls.forward = new_forward
cls.backward = new_backward
cls.apply = new_apply
return cls
def not_list_of_tensor(tree):
if isinstance(tree, tuple):
return False
if isinstance(tree, list):
return any(not isinstance(l, Tensor) for l in tree)
return True
def not_list_of_optional_tensor(tree):
if isinstance(tree, tuple):
return False
if isinstance(tree, list):
return any(l is not None and not isinstance(l, Tensor) for l in tree)
return True
flatten = _pytree.tree_flatten
unflatten = _pytree.tree_unflatten
spec_t = _pytree.TreeSpec

View File

@ -0,0 +1,835 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
import logging
import weakref
from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import torch
from torch import _C, _ops, Tensor
from torch.utils._exposed_in import exposed_in
from . import autograd, utils
device_types_t = Optional[Union[str, Sequence[str]]]
log = logging.getLogger(__name__)
@exposed_in("torch.library")
def custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: device_types_t = None,
schema: Optional[str] = None,
) -> Callable:
"""Wraps a function into custom operator.
Reasons why you may want to create a custom op include:
- Wrapping a third-party library or custom kernel to work with PyTorch
subsystems like Autograd.
- Preventing torch.compile/export/FX tracing from peeking inside your function.
This API is used as a decorator around a function (please see examples).
The provided function must have type hints; these are needed to interface
with PyTorch's various subsystems.
Args:
name (str): A name for the custom op that looks like "{namespace}::{name}",
e.g. "mylib::my_linear". The name is used as the op's stable identifier
in PyTorch subsystems (e.g. torch.export, FX graphs).
To avoid name collisions, please use your project name as the namespace;
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
it pessimistically assumes that all inputs to the operator are being mutated.
device_types (None | str | Sequence[str]): The device type(s) the function
is valid for. If no device type is provided, then the function
is used as the default implementation for all device types.
Examples: "cpu", "cuda".
When registering a device-specific implementation for an operator that accepts no Tensors,
we require the operator to have a "device: torch.device argument".
schema (None | str): A schema string for the operator. If None
(recommended) we'll infer a schema for the operator from its type
annotations. We recommend letting us infer a schema unless you
have a specific reason not to.
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
.. note::
We recommend not passing in a ``schema`` arg and instead letting us infer
it from the type annotations. It is error-prone to write your own schema.
You may wish to provide your own schema if our interpretation of
the type annotation is not what you want.
For more info on how to write a schema string, see
`here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
Examples::
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> @custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> x_np = x.cpu().numpy()
>>> y_np = np.sin(x_np)
>>> return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that only works for one device type.
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
>>> x_np = x.numpy()
>>> y_np = np.sin(x_np)
>>> return torch.from_numpy(y_np)
>>>
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>> x_np = x.numpy()
>>> np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)
>>>
>>> # Example of a factory function
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
>>> def bar(device: torch.device) -> Tensor:
>>> return torch.ones(3)
>>>
>>> bar("cpu")
"""
def inner(fn):
import torch
if schema is None:
schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
else:
schema_str = schema
namespace, opname = name.split("::")
result = CustomOpDef(namespace, opname, schema_str, fn)
if schema is not None:
# Check that schema's alias annotations match those of `mutates_args`.
expected = set()
for arg in result._opoverload._schema.arguments:
if arg.alias_info is not None and arg.alias_info.is_write:
expected.add(arg.name)
if expected != set(mutates_args):
raise ValueError(
f"Attempted to create a custom op with `mutates_args={mutates_args}` "
f"and `schema={schema}. The schema suggests that the op mutates {expected}"
f"which is different from what was provided to us in `mutates_args`. "
f"Please make these consistent."
)
result.register_kernel(device_types)(fn)
return result
if fn is None:
return inner
return inner(fn)
class CustomOpDef:
"""CustomOpDef is a wrapper around a function that turns it into a custom op.
It has various methods for registering additional behavior for this
custom op.
You should not instantiate CustomOpDef directly; instead, use the
:func:`torch.library.custom_op` API.
"""
def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None:
# Fields used to interface with the PyTorch dispatcher
self._namespace = namespace
self._name = name
self._schema = schema
self._init_fn = fn
self._backend_fns: Dict[Union[str, None], Callable] = {}
self._abstract_fn: Optional[Callable] = None
self._setup_context_fn: Optional[Callable] = None
self._backward_fn: Optional[Callable] = None
self._torch_dispatch_fns: Dict[type, Callable] = {}
self._vmap_fn: Optional[Callable] = None
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
self._register_to_dispatcher()
self._disabled_kernel: Set = set()
OPDEFS[self._qualname] = self
@property
def _qualname(self) -> str:
return f"{self._namespace}::{self._name}"
def __repr__(self) -> str:
return f"<CustomOpDef({self._qualname})>"
@contextmanager
def set_kernel_enabled(self, device_type: str, enabled: bool = True):
"""
Disable or re-enable an already registered kernel for this custom operator.
If the kernel is already disabled/enabled, this is a no-op.
Note:
If a kernel is first disabled and then registered, it is disabled until enabled again.
Args:
device_type (str): The device type to disable/enable the kernel for.
disable (bool): Whether to disable or enable the kernel.
Example:
>>> inp = torch.randn(1)
>>>
>>> # define custom op `f`.
>>> @custom_op("mylib::f", mutates_args=())
>>> def f(x: Tensor) -> Tensor:
>>> return torch.zeros(1)
>>>
>>> print(f(inp)) # tensor([0.]), default kernel
>>>
>>> @f.register_kernel("cpu")
>>> def _(x):
>>> return torch.ones(1)
>>>
>>> print(f(inp)) # tensor([1.]), CPU kernel
>>>
>>> # temporarily disable the CPU kernel
>>> with f.set_kernel_enabled("cpu", enabled = False):
>>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
"""
action = "enable" if enabled else "disable"
originally_disabled = device_type in self._disabled_kernel
if device_type not in self._backend_fns:
log.warning(
"Attempted to %s kernel for %s but no kernel was registered for this device type.",
action,
device_type,
)
if not enabled:
if originally_disabled:
log.warning(
"Attempted to disable kernel for %s but it was already disabled.",
device_type,
)
else:
self._disabled_kernel.add(device_type)
else: # enable the kernel
if not originally_disabled:
log.warning(
"Attempted to enable kernel for %s but it was already enabled.",
device_type,
)
else:
self._disabled_kernel.remove(device_type)
try:
yield
finally:
# restore original state
if originally_disabled:
self._disabled_kernel.add(device_type)
else:
self._disabled_kernel.discard(device_type)
def register_kernel(
self, device_types: device_types_t, fn: Optional[Callable] = None, /
) -> Callable:
"""Register an implementation for a device type for this operator.
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
This API may be used as a decorator.
Args:
fn (Callable): The function to register as the implementation for
the given device types.
device_types (str | Sequence[str]): The device device_types to register an impl to.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>> import numpy as np
>>>
>>> # Create a custom op that works on cpu
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> x_np = x.numpy()
>>> y_np = np.sin(x_np)
>>> return torch.from_numpy(y_np)
>>>
>>> # Add implementations for the cuda device
>>> @numpy_sin.register_kernel("cuda")
>>> def _(x):
>>> x_np = x.cpu().numpy()
>>> y_np = np.sin(x_np)
>>> return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> x_cpu = torch.randn(3)
>>> x_cuda = x_cpu.cuda()
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
"""
def inner(fn):
if device_types is None or isinstance(device_types, str):
dtypes: List[Union[str, None]] = [device_types]
else:
dtypes = list(device_types)
for device_type in dtypes:
if device_type not in self._backend_fns:
def backend_impl(*args, **kwargs):
# Checks the assumption that outputs cannot alias
# inputs or other outputs.
storages = {
id(tensor.untyped_storage())
for tensor in iter_tensors(args, kwargs)
}
result = self._backend_fns[device_type](*args, **kwargs)
tuple_result = result
if not isinstance(result, tuple):
tuple_result = (result,)
for tensor in iter_tensors(tuple_result, {}):
key = id(tensor.untyped_storage())
if id(tensor.untyped_storage()) in storages:
fn = self._backend_fns[device_type]
module = inspect.getmodule(fn)
raise RuntimeError(
f"{self._name} (with implementation in {module}): "
f"The output of this custom operator (1) must not "
f"also be an input to this custom operator and "
f"(2) may not alias any inputs to this custom operator "
f"or other returns. "
f"The most common way to trigger this error is if "
f"we have y = custom_op(x) and y and x are the same Tensor. "
f"Please instead return a clone of the offending output "
f"tensor(s) (e.g. return x.clone()) or refactor the custom "
f"operator to not return y."
)
storages.add(key)
return result
if device_type is None:
self._lib.impl(
self._name, backend_impl, "CompositeExplicitAutograd"
)
else:
self._lib.impl(
self._name,
backend_impl,
_C._dispatch_key_for_device(device_type),
)
# Wrap function to choose between the default implementation or the device-specific
# implementation depending on if the kernel is disabled.
@torch._disable_dynamo
def wrapped_fn(*args, **kwargs):
if device_type in self._disabled_kernel:
return self._init_fn(*args, **kwargs)
else:
return fn(*args, **kwargs)
self._backend_fns[device_type] = wrapped_fn
return fn
if device_types is not None and not utils.has_tensor_arg(
self._opoverload._schema
):
device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
if device_arg_index is None:
raise ValueError(
"Functions without tensor inputs are required to have a `device: torch.device` argument"
)
self._register_backend_select_dispatcher(device_arg_index)
# See NOTE: [Supporting decorator and non-decorator usage]
if fn is None:
return inner
return inner(fn)
def register_fake(self, fn: Callable, /) -> Callable:
r"""Register a FakeTensor implementation for this custom op.
This is necessary to get the operator to work efficiently with torch.compile.
The Fake impl (sometimes also known as a meta kernel or abstract impl)
specifies the behavior of this operator on Tensors that carry no data.
Given some input Tensors with certain properties
(sizes/strides/storage_offset/device), it specifies what the properties of
the output Tensors are.
Please see :func:`torch.library.impl_abstract` for more details.
Args:
fn (Callable): The function to register as the FakeTensor
implementation.
Examples:
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch.library.custom_op("mylib::linear", mutates_args=())
>>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>> return (x @ weight.t()) + bias
>>>
>>> @linear.register_fake
>>> def _(x, weight, bias):
>>> assert x.dim() == 2
>>> assert weight.dim() == 2
>>> assert bias.dim() == 1
>>> assert x.shape[1] == weight.shape[1]
>>> assert weight.shape[0] == bias.shape[0]
>>> assert x.device == weight.device
>>> return x.new_empty(x.size(0), weight.size(0))
>>>
>>> x = torch.randn(2, 2)
>>> weight = torch.randn(2, 2)
>>> bias = torch.randn(2)
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
>>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
>>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
>>> def nonzero(x: Tensor) -> Tensor:
>>> x_np = x.cpu().numpy()
>>> res = np.stack(np.nonzero(x_np), axis=1)
>>> return torch.tensor(res, device=x.device)
>>>
>>> @nonzero.register_fake
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch.library.get_ctx()
>>> nnz = ctx.new_dynamic_size()
>>> shape = [nnz, x.dim()]
>>> result = x.new_empty(shape, dtype=torch.int64)
>>> return result
>>>
>>> x = torch.tensor([0, 1, 2, 0, 0, 1])
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
>>> out = torch.compile(nonzero, fullgraph=True)(x)
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
>>> assert torch.allclose(out, x.nonzero())
"""
self._abstract_fn = fn
return fn
def register_torch_dispatch(
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
) -> Callable:
r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
This allows for open registration to specify the behavior between the operator
and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
or the operator directly.
Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
"""
def register(fn):
if torch_dispatch_class not in self._torch_dispatch_fns:
def inner(*args, **kwargs):
return self._torch_dispatch_fns[torch_dispatch_class](
*args, **kwargs
)
self._lib._register_torch_dispatch_rule(
self._name, torch_dispatch_class, inner
)
self._torch_dispatch_fns[torch_dispatch_class] = fn
return fn
if fn is None:
return register
else:
return register(fn)
def register_autograd(
self,
backward: Callable,
/,
*,
setup_context: Optional[Callable] = None,
) -> None:
r"""Register a backward formula for this custom op.
In order for an operator to work with autograd, you need to register
a backward formula:
1. You must tell us how to compute gradients during the backward pass
by providing us a "backward" function.
2. If you need any values from the forward to compute gradients, you can
use `setup_context` to save values for backward.
``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
- ``grads`` is one or more gradients. The number of gradients matches
the number of outputs of the operator.
The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
:class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
same as :meth:`torch.autograd.Function.backward`.
``setup_context(ctx, inputs, output)`` runs during the forward pass.
Please save quantities needed for backward onto the ``ctx`` object via
either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
or assigning them as attributes of ``ctx``. If your custom op has
kwarg-only arguments, we expect the signature of ``setup_context``
to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
they may not directly access :meth:`torch.Tensor.data_ptr` and they must
not depend on or mutate global state. If you need a non-traceable backward,
you can make it a separate custom_op that you call inside ``backward_fn``.
Examples:
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> x_np = x.cpu().numpy()
>>> y_np = np.sin(x_np)
>>> return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, output) -> Tensor:
>>> x, = inputs
>>> ctx.save_for_backward(x)
>>>
>>> def backward(ctx, grad):
>>> x, = ctx.saved_tensors
>>> return grad * x.cos()
>>>
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, x.cos())
>>>
>>> # Example with a keyword-only arg
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
>>> x_np = x.cpu().numpy()
>>> y_np = x_np * val
>>> return torch.from_numpy(y_np).to(device=x.device)
>>>
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
>>> ctx.val = keyword_only_inputs["val"]
>>>
>>> def backward(ctx, grad):
>>> return grad * ctx.val
>>>
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_mul(x, val=3.14)
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
"""
schema = self._opoverload._schema
if not utils.is_functional_schema(schema):
raise RuntimeError(
f"Cannot register autograd formula for non-functional operator "
f"{self} with schema {schema}. Please create "
f"a functional operator and register an autograd formula for that."
)
self._backward_fn = backward
self._setup_context_fn = setup_context
def _register_to_dispatcher(self) -> None:
lib = self._lib
schema_str = self._name + self._schema
cpp_schema = _C.parse_schema(schema_str)
if utils.has_kwarg_only_tensors(cpp_schema):
# If you want to support this, the progression is:
# - supporting kwarg-only Tensors that are non-differentiable
# - supporting kwarg-only Tensors (regardless of differentiability)
raise NotImplementedError(
f"custom_op with kwarg-only Tensor args. Please make your "
f"tensors not kwarg-only. Got: {schema_str}"
)
lib.define(
schema_str,
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
)
self._opoverload = utils.lookup_op(self._qualname)
def fake_impl(*args, **kwargs):
if self._abstract_fn is None:
if utils.can_generate_trivial_fake_impl(self._opoverload):
return None
raise RuntimeError(
f"There was no fake impl registered for {self}. "
f"This is necessary for torch.compile/export/fx tracing to work. "
f"Please use `{self._init_fn.__name__}.register_fake` to add an "
f"fake impl."
)
return self._abstract_fn(*args, **kwargs)
lib._register_fake(self._name, fake_impl, _stacklevel=4)
autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
schema = self._opoverload._schema
if schema.is_mutable:
def adinplaceorview_impl(keyset, *args, **kwargs):
for arg, val in utils.zip_schema(schema, args, kwargs):
if not arg.alias_info:
continue
if not arg.alias_info.is_write:
continue
if isinstance(val, Tensor):
torch.autograd.graph.increment_version(val)
elif isinstance(val, (tuple, list)):
for v in val:
if isinstance(v, Tensor):
torch.autograd.graph.increment_version(v)
with _C._AutoDispatchBelowADInplaceOrView():
return self._opoverload.redispatch(
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
)
lib.impl(
self._name,
adinplaceorview_impl,
"ADInplaceOrView",
with_keyset=True,
)
def _register_backend_select_dispatcher(self, device_arg_index: int):
"""
Switch on the device argument to select the correct backend to dispatch to.
"""
def backend_select(keyset, *args, **kwargs):
device = args[device_arg_index].type
if device not in self._backend_fns:
raise RuntimeError(
f"{self._name} does not have a kernel registered for {device}. "
"Please use register_kernel to do so."
)
dispatch_key = _C._dispatch_key_for_device(device)
dispatch_key = getattr(_C.DispatchKey, dispatch_key)
return self._opoverload.redispatch(
_C.DispatchKeySet(dispatch_key), *args, **kwargs
)
self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
def __call__(self, *args, **kwargs):
return self._opoverload(*args, **kwargs)
def register_vmap(
self,
func: Optional[Callable] = None,
):
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
This API may be used as a decorator.
In order for an operator to work with :func:`torch.vmap`, you may need to register a
vmap implementation in the following signature:
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
It specifies how do we compute the batched version of ``op`` given inputs with an additional
dimension (specified by ``in_dims``).
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
specifying what dimension of the Tensor is being vmapped over.
``info`` is a collection of additional metadata that may be helpful:
``info.batch_size`` specifies the size of the dimension being vmapped over, while
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
per output that specifies if the output has the vmapped dimension and what index it is in.
Examples:
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>>
>>> def to_numpy(tensor):
>>> return tensor.cpu().numpy()
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
>>> x_np = to_numpy(x)
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>> return torch.tensor(x_np ** 3, device=x.device), dx
>>>
>>> def numpy_cube_vmap(info, in_dims, x):
>>> result = numpy_cube(x)
>>> return result, (in_dims[0], in_dims[0])
>>>
>>> numpy_cube.register_vmap(numpy_cube_vmap)
>>>
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>>
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>>
>>> @numpy_mul.register_vmap
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>> x_bdim, y_bdim = in_dims
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>> result = x * y
>>> result = result.movedim(-1, 0)
>>> return result, 0
>>>
>>>
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)
"""
from torch._functorch.autograd_function import custom_function_call_vmap_helper
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
def register(func):
need_register = self._vmap_fn is None
self._vmap_fn = func
if need_register:
def wrapped_func(keyset, *args, **kwargs):
interpreter = retrieve_current_functorch_interpreter()
return custom_function_call_vmap_helper(
interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
)
self._lib.impl(
self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
)
if func is None:
return register
else:
return register(func)
# NOTE: [Supporting decorator and non-decorator usage]
#
# Some APIs may be both used as a decorator and not as a decorator.
# For example:
#
# >>> def fn(x):
# >>> return x.sin()
# >>>
# >>> # Usage 1: not as a decorator
# >>> numpy_sin.register_kernel("cuda", fn)
# >>>
# >>> # Usage 2: as a decorator
# >>> @numpy_sin.register_kernel("cuda")
# >>> def fn2(x):
# >>> return x.sin
#
# The way we support this is that `register_kernel` accepts an optional `fn`.
# If `fn` is provided (Usage 1), then we know that the user is using it not
# as a decorator.
# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
# decorator.
OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
def get_library_allowing_overwrite(
namespace: str, name: str
) -> "torch.library.Library":
qualname = f"{namespace}::{name}"
if qualname in OPDEF_TO_LIB:
OPDEF_TO_LIB[qualname]._destroy()
del OPDEF_TO_LIB[qualname]
lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
OPDEF_TO_LIB[qualname] = lib
return lib
def iter_tensors(
args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1
) -> Iterator[Tensor]:
def check(arg):
if isinstance(arg, Tensor):
yield arg
elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
for arg in args:
yield from check(arg)
for kwarg in kwargs.values():
yield from check(kwarg)
def _maybe_get_opdef(
op: Union[CustomOpDef, _ops.OpOverload, str]
) -> Optional[CustomOpDef]:
if isinstance(op, CustomOpDef):
return op
if isinstance(op, _ops.OpOverload):
op = op._name
assert isinstance(op, str)
if op in OPDEFS:
return OPDEFS[op]
return None

View File

@ -0,0 +1,320 @@
# mypy: allow-untyped-defs
import logging
from typing import Any, Dict, Optional, Protocol, Tuple, Union
import torch
from torch._library.utils import parse_namespace
log = logging.getLogger(__name__)
class FakeScriptObject:
def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject):
self.wrapped_obj = wrapped_obj
# The fully qualified name of the class of original script object
self.script_class_name = script_class_name
self.real_obj = x
class FakeScriptMethod:
def __init__(
self,
self_fake_obj: FakeScriptObject,
method_name: str,
schema: Optional[torch.FunctionSchema],
):
self.self_fake_obj = self_fake_obj
self.method_name = method_name
self.schema = schema
def __call__(self, *args, **kwargs):
from torch._higher_order_ops.torchbind import call_torchbind
return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs)
class HasStaticMethodFromReal(Protocol):
@classmethod
def from_real(cls, real_obj: torch.ScriptObject):
pass
class FakeClassRegistry:
def __init__(self) -> None:
self._registered_class: Dict[str, Any] = {}
def has_impl(self, full_qualname: str) -> bool:
return full_qualname in self._registered_class
def get_impl(self, full_qualname: str) -> Any:
self._check_registered(full_qualname)
return self._registered_class[full_qualname]
def register(self, full_qualname: str, fake_class=None) -> None:
if self.has_impl(full_qualname):
log.warning(
"%s is already registered. Previous fake class is overridden with %s.",
full_qualname,
fake_class,
)
self._registered_class[full_qualname] = fake_class
def deregister(self, full_qualname: str) -> Any:
if not self.has_impl(full_qualname):
log.warning(
"Cannot deregister %s. Please use register_fake_class to register it first."
" Or do you dereigster it twice?",
full_qualname,
)
else:
return self._registered_class.pop(full_qualname)
def clear(self) -> None:
self._registered_class.clear()
def _check_registered(self, full_qualname: str) -> None:
if full_qualname not in self._registered_class:
raise RuntimeError(
f"{full_qualname} is not registered. Please use register_fake_class to register it first."
)
global_fake_class_registry = FakeClassRegistry()
# TODO: add this check at compile time for __obj_flatten__.
def _check_valid_flat_script_obj(flat_x):
if not isinstance(flat_x, tuple):
raise RuntimeError("Expect flat x to be a tuple.")
for tp in flat_x:
if not isinstance(tp, tuple):
raise RuntimeError("Expect flat x to be a tuple of tuples.")
if not len(tp) == 2 or not isinstance(tp[0], str):
raise RuntimeError(
"Expect element of flat x to be a tuple of two elements with first element being a string"
)
def tracing_with_real(x: torch.ScriptObject) -> bool:
if not hasattr(x, "tracing_mode"):
return False
assert x.tracing_mode() in [
"real",
"fake",
], f"tracing_mode can be either real or fake but got {x.tracing_mode()}"
return x.tracing_mode() == "real"
def maybe_to_fake_obj(
fake_mode, x: torch.ScriptObject
) -> Union[FakeScriptObject, torch.ScriptObject]:
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import _disable_current_modes
# When tracing with real mode, people should implement meta kernels that can
# handle the case of real script object + fake tensor inputs.
if tracing_with_real(x):
return x
# x.__obj_flatten__() could be calling some tensor operations inside but we don't
# want to call these ops in surrounding dispatch modes when executing it.
# Otherwise, for example, the fake tensor modes will error out when the tensors inside
# script obeject execute some operations like clone if allow_non_fake_input flag is set.
with _disable_current_modes():
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
_check_valid_flat_script_obj(flat_x)
fake_flattened = pytree.tree_map_only(
torch.Tensor,
lambda t: fake_mode.from_tensor(t),
flat_x,
)
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined]
for name in x._method_names(): # type: ignore[attr-defined]
attr = getattr(fake_x, name, None)
if attr:
if not callable(attr):
raise RuntimeError(f"Expect {name} to be a callable but got {attr}.")
real_attr = getattr(x, name) # type: ignore[attr-defined]
# real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__
method_schema: Optional[torch.FunctionSchema] = None
if isinstance(real_attr, torch.ScriptMethod):
method_schema = real_attr.schema # type: ignore[attr-defined]
setattr(
fake_x_wrapped,
name,
FakeScriptMethod(fake_x_wrapped, name, method_schema),
)
else:
override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"}
if name not in override_skip_list:
log.warning("fake object of %s doesn't implement method %s.", x, name)
return fake_x_wrapped
def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None):
r"""Register a fake implementation for this class.
It's in the same spirit of registering a fake implementation for
an operator but with the difference that it
associates a fake class with the original torch bind class (registered
with torch::class_). In this way, torch.compile can handle them properly
in components such as Dynamo and AOTAutograd.
This API may be used as a decorator (see example). For the fake class, users
are required to provide a from_real classmethod that takes a real object and
returns an instance of the fake class. All tensors in the fake object should also
be properly fakified with to_fake_tensor() in from_real.
Examples:
# For a custom class Foo defined in test_custom_class_registration.cpp:
TORCH_LIBRARY(_TorchScriptTesting, m) {
m.class_<TensorQueue>("_TensorQueue")
.def(torch::init<at::Tensor>())
.def("push", &TensorQueue::push)
.def("pop", &TensorQueue::pop)
.def("top", &TensorQueue::top)
.def("size", &TensorQueue::size)
.def("clone_queue", &TensorQueue::clone_queue)
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self)
-> c10::Dict<std::string, at::Tensor> {
return self->serialize();
},
// __setstate__
[](c10::Dict<std::string, at::Tensor> data)
-> c10::intrusive_ptr<TensorQueue> {
return c10::make_intrusive<TensorQueue>(std::move(data));
});
};
# We could register a fake class FakeTensorQueue in Python as follows:
import torch
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, queue):
self.queue = queue
@classmethod
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(ctx))
def push(self, x):
self.queue.append(x)
def pop(self):
return self.queue.pop(0)
def size(self):
return len(self.queue)
In this example, the original TensorQeue need to addd a __obj_flatten__ method
to the class TensorQueue and the flattend result is passed into FakeTensorQueue's
__obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look
at the contents of the script object and properly handle them in the subsystems
like dynamo, aot_aotugrad or more.
"""
def inner(fake_class: HasStaticMethodFromReal):
ns, name = parse_namespace(qualname)
# This also checks whether the refered torch::class_ exists.
torchbind_class = torch._C._get_custom_class_python_wrapper(ns, name)
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_method:
raise RuntimeError(
f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
)
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
raise RuntimeError(
f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod."
)
global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class)
return fake_class
if fake_class is None:
return inner
return inner(fake_class)
def deregister_fake_class(qualname):
return global_fake_class_registry.deregister(_full_qual_class_name(qualname))
def has_fake_class(full_qualname) -> bool:
return global_fake_class_registry.has_impl(full_qualname)
def find_fake_class(full_qualname) -> Optional[Any]:
if not has_fake_class(full_qualname):
return None
return global_fake_class_registry.get_impl(full_qualname)
def _full_qual_class_name(qualname: str) -> str:
ns, name = parse_namespace(qualname)
return "__torch__.torch.classes." + ns + "." + name
# Return the namespace and class name from fully qualified name.
def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]:
splits = full_qualname.split(".")
assert len(splits) == 5
_torch, torch_ns, classes, ns, class_name = splits
return ns, class_name
def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
full_qualname = x._type().qualified_name() # type: ignore[attr-defined]
ns, class_name = _ns_and_class_name(full_qualname)
fake_class = find_fake_class(full_qualname)
if fake_class is None:
raise RuntimeError(
f" ScriptObject's {full_qualname} haven't registered a fake class."
f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj."
f" Specifically, create a python class that implements a fake version for all the methods"
f" that're used in the program and put annotated class in the program e.g. after loading the library."
f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally"
f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod"
f" to enable creating a fake obj from a real one."
)
return fake_class
_CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
def _fake_obj_from_real(fake_mode, x) -> Any:
fake_class = _find_fake_class_for_script_object(x)
from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_real_method:
raise RuntimeError(
f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}"
f" that converts the real object to the fake object."
)
# from_real defined by user need the ctx to fakify the tensor states.
ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None)
with torch._library.fake_impl.set_ctx_getter(lambda: ctx):
return fake_class.from_real(x)

View File

@ -0,0 +1,207 @@
# mypy: allow-untyped-defs
import contextlib
import functools
from typing import Callable, Optional
from typing_extensions import deprecated
import torch
from torch._library.utils import Kernel, RegistrationHandle
class FakeImplHolder:
"""A holder where one can register an fake impl to."""
def __init__(self, qualname: str):
self.qualname: str = qualname
self.kernel: Optional[Kernel] = None
self.lib: Optional[torch.library.Library] = None
def register(self, func: Callable, source: str) -> RegistrationHandle:
"""Register an fake impl.
Returns a RegistrationHandle that one can use to de-register this
fake impl.
"""
if self.kernel is not None:
raise RuntimeError(
f"register_fake(...): the operator {self.qualname} "
f"already has an fake impl registered at "
f"{self.kernel.source}."
)
if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
raise RuntimeError(
f"register_fake(...): the operator {self.qualname} "
f"already has an DispatchKey::Meta implementation via a "
f"pre-existing torch.library or TORCH_LIBRARY registration. "
f"Please either remove that registration or don't call "
f"register_fake."
)
if torch._C._dispatch_has_kernel_for_dispatch_key(
self.qualname, "CompositeImplicitAutograd"
):
raise RuntimeError(
f"register_fake(...): the operator {self.qualname} "
f"already has an implementation for this device type via a "
f"pre-existing registration to "
f"DispatchKey::CompositeImplicitAutograd."
f"CompositeImplicitAutograd operators do not need an fake "
f"impl; "
f"instead, the operator will decompose into its constituents "
f"and those "
f"can have fake impls defined on them."
)
# Store the kernel in this holder
self.kernel = Kernel(func, source)
# Also register the fake impl to Meta key
if self.lib is None:
ns = self.qualname.split("::")[0]
self.lib = torch.library.Library(ns, "FRAGMENT") # noqa: TOR901
meta_kernel = construct_meta_kernel(self.qualname, self)
self.lib.impl(self.qualname, meta_kernel, "Meta")
def deregister_fake_class():
if self.lib:
self.lib._destroy()
self.lib = None
self.kernel = None
return RegistrationHandle(deregister_fake_class)
def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable:
assert fake_impl_holder.kernel is not None
@functools.wraps(fake_impl_holder.kernel.func)
def meta_kernel(*args, **kwargs):
assert fake_impl_holder.kernel is not None
source = fake_impl_holder.kernel.source
def error_on_ctx():
raise RuntimeError(
f"Attempted to call get_ctx() for the meta implementation "
f"for {qualname} (implemented at {source})"
f"You have presumably called get_ctx() because the operator "
f"has a data-dependent output shape; if so, there is no "
f"such meta implementation and this error is the correct "
f"behavior."
)
with set_ctx_getter(error_on_ctx):
return fake_impl_holder.kernel(*args, **kwargs)
return meta_kernel
def get_none():
return None
global_ctx_getter: Callable = get_none
@contextlib.contextmanager
def set_ctx_getter(ctx_getter):
global global_ctx_getter
prev = global_ctx_getter
try:
global_ctx_getter = ctx_getter
yield
finally:
global_ctx_getter = prev
class FakeImplCtx:
"""
Context object for writing fake implementations for custom operators.
"""
def __init__(self, _fake_mode, _op):
self._fake_mode = _fake_mode
self._shape_env = _fake_mode.shape_env
self._op = _op
@deprecated(
"`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead",
category=FutureWarning,
)
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
return self.new_dynamic_size(min=min, max=max)
def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
"""Constructs a new symint (symbolic int) representing a data-dependent value.
This is useful for writing the fake implementation (which is necessary
for torch.compile) for a CustomOp where an output Tensor has a size
that depends on the data of the input Tensors.
Args:
min (int): A statically known inclusive lower bound for this symint. Default: 0
max (Optional[int]): A statically known inclusive upper bound for this
symint. Default: None
.. warning:
It is important that the ``min`` and ``max`` (if not None) values are set
correctly, otherwise, there will be undefined behavior under
torch.compile. The default value of ``min`` is 2 due to torch.compile
specializing on 0/1 sizes.
You must also verify that your implementation on concrete Tensors
(e.g. CPU/CUDA) only returns Tensors where the size that corresponds
to the symint also has respects these constraint.
The easiest way to do this is to add an assertion in the CPU/CUDA/etc
implementation that the size follows these bounds.
Example::
>>> # An operator with data-dependent output shape
>>> lib = torch.library.Library("mymodule", "FRAGMENT")
>>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor")
>>>
>>> @torch.library.register_fake("mymodule::custom_nonzero")
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an fake impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch.library.get_ctx()
>>> nnz = ctx.new_dynamic_size()
>>> shape = [nnz, x.dim()]
>>> result = x.new_empty(shape, dtype=torch.int64)
>>> return result
>>>
>>> @torch.library.impl(lib, "custom_nonzero", "CPU")
>>> def _(x):
>>> x_np = x.numpy()
>>> res = np.stack(np.nonzero(x_np), axis=1)
>>> return torch.tensor(res, device=x.device)
"""
if (
self._shape_env is None
or not self._shape_env.allow_dynamic_output_shape_ops
):
raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
raise ValueError(
f"ctx.new_dynamic_size(min={min}, max={max}): expected "
f"min and max to be statically known ints but got SymInt. "
f"This is not supported."
)
if min < 0:
raise ValueError(
f"ctx.new_dynamic_size(min={min}, ...): expected min to be "
f"greater than or equal to 0: this API can only create "
f"non-negative sizes."
)
result = self._shape_env.create_unbacked_symint()
torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
result, min=min, max=max
)
return result

View File

@ -0,0 +1,271 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
import typing
from typing import List, Optional, Sequence, Union # noqa: F401
import torch
from torch import device, dtype, Tensor, types
from torch.utils._exposed_in import exposed_in
@exposed_in("torch.library")
def infer_schema(
prototype_function: typing.Callable,
/,
*,
mutates_args,
op_name: Optional[str] = None,
) -> str:
r"""Parses the schema of a given function with type hints. The schema is inferred from the
function's type hints, and can be used to define a new operator.
We make the following assumptions:
* None of the outputs alias any of the inputs or each other.
* | String type annotations "device, dtype, Tensor, types" without library specification are
| assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
| without library specification are assumed to be typing.*.
* | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
| it assumes that all inputs to the operator are being mutates.
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
Args:
prototype_function: The function from which to infer a schema for from its type annotations.
op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the
name is not included in the inferred schema. Note that the input schema to
``torch.library.Library.define`` requires a operator name.
mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function.
Returns:
The inferred schema.
Example:
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
>>> return x.sin()
>>>
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
foo(Tensor x) -> Tensor
>>>
>>> infer_schema(foo_impl, mutates_args={})
(Tensor x) -> Tensor
"""
UNKNOWN_MUTATES = "unknown"
sig = inspect.signature(prototype_function)
def error_fn(what):
raise ValueError(
f"infer_schema(func): {what} " f"Got func with signature {sig})"
)
def convert_type_string(annotation_type: str):
try:
return eval(annotation_type)
except Exception as e:
error_fn(
f"Unsupported type annotation {annotation_type}. It is not a type."
)
params = []
seen_args = set()
saw_kwarg_only_arg = False
for idx, (name, param) in enumerate(sig.parameters.items()):
if not supported_param(param):
error_fn("We do not support positional-only args, varargs, or varkwargs.")
if param.kind == inspect.Parameter.KEYWORD_ONLY:
# The first time we see a kwarg-only arg, add "*" to the schema.
if not saw_kwarg_only_arg:
params.append("*")
saw_kwarg_only_arg = True
if param.annotation is inspect.Parameter.empty:
error_fn(f"Parameter {name} must have a type annotation.")
# The annotation might be converted to a string by annotation,
# we convert it to the actual type.
annotation_type = param.annotation
if type(annotation_type) == str:
annotation_type = convert_type_string(annotation_type)
if annotation_type not in SUPPORTED_PARAM_TYPES.keys():
if annotation_type.__origin__ is tuple:
list_type = tuple_to_list(annotation_type)
example_type_str = "\n\n"
# Only suggest the list type if this type is supported.
if list_type in SUPPORTED_PARAM_TYPES.keys():
example_type_str = f"For example, {list_type}.\n\n"
error_fn(
f"Parameter {name} has unsupported type {param.annotation}. "
f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. "
f"{example_type_str}"
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
)
else:
error_fn(
f"Parameter {name} has unsupported type {param.annotation}. "
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
)
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
if type(mutates_args) == str:
if mutates_args != UNKNOWN_MUTATES:
raise ValueError(
"mutates_args must either be a sequence of the names of "
"the arguments that are mutated or the string 'unknown'. "
)
if schema_type.startswith("Tensor"):
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
elif name in mutates_args:
if not schema_type.startswith("Tensor"):
error_fn(
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
)
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
seen_args.add(name)
if param.default is inspect.Parameter.empty:
params.append(f"{schema_type} {name}")
else:
default_repr = None
if param.default is None or isinstance(param.default, (int, float, bool)):
default_repr = str(param.default)
elif isinstance(param.default, (str, torch.device)):
default_repr = f'"{param.default}"'
elif isinstance(param.default, torch.dtype):
dtype_repr = str(param.default)
torch_dot = "torch."
assert dtype_repr.startswith(torch_dot)
default_repr = dtype_repr[len(torch_dot) :]
else:
error_fn(
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
f"Please file an issue on GitHub so we can prioritize this."
)
params.append(f"{schema_type} {name}={default_repr}")
if mutates_args != UNKNOWN_MUTATES:
mutates_args_not_seen = set(mutates_args) - seen_args
if len(mutates_args_not_seen) > 0:
error_fn(
f"{mutates_args_not_seen} in mutates_args were not found in "
f"the custom op's signature. "
f"mutates_args should contain the names of all args that the "
f"custom op mutates, or just the string 'unknown' if you don't know."
)
return_annotation = sig.return_annotation
if type(return_annotation) == str:
return_annotation = convert_type_string(return_annotation)
ret = parse_return(return_annotation, error_fn)
if op_name is not None:
return f"{op_name}({', '.join(params)}) -> {ret}"
return f"({', '.join(params)}) -> {ret}"
def derived_types(
base_type, cpp_type, list_base, optional_base_list, optional_list_base
):
result = [
(base_type, cpp_type),
(typing.Optional[base_type], f"{cpp_type}?"),
]
def derived_seq_types(typ):
return [
typing.Sequence[typ], # type: ignore[valid-type]
typing.List[typ], # type: ignore[valid-type]
]
if list_base:
for seq_typ in derived_seq_types(base_type):
result.append((seq_typ, f"{cpp_type}[]")) # type: ignore[valid-type]
if optional_base_list:
for seq_typ in derived_seq_types(typing.Optional[base_type]):
result.append((seq_typ, f"{cpp_type}?[]")) # type: ignore[valid-type]
if optional_list_base:
for seq_typ in derived_seq_types(base_type): # type: ignore[valid-type]
result.append((typing.Optional[seq_typ], f"{cpp_type}[]?")) # type: ignore[valid-type]
return result
def get_supported_param_types():
data = [
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
(Tensor, "Tensor", True, True, False),
(int, "SymInt", True, False, True),
(float, "float", True, False, True),
(bool, "bool", True, False, True),
(str, "str", False, False, False),
(types.Number, "Scalar", True, False, False),
(dtype, "ScalarType", False, False, False),
(device, "Device", False, False, False),
]
result = []
for line in data:
result.extend(derived_types(*line))
return dict(result)
SUPPORTED_RETURN_TYPES = {
Tensor: "Tensor",
typing.List[Tensor]: "Tensor[]",
int: "SymInt",
float: "float",
bool: "bool",
types.Number: "Scalar",
}
def parse_return(annotation, error_fn):
if annotation is None:
return "()"
if annotation is inspect.Parameter.empty:
error_fn("No return type annotation was provided. Please add one.")
origin = typing.get_origin(annotation)
if origin is not tuple:
if annotation not in SUPPORTED_RETURN_TYPES.keys():
error_fn(
f"Return has unsupported type {annotation}. "
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
)
return SUPPORTED_RETURN_TYPES[annotation]
args = typing.get_args(annotation)
for arg in args:
if arg not in SUPPORTED_RETURN_TYPES:
error_fn(
f"Return has unsupported type {annotation}. "
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
)
return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
SUPPORTED_PARAM_TYPES = get_supported_param_types()
def supported_param(param: inspect.Parameter) -> bool:
return param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.List]:
"""
Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type.
"""
type_args = getattr(tuple_type, "__args__", None)
# Account for different python versions, e.g. python 3.8 would give ()
# but python 3.12 would give None.
if tuple_type is typing.Tuple or type_args == () or type_args is None:
# Handle the case of an empty tuple type
return typing.List
elif len(type_args) == 1:
# General case: create a List with the same type arguments
return typing.List[type_args[0]] # type: ignore[valid-type]
elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type]
return typing.List[type_args[0]] # type: ignore[valid-type]
else:
return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc]

View File

@ -0,0 +1,85 @@
# mypy: allow-untyped-defs
from typing import Callable, Optional
from .fake_impl import FakeImplHolder
from .utils import RegistrationHandle
__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"]
class SimpleLibraryRegistry:
"""Registry for the "simple" torch.library APIs
The "simple" torch.library APIs are a higher-level API on top of the
raw PyTorch DispatchKey registration APIs that includes:
- fake impl
Registrations for these APIs do not go into the PyTorch dispatcher's
table because they may not directly involve a DispatchKey. For example,
the fake impl is a Python function that gets invoked by FakeTensor.
Instead, we manage them here.
SimpleLibraryRegistry is a mapping from a fully qualified operator name
(including the overload) to SimpleOperatorEntry.
"""
def __init__(self):
self._data = {}
def find(self, qualname: str) -> "SimpleOperatorEntry":
if qualname not in self._data:
self._data[qualname] = SimpleOperatorEntry(qualname)
return self._data[qualname]
singleton: SimpleLibraryRegistry = SimpleLibraryRegistry()
class SimpleOperatorEntry:
"""This is 1:1 to an operator overload.
The fields of SimpleOperatorEntry are Holders where kernels can be
registered to.
"""
def __init__(self, qualname: str):
self.qualname: str = qualname
self.fake_impl: FakeImplHolder = FakeImplHolder(qualname)
self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = (
GenericTorchDispatchRuleHolder(qualname)
)
# For compatibility reasons. We can delete this soon.
@property
def abstract_impl(self):
return self.fake_impl
class GenericTorchDispatchRuleHolder:
def __init__(self, qualname):
self._data = {}
self.qualname = qualname
def register(
self, torch_dispatch_class: type, func: Callable
) -> RegistrationHandle:
if self.find(torch_dispatch_class):
raise RuntimeError(
f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}"
)
self._data[torch_dispatch_class] = func
def deregister():
del self._data[torch_dispatch_class]
return RegistrationHandle(deregister)
def find(self, torch_dispatch_class):
return self._data.get(torch_dispatch_class, None)
def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]:
return singleton.find(op.__qualname__).torch_dispatch_rules.find(
torch_dispatch_class
)

View File

@ -0,0 +1,233 @@
import contextlib
import threading
from typing import Callable, Generator, Iterable, Optional, Union
from .custom_ops import custom_op
from .infer_schema import infer_schema
def triton_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
schema: Optional[str] = None,
) -> Callable:
"""Create a custom operator whose implementation is backed by 1+ triton kernels.
Use this instead of :func:`torch.library.custom_op` when the implementation
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
custom operators as opaque (:func:`torch.compile` and
:func:`torch.export.export` will never trace into them), but ``triton_op``
makes the implementation visible to these subsystems, allowing them
to optimize the triton kernel(s).
Note that ``fn`` must only consist of calls to PyTorch-understood
operators and triton kernels. Any triton kernels called inside ``fn``
must be wrapped in a call to :func:`torch._library.capture_triton``.
Args:
name (str): A name for the custom op that looks like "{namespace}::{name}",
e.g. "mylib::my_linear". The name is used as the op's stable identifier
in PyTorch subsystems (e.g. torch.export, FX graphs).
To avoid name collisions, please use your project name as the namespace;
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
it pessimistically assumes that all inputs to the operator are being mutated.
schema (None | str): A schema string for the operator. If None
(recommended) we'll infer a schema for the operator from its type
annotations. We recommend letting us infer a schema unless you
have a specific reason not to.
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch._library import triton_op, capture_triton
>>>
>>> import triton
>>> from triton import language as tl
>>>
>>> @triton.jit
>>> def add_kernel(
>>> in_ptr0,
>>> in_ptr1,
>>> out_ptr,
>>> n_elements,
>>> BLOCK_SIZE: "tl.constexpr",
>>> ):
>>> pid = tl.program_id(axis=0)
>>> block_start = pid * BLOCK_SIZE
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>> mask = offsets < n_elements
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
>>> output = x + y
>>> tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> @triton_op("mylib::add", mutates_args={})
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
>>> output = torch.empty_like(x)
>>> n_elements = output.numel()
>>>
>>> def grid(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> # NB: we need to wrap the triton kernel in a call to capture_triton
>>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>> return output
>>>
>>> @torch.compile
>>> def f(x, y):
>>> return add(x, y)
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>>
>>> z = f(x, y)
>>> assert torch.allclose(z, x + y)
"""
def dec(fn: Callable) -> Callable:
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
# Optimization: we're passing regular Tensors into the triton kernel, so
# no need to go through HOP dispatch
with set_capture_triton_enabled(False):
return fn(*args, **kwargs)
result = custom_op(
name,
backend_fn,
mutates_args=mutates_args,
schema=infer_schema(fn, mutates_args=mutates_args),
)
from .._subclasses.functional_tensor import FunctionalTensorMode
# We require that the user pass us a function that is make_fx traceable,
# so we can just register it as the Fake/meta kernel.
result.register_fake(fn)
# We decompose the operator when FunctionalTensorMode is active.
# The goal is to decompose the operator in AOTDispatcher.
# - With torch.compile, this means that the backend (usually Inductor)
# can see a call to the triton kernel(s) and so it can directly optimize
# them by inlining them into the lowering process.
# - With post-dispatch torch.export, this means that there will
# be a call(s) to the triton_kernel_wrapper_functional HOP in the
# graph (that we have yet to figure out how to serialize).
def functional_decomp( # type: ignore[no-untyped-def]
mode, _, types, args, kwargs
):
with mode:
return fn(*args, **kwargs)
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
return result
if fn is None:
return dec
else:
return dec(fn)
capture_triton_enabled = threading.local()
capture_triton_enabled_default = True
@contextlib.contextmanager
def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]:
"""If triton kernels annotated with @capture_triton should dispatch via HOP
or go straight to the triton kernel execution.
We have this switch because eager-mode performance of HOP dispatch is slow
enough to matter (~1ms) and we know that capture_triton isn't necessary in
some situations (eager-mode with regular Tensors)
"""
try:
prev = is_capture_triton_enabled()
capture_triton_enabled.value = enabled
yield
finally:
capture_triton_enabled.value = prev
def is_capture_triton_enabled() -> bool:
return getattr(capture_triton_enabled, "value", capture_triton_enabled_default)
def capture_triton(triton_kernel: Callable, /) -> Callable:
"""Allows capture of a triton kernel into a graph via make_fx or
non-strict export (coming soon).
These technologies perform Dispatcher-based tracing (via
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
The ``capture_triton`` API returns a new callable that can actually
be traced into a graph.
Examples:
>>> # xdoctest: +SKIP
>>> import torch
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
>>>
>>> @triton.jit
>>> def add_kernel(
>>> in_ptr0,
>>> in_ptr1,
>>> out_ptr,
>>> n_elements,
>>> BLOCK_SIZE: "tl.constexpr",
>>> ):
>>> pid = tl.program_id(axis=0)
>>> block_start = pid * BLOCK_SIZE
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
>>> mask = offsets < n_elements
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
>>> output = x + y
>>> tl.store(out_ptr + offsets, output, mask=mask)
>>>
>>> def add(x, y):
>>> output = torch.empty_like(x)
>>> n_elements = output.numel()
>>>
>>> def grid_fn(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>> return output
>>>
>>> x = torch.randn(3, device="cuda")
>>> y = torch.randn(3, device="cuda")
>>> gm = make_fx(add)(x, y)
>>> print(gm.code)
>>> # def forward(self, x_1, y_1):
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
>>> # kernel_idx = 0, constant_args_idx = 0,
>>> # grid = [(1, 1, 1)], kwargs = {
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
>>> # })
>>> # return empty_like
"""
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
if not isinstance(triton_kernel, (JITFunction, Autotuner)):
raise RuntimeError(
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
)
if not is_capture_triton_enabled():
return triton_kernel
return TraceableTritonKernelWrapper(triton_kernel, None, None)

View File

@ -0,0 +1,318 @@
# mypy: allow-untyped-defs
import dataclasses
import inspect
import sys
from typing import Any, Callable, Dict, Iterable, Tuple, Union
import torch
from torch import _C, _utils_internal
from torch._ops import OpOverload
@dataclasses.dataclass
class Kernel:
"""Models a (function, source location)"""
func: Callable
source: str
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
class RegistrationHandle:
"""Does something when someone calls .destroy() on it"""
def __init__(self, on_destroy: Callable):
self._on_destroy = on_destroy
def destroy(self) -> None:
self._on_destroy()
def get_source(stacklevel: int) -> str:
"""Get a string that represents the caller.
Example: "/path/to/foo.py:42"
Use stacklevel=1 to get the caller's source
Use stacklevel=2 to get the caller's caller's source
etc.
"""
frame = inspect.getframeinfo(sys._getframe(stacklevel))
source = f"{frame.filename}:{frame.lineno}"
return source
def parse_namespace(qualname: str) -> Tuple[str, str]:
splits = qualname.split("::")
if len(splits) != 2:
raise ValueError(
f"Expected `qualname` to be of the form "
f'"namespace::name", but got {qualname}. '
f"The qualname passed to the torch.library APIs must consist "
f"of a namespace and a name, e.g. aten::sin"
)
return splits[0], splits[1]
def lookup_op(qualname: str) -> OpOverload:
namespace, name = parse_namespace(qualname)
if "." in name:
name, overload = name.split(".")
else:
overload = "default"
ns = getattr(torch.ops, namespace)
packet = getattr(ns, name)
return getattr(packet, overload)
def is_builtin(op: OpOverload) -> bool:
assert isinstance(op, OpOverload)
return op.namespace in {"aten", "prim", "prims"}
def is_functional_schema(schema: Any) -> bool:
"""Check if the schema is functional.
An operator is functional if:
- it does not mutate any of its inputs
- it does not return a view on any of its inputs
- it has at least one return
"""
def is_functional(schema):
if schema.is_mutable:
return False
rets = schema.returns
is_non_mutating_view = len(rets) > 0 and any(
r.alias_info is not None and not r.alias_info.is_write for r in rets
)
if is_non_mutating_view:
return False
if not schema.returns:
return False
return True
if isinstance(schema, torch._C.FunctionSchema):
return is_functional(schema)
# Lazy import because not all PyTorch builds have torchgen
from torchgen.model import FunctionSchema
if isinstance(schema, str):
schema = FunctionSchema.parse(schema)
assert isinstance(schema, FunctionSchema)
return is_functional(schema)
# should be torch._C.JitType but that annotation is busted
def is_tensorlist_like_type(typ: Any) -> bool:
return (
typ == _C.ListType(_C.TensorType.get())
or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
)
# should be torch._C.JitType but that annotation is busted
def is_tensor_like_type(typ: Any) -> bool:
return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
def mutates_and_returns_first_arg(op: OpOverload):
"""Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
Figure this out.
Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
"""
if op.namespace != "aten":
return False
schema = op._schema
if not len(schema.returns) == 1:
return False
if schema.returns[0].alias_info is None:
return False
alias_set = schema.returns[0].alias_info.after_set
if len(alias_set) != 1:
return False
loc = next(iter(alias_set))
if len(schema.arguments) < 1:
return False
first_arg = schema.arguments[0]
if first_arg.alias_info is None:
return False
if not first_arg.alias_info.is_write:
return False
alias_set = first_arg.alias_info.after_set
if len(alias_set) != 1:
return False
if loc != next(iter(alias_set)):
return False
for arg in schema.arguments[1:]:
if arg.alias_info is not None:
return False
return True
def fill_defaults(schema, args, kwargs):
new_args = []
new_kwargs = {}
for i in range(len(schema.arguments)):
info = schema.arguments[i]
if info.kwarg_only:
if info.name in kwargs:
new_kwargs[info.name] = kwargs[info.name]
else:
new_kwargs[info.name] = info.default_value
else:
if i < len(args):
new_args.append(args[i])
else:
new_args.append(info.default_value)
return tuple(new_args), new_kwargs
def zip_schema(
schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Iterable[Tuple[_C.Argument, Any]]:
"""zips schema.arguments and (args, kwargs) together.
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
that is, kwargs must be keyword-only arguments and default values may be omitted.
"""
assert len(schema.arguments) >= len(args) + len(kwargs)
for i in range(len(schema.arguments)):
info = schema.arguments[i]
if info.kwarg_only:
if info.name in kwargs:
yield info, kwargs[info.name]
continue
if i >= len(args):
# args that are equal to their default values are not populated
# if they are followed by args that are equal to their defaults.
# Skip these.
continue
yield info, args[i]
return
def hop_schema_from_fx_node(node):
from torchgen.gen_schema_utils import FunctionSchemaGen
hop = node.target
if not isinstance(hop, torch._ops.HigherOrderOperator):
raise RuntimeError("fx_node's target must be a hop.")
def _collect_example_val(node):
meta_val = node.meta.get("val", None)
if meta_val is None:
assert node.op == "get_attr"
meta_val = getattr(node.graph.owning_module, node.target)
return meta_val
example_inputs = []
for arg in node.args:
if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
example_inputs.append(_collect_example_val(arg))
elif isinstance(
arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
):
example_inputs.append([_collect_example_val(x) for x in arg])
else:
raise RuntimeError(f"Unsupported arg type {type(arg)}")
# Bound the arguments to make sure number of inputs are correct
bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
*example_inputs
)
# We treat example_output as a single value in return. This is to differentiate 1. return a single val
# vs 2. return a tuple with one element.
example_output = _collect_example_val(node)
return FunctionSchemaGen.from_example(
hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
)
def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
assert isinstance(op, OpOverload)
if is_builtin(op):
# We control the built-ins. These may (in rare cases)
# do input metadata mutation (which we have banned on custom ops)
return False
schema = op._schema
# It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
if not schema.is_mutable:
return False
if len(schema.returns) > 0:
return False
# If the op returns nothing, then it has a trivial fake impl.
return True
def requires_set_python_module() -> bool:
"""If an op was defined in C++ and extended from Python using the
torch.library APIs, returns if we require that there have been a
m.set_python_module("mylib.ops") call from C++ that associates
the C++ op with a python module.
"""
return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
overload_types = []
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
for a in args_flattened:
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
# where in one case we only include tensors with the python key, and in another
# we include **all** tensors.
if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
torch._C.DispatchKey.Python
):
overload_types.append(type(a))
# TODO: check that I got these args correct (in C++, we pass in "0000"??)
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
def has_kwarg_only_args(schema: _C.FunctionSchema):
return any(a.kwarg_only for a in schema.arguments)
def has_kwarg_only_tensors(schema: _C.FunctionSchema):
for a in schema.arguments:
if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
continue
if not a.kwarg_only:
continue
return True
return False
def has_tensor_arg(schema: _C.FunctionSchema) -> bool:
"""
Given a schema, returns True if the schema has a Tensor arg.
A Tensor arg is any arg with a type annotation that might involve Tensor.
"""
return any(
(is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type))
for a in schema.arguments
)
def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
"""
Given a schema, returns the id of the `device: torch.device` argument.
If it does not exist, returns None.
"""
for index, arg in enumerate(schema.arguments):
if arg.type is _C.DeviceObjType.get() and arg.name == "device":
return index
return None