319 lines
10 KiB
Python
319 lines
10 KiB
Python
# mypy: allow-untyped-defs
|
|
import dataclasses
|
|
import inspect
|
|
import sys
|
|
from typing import Any, Callable, Dict, Iterable, Tuple, Union
|
|
|
|
import torch
|
|
from torch import _C, _utils_internal
|
|
from torch._ops import OpOverload
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Kernel:
|
|
"""Models a (function, source location)"""
|
|
|
|
func: Callable
|
|
source: str
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.func(*args, **kwargs)
|
|
|
|
|
|
class RegistrationHandle:
|
|
"""Does something when someone calls .destroy() on it"""
|
|
|
|
def __init__(self, on_destroy: Callable):
|
|
self._on_destroy = on_destroy
|
|
|
|
def destroy(self) -> None:
|
|
self._on_destroy()
|
|
|
|
|
|
def get_source(stacklevel: int) -> str:
|
|
"""Get a string that represents the caller.
|
|
|
|
Example: "/path/to/foo.py:42"
|
|
|
|
Use stacklevel=1 to get the caller's source
|
|
Use stacklevel=2 to get the caller's caller's source
|
|
etc.
|
|
"""
|
|
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
|
source = f"{frame.filename}:{frame.lineno}"
|
|
return source
|
|
|
|
|
|
def parse_namespace(qualname: str) -> Tuple[str, str]:
|
|
splits = qualname.split("::")
|
|
if len(splits) != 2:
|
|
raise ValueError(
|
|
f"Expected `qualname` to be of the form "
|
|
f'"namespace::name", but got {qualname}. '
|
|
f"The qualname passed to the torch.library APIs must consist "
|
|
f"of a namespace and a name, e.g. aten::sin"
|
|
)
|
|
return splits[0], splits[1]
|
|
|
|
|
|
def lookup_op(qualname: str) -> OpOverload:
|
|
namespace, name = parse_namespace(qualname)
|
|
if "." in name:
|
|
name, overload = name.split(".")
|
|
else:
|
|
overload = "default"
|
|
ns = getattr(torch.ops, namespace)
|
|
packet = getattr(ns, name)
|
|
return getattr(packet, overload)
|
|
|
|
|
|
def is_builtin(op: OpOverload) -> bool:
|
|
assert isinstance(op, OpOverload)
|
|
return op.namespace in {"aten", "prim", "prims"}
|
|
|
|
|
|
def is_functional_schema(schema: Any) -> bool:
|
|
"""Check if the schema is functional.
|
|
|
|
An operator is functional if:
|
|
- it does not mutate any of its inputs
|
|
- it does not return a view on any of its inputs
|
|
- it has at least one return
|
|
"""
|
|
|
|
def is_functional(schema):
|
|
if schema.is_mutable:
|
|
return False
|
|
rets = schema.returns
|
|
is_non_mutating_view = len(rets) > 0 and any(
|
|
r.alias_info is not None and not r.alias_info.is_write for r in rets
|
|
)
|
|
if is_non_mutating_view:
|
|
return False
|
|
if not schema.returns:
|
|
return False
|
|
return True
|
|
|
|
if isinstance(schema, torch._C.FunctionSchema):
|
|
return is_functional(schema)
|
|
|
|
# Lazy import because not all PyTorch builds have torchgen
|
|
from torchgen.model import FunctionSchema
|
|
|
|
if isinstance(schema, str):
|
|
schema = FunctionSchema.parse(schema)
|
|
assert isinstance(schema, FunctionSchema)
|
|
return is_functional(schema)
|
|
|
|
|
|
# should be torch._C.JitType but that annotation is busted
|
|
def is_tensorlist_like_type(typ: Any) -> bool:
|
|
return (
|
|
typ == _C.ListType(_C.TensorType.get())
|
|
or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
|
|
or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
|
|
or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
|
|
)
|
|
|
|
|
|
# should be torch._C.JitType but that annotation is busted
|
|
def is_tensor_like_type(typ: Any) -> bool:
|
|
return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
|
|
|
|
|
|
def mutates_and_returns_first_arg(op: OpOverload):
|
|
"""Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
|
|
|
|
TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
|
|
but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
|
|
Figure this out.
|
|
|
|
Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
|
|
"""
|
|
if op.namespace != "aten":
|
|
return False
|
|
schema = op._schema
|
|
if not len(schema.returns) == 1:
|
|
return False
|
|
if schema.returns[0].alias_info is None:
|
|
return False
|
|
alias_set = schema.returns[0].alias_info.after_set
|
|
if len(alias_set) != 1:
|
|
return False
|
|
loc = next(iter(alias_set))
|
|
if len(schema.arguments) < 1:
|
|
return False
|
|
first_arg = schema.arguments[0]
|
|
if first_arg.alias_info is None:
|
|
return False
|
|
if not first_arg.alias_info.is_write:
|
|
return False
|
|
alias_set = first_arg.alias_info.after_set
|
|
if len(alias_set) != 1:
|
|
return False
|
|
if loc != next(iter(alias_set)):
|
|
return False
|
|
for arg in schema.arguments[1:]:
|
|
if arg.alias_info is not None:
|
|
return False
|
|
return True
|
|
|
|
|
|
def fill_defaults(schema, args, kwargs):
|
|
new_args = []
|
|
new_kwargs = {}
|
|
for i in range(len(schema.arguments)):
|
|
info = schema.arguments[i]
|
|
if info.kwarg_only:
|
|
if info.name in kwargs:
|
|
new_kwargs[info.name] = kwargs[info.name]
|
|
else:
|
|
new_kwargs[info.name] = info.default_value
|
|
else:
|
|
if i < len(args):
|
|
new_args.append(args[i])
|
|
else:
|
|
new_args.append(info.default_value)
|
|
return tuple(new_args), new_kwargs
|
|
|
|
|
|
def zip_schema(
|
|
schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
|
) -> Iterable[Tuple[_C.Argument, Any]]:
|
|
"""zips schema.arguments and (args, kwargs) together.
|
|
|
|
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
|
|
that is, kwargs must be keyword-only arguments and default values may be omitted.
|
|
"""
|
|
assert len(schema.arguments) >= len(args) + len(kwargs)
|
|
for i in range(len(schema.arguments)):
|
|
info = schema.arguments[i]
|
|
if info.kwarg_only:
|
|
if info.name in kwargs:
|
|
yield info, kwargs[info.name]
|
|
continue
|
|
if i >= len(args):
|
|
# args that are equal to their default values are not populated
|
|
# if they are followed by args that are equal to their defaults.
|
|
# Skip these.
|
|
continue
|
|
yield info, args[i]
|
|
return
|
|
|
|
|
|
def hop_schema_from_fx_node(node):
|
|
from torchgen.gen_schema_utils import FunctionSchemaGen
|
|
|
|
hop = node.target
|
|
if not isinstance(hop, torch._ops.HigherOrderOperator):
|
|
raise RuntimeError("fx_node's target must be a hop.")
|
|
|
|
def _collect_example_val(node):
|
|
meta_val = node.meta.get("val", None)
|
|
if meta_val is None:
|
|
assert node.op == "get_attr"
|
|
meta_val = getattr(node.graph.owning_module, node.target)
|
|
return meta_val
|
|
|
|
example_inputs = []
|
|
for arg in node.args:
|
|
if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
|
|
example_inputs.append(_collect_example_val(arg))
|
|
elif isinstance(
|
|
arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
|
|
):
|
|
example_inputs.append([_collect_example_val(x) for x in arg])
|
|
else:
|
|
raise RuntimeError(f"Unsupported arg type {type(arg)}")
|
|
|
|
# Bound the arguments to make sure number of inputs are correct
|
|
bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
|
|
*example_inputs
|
|
)
|
|
|
|
# We treat example_output as a single value in return. This is to differentiate 1. return a single val
|
|
# vs 2. return a tuple with one element.
|
|
example_output = _collect_example_val(node)
|
|
return FunctionSchemaGen.from_example(
|
|
hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
|
|
)
|
|
|
|
|
|
def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
|
|
assert isinstance(op, OpOverload)
|
|
if is_builtin(op):
|
|
# We control the built-ins. These may (in rare cases)
|
|
# do input metadata mutation (which we have banned on custom ops)
|
|
return False
|
|
schema = op._schema
|
|
# It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
|
|
if not schema.is_mutable:
|
|
return False
|
|
if len(schema.returns) > 0:
|
|
return False
|
|
# If the op returns nothing, then it has a trivial fake impl.
|
|
return True
|
|
|
|
|
|
def requires_set_python_module() -> bool:
|
|
"""If an op was defined in C++ and extended from Python using the
|
|
torch.library APIs, returns if we require that there have been a
|
|
m.set_python_module("mylib.ops") call from C++ that associates
|
|
the C++ op with a python module.
|
|
"""
|
|
return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
|
|
|
|
|
|
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
|
|
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
|
|
overload_types = []
|
|
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
|
|
for a in args_flattened:
|
|
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
|
|
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
|
|
# where in one case we only include tensors with the python key, and in another
|
|
# we include **all** tensors.
|
|
if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
|
|
torch._C.DispatchKey.Python
|
|
):
|
|
overload_types.append(type(a))
|
|
# TODO: check that I got these args correct (in C++, we pass in "0000"??)
|
|
|
|
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
|
|
|
|
|
|
def has_kwarg_only_args(schema: _C.FunctionSchema):
|
|
return any(a.kwarg_only for a in schema.arguments)
|
|
|
|
|
|
def has_kwarg_only_tensors(schema: _C.FunctionSchema):
|
|
for a in schema.arguments:
|
|
if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
|
|
continue
|
|
if not a.kwarg_only:
|
|
continue
|
|
return True
|
|
return False
|
|
|
|
|
|
def has_tensor_arg(schema: _C.FunctionSchema) -> bool:
|
|
"""
|
|
Given a schema, returns True if the schema has a Tensor arg.
|
|
A Tensor arg is any arg with a type annotation that might involve Tensor.
|
|
"""
|
|
return any(
|
|
(is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type))
|
|
for a in schema.arguments
|
|
)
|
|
|
|
|
|
def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
|
|
"""
|
|
Given a schema, returns the id of the `device: torch.device` argument.
|
|
If it does not exist, returns None.
|
|
"""
|
|
for index, arg in enumerate(schema.arguments):
|
|
if arg.type is _C.DeviceObjType.get() and arg.name == "device":
|
|
return index
|
|
return None
|