I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,149 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
from torchgen import dest
# disable import sorting to avoid circular dependency.
from torchgen.api.types import DispatcherSignature # usort: skip
from torchgen.context import method_with_native_function
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
from torchgen.utils import concatMap, Target
if TYPE_CHECKING:
from torchgen.executorch.model import ETKernelIndex
from torchgen.selective_build.selector import SelectiveBuilder
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
# model authoring side.
@dataclass(frozen=True)
class ComputeNativeFunctionStub:
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
if Variant.function not in f.variants:
return None
sig = DispatcherSignature.from_schema(
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
)
assert sig is not None
if len(f.func.returns) == 0:
ret_name = ""
elif len(f.func.returns) == 1:
if f.func.arguments.out:
ret_name = f.func.arguments.out[0].name
else:
ret_name = next(
(
a.name
for a in f.func.arguments.flat_non_out
if a.type == f.func.returns[0].type
),
"",
)
if not ret_name:
# if return type is tensor
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
# Returns an empty tensor
ret_name = "at::Tensor()"
else:
raise Exception( # noqa: TRY002
f"Can't handle this return type {f.func}"
) # noqa: TRY002
elif len(f.func.arguments.out) == len(f.func.returns):
# Returns a tuple of out arguments
tensor_type = "at::Tensor &"
comma = ", "
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
{comma.join([r.name for r in f.func.arguments.out])}
)"""
else:
assert all(
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
), f"Only support tensor returns but got {f.func.returns}"
# Returns a tuple of empty tensors
tensor_type = "at::Tensor"
comma = ", "
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
{comma.join(["at::Tensor()" for _ in f.func.returns])}
)"""
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
return f"""
{sig.defn()} {{
{ret_str}
}}
"""
def gen_custom_ops_registration(
*,
native_functions: Sequence[NativeFunction],
selector: SelectiveBuilder,
kernel_index: ETKernelIndex,
rocm: bool,
) -> tuple[str, str]:
"""
Generate custom ops registration code for dest.RegisterDispatchKey.
:param native_functions: a sequence of `NativeFunction`
:param selector: for selective build.
:param kernel_index: kernels for all the ops.
:param rocm: bool for dest.RegisterDispatchKey.
:return: generated C++ code to register custom operators into PyTorch
"""
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
dispatch_key = DispatchKey.CPU
backend_index = kernel_index._to_backend_index()
static_init_dispatch_registrations = ""
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
for native_function in native_functions:
ns_grouped_native_functions[native_function.namespace].append(native_function)
for namespace, functions in ns_grouped_native_functions.items():
if len(functions) == 0:
continue
dispatch_registrations_body = "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.REGISTRATION,
selector,
rocm=rocm,
symint=False,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
functions,
)
)
)
static_init_dispatch_registrations += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{dispatch_registrations_body}
}};"""
anonymous_definition = "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
symint=False,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
native_functions,
)
)
)
return anonymous_definition, static_init_dispatch_registrations

View File

@ -0,0 +1,370 @@
from __future__ import annotations
from typing import Sequence
from torchgen import local
from torchgen.api.types import (
ArgName,
BaseCType,
Binding,
ConstRefCType,
CType,
MutRefCType,
NamedCType,
SpecialArgName,
TupleCType,
VectorCType,
voidT,
)
from torchgen.executorch.api.types import (
ArrayRefCType,
BaseTypeToCppMapping,
OptionalCType,
scalarT,
tensorListT,
tensorT,
)
from torchgen.model import (
Argument,
Arguments,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
"""
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
functions like at::add. It also serves as a native function API, which is the signature of kernels,
since in Executorch CppSignature is the same as NativeSignature.
Difference between this file and torchgen.api.cpp.py:
- Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
- Executorch doesn't support Dimname.
- Executorch runtime doesn't support SymInt, will treat it as int.
"""
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(
t: Type,
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
) -> NamedCType | None:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
# For SymInt we simply treat it as int.
elif str(t) == "SymInt":
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
if remove_non_owning_ref_types:
if t.name == BaseTy.str:
raise AssertionError(
"string ref->value conversion: not implemented yet"
)
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if str(t.elem) == "bool":
assert t.size is not None
return NamedCType(
binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]))
)
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occurring in JIT arguments to a C++ argument type.
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
# For example, we'll return std::vector<int> instead of IntArrayRef.
# See Note [translation from C++ reference to value types]
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
) -> NamedCType:
# If it's a value type, do the value type translation
r = valuetype_type(
t,
binds=binds,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == "Tensor":
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(
binds, MutRefCType(BaseCType(tensorT))
) # TODO: fix this discrepancy
else:
return NamedCType(
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
)
elif str(t.elem) == "Scalar":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
# TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
if str(t.elem) == "Tensor":
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Dimname":
raise NotImplementedError("Executorch doesn't support Dimname")
elif str(t.elem) == "Tensor?":
return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
# N.B: returntype_type returns a CType, not a NamedCType.
# This is mostly because of the mismatch between return types and return names.
# e.g. a function with a return type of 'void' has 0 return names,
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool) -> CType:
# placeholder is ignored
r = valuetype_type(t, binds="__placeholder__")
if r is not None:
return r.type
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
if local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
else:
return MutRefCType(BaseCType(tensorT))
else:
# Note [Tensor Copy Returns]
# Currently, we use "Argument.is_write" to determine
# whether or not Tensor return types should be copies or references.
# If that ever changes, take a look at other locations of this note!
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return) -> CType:
return returntype_type(r.type, mutable=r.is_write)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return]) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0])
else:
return TupleCType([return_type(r) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: list[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
# TODO: Consider incorporating this into the data model
if f.func.name.name.inplace:
assert i == 0, "illegal inplace function with multiple returns"
name = "self"
# If we are out function, the name is the name of the
# corresponding output function (r.name will get recorded
# in field_name later.)
elif f.func.is_out_fn():
name = f.func.arguments.out[i].name
# If the return argument is explicitly named...
elif r.name:
name_conflict = any(
r.name == a.name for a in f.func.schema_order_arguments()
)
if name_conflict and not f.func.is_out_fn():
name = f"{r.name}_return"
else:
name = r.name
# If there is no explicit name and no fallback name was passed in, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
returns.append(name)
return returns
JIT_TO_CPP_DEFAULT = {
"False": "false",
"True": "true",
"None": "torch::executorch::nullopt", # UGH this one is type directed
"[]": "{}",
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
"long": "torch::executorch::kLong",
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type) -> str:
if d == "None" and str(t) == "Tensor?":
return "{}"
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ""
i = 1
while i + 1 < len(d):
if d[i] != "\\":
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i : i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == "None":
return "torch::executor::nullopt"
return default_expr(d, t.elem)
if isinstance(t, ListType):
if d.startswith("[") and d.endswith("]"):
return "{" + d[1:-1] + "}"
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument(
a: Argument | TensorOptionsArguments | SelfArgument,
*,
cpp_no_default_args: set[str],
method: bool,
faithful: bool,
has_tensor_options: bool,
) -> list[Binding]:
def sub_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
method=method,
faithful=faithful,
has_tensor_options=has_tensor_options,
)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: str | None = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [
Binding(
nctype=argument_type(a, binds=binds),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, TensorOptionsArguments):
raise NotImplementedError("Need to implement type resolution for TensorOptions")
elif isinstance(a, SelfArgument):
if method:
# Caller is responsible for installing implicit this in context!
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
arguments: Arguments,
*,
faithful: bool,
method: bool,
cpp_no_default_args: set[str],
) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
else:
args.extend(arguments.out)
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r
for a in args
for r in argument(
a,
faithful=faithful,
method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args,
)
]

View File

@ -0,0 +1,4 @@
from torchgen.executorch.api.types.types import *
from torchgen.executorch.api.types.signatures import * # usort: skip

View File

@ -0,0 +1,76 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torchgen.api.cpp as aten_cpp
from torchgen.executorch.api.types.types import contextArg
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType
from torchgen.model import FunctionSchema, NativeFunction
@dataclass(frozen=True)
class ExecutorchCppSignature:
"""
This signature is merely a CppSignature with Executorch types (optionally
contains KernelRuntimeContext as well). The inline definition of
CppSignature is generated in Functions.h and it's used by unboxing
functions.
"""
# The schema this signature is derived from
func: FunctionSchema
# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: set[str]
# Allows you to prepend an arbitrary prefix to the signature name.
# This is useful for parts of the codegen that generate wrappers around kernels,
# and need to avoid naming collisions.
prefix: str = ""
def arguments(self, *, include_context: bool = True) -> list[Binding]:
return ([contextArg] if include_context else []) + et_cpp.arguments(
self.func.arguments,
faithful=True, # always faithful, out argument at the end
method=False, # method not supported
cpp_no_default_args=self.cpp_no_default_args,
)
def name(self) -> str:
return self.prefix + aten_cpp.name(
self.func,
faithful_name_for_out_overloads=True,
)
def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
args_str = ", ".join(
a.decl() for a in self.arguments(include_context=include_context)
)
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn(self, name: str | None = None) -> str:
args = [a.defn() for a in self.arguments()]
args_str = ", ".join(args)
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def returns_type(self) -> CType:
return et_cpp.returns_type(self.func.returns)
@staticmethod
def from_native_function(
f: NativeFunction, *, prefix: str = ""
) -> ExecutorchCppSignature:
return ExecutorchCppSignature(
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
)
from torchgen.executorch.api import et_cpp

View File

@ -0,0 +1,83 @@
from __future__ import annotations
from dataclasses import dataclass
from torchgen.api.types import (
BaseCppType,
BaseCType,
Binding,
boolT,
CType,
doubleT,
Expr,
longT,
MutRefCType,
NamedCType,
)
from torchgen.model import BaseTy
halfT = BaseCppType("torch::executor", "Half")
bfloat16T = BaseCppType("torch::executor", "BFloat16")
stringT = BaseCppType("torch::executor", "string_view")
scalarTypeT = BaseCppType("torch::executor", "ScalarType")
tensorT = BaseCppType("torch::executor", "Tensor")
tensorListT = BaseCppType("torch::executor", "TensorList")
scalarT = BaseCppType("torch::executor", "Scalar")
memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
optionalT = BaseCppType("torch::executor", "optional")
contextT = BaseCppType("torch::executor", "KernelRuntimeContext")
contextExpr = Expr(
expr="context",
type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
)
contextArg = Binding(
name="context",
nctype=contextExpr.type,
argument=None, # type: ignore[arg-type]
default=None,
)
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
BaseTy.str: stringT,
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
}
@dataclass(frozen=True)
class OptionalCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"torch::executor::optional<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return ArrayRefCType(self.elem.remove_const_ref())

View File

@ -0,0 +1,230 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Sequence, TYPE_CHECKING
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Type,
)
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType, NamedCType
connector = "\n\t"
# Return unboxing function name for a NativeFunction
def name(f: NativeFunction) -> str:
return f.func.name.unambiguous_name()
@dataclass(frozen=True)
class Unboxing:
"""
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
A sample generated code:
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
void mul_out(EValue** stack) {
EValue& self = *stack[0];
EValue& other = *stack[1];
EValue& out = *stack[2];
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
torch::executor::mul_outf(self_base, other_base, out_base);
}
"""
# this is a callable that converts a JIT argument, into its C++ type.
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
argument_type_gen: Callable[
...,
NamedCType,
]
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(
self, args: Sequence[Binding]
) -> tuple[list[Binding], list[str]]:
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
binding_list = []
for arg in args:
# expecting only Argument
if not isinstance(arg.argument, Argument):
raise Exception( # noqa: TRY002
f"Unexpected argument type, expecting `Argument` but got {arg}"
)
argument: Argument = arg.argument
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
argument.type, argument.name, mutable=argument.is_write
)
code_list.extend(decl)
code_list.extend(code)
binding_list.append(arg.with_name(unboxed_name))
return binding_list, code_list
def argumenttype_evalue_convert(
self, t: Type, arg_name: str, *, mutable: bool = False
) -> tuple[str, CType, list[str], list[str]]:
"""
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
(1) the C++ code necessary to unbox the argument
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
:param t: a `Type` of an argument
:param arg_name: argument name
:param mutable: boolean for whether this argument type is mutable
:return: unboxed result
"""
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"
code, decl = self._gen_code_base_type(
arg_name=arg_name, out_name=out_name, ctype=ctype
)
elif isinstance(t, OptionalType):
out_name = f"{arg_name}_opt_out"
code, decl = self._gen_code_optional_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
elif isinstance(t, ListType):
out_name = f"{arg_name}_list_out"
code, decl = self._gen_code_list_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
else:
raise Exception( # noqa: TRY002
f"Cannot handle type {t}. arg_name: {arg_name}"
) # noqa: TRY002
return out_name, ctype, code, decl
def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType
) -> tuple[list[str], list[str]]:
return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_opt_in"
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name
)
return (
f"""
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
""".split(
"\n"
),
decl,
)
def _gen_code_list_type(
self, arg_name: str, out_name: str, t: ListType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = []
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
t.elem, elem_name
)
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
code.extend(
f"""
auto {out_name} = {arg_name}.toTensorList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and (
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
):
code.extend(
f"""
auto {out_name} = {arg_name}.toIntList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
code.extend(
f"""
auto {out_name} = {arg_name}.toDoubleList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
# handle list type with size, e.g., bool[4]
code.extend(
f"""
#ifdef USE_ATEN_LIB
std::array<bool, {t.size}> {out_name};
auto {in_name} = {arg_name}.toBoolList();
size_t _i = 0;
for (auto {elem_name}: {in_name}) {{
{out_name}[_i++] = {elem_name};
}}
#else
auto {out_name} = {arg_name}.toBoolList();
#endif
""".split(
"\n"
)
)
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif (
isinstance(t.elem, OptionalType)
and isinstance(t.elem.elem, BaseType)
and t.elem.elem.name == BaseTy.Tensor
):
code.extend(
f"""
#ifdef USE_ATEN_LIB
auto {in_name} = {arg_name}.toListOptionalTensor();
c10::List<::std::optional<at::Tensor>> {out_name};
for (auto {elem_name}: {in_name}) {{
{out_name}.push_back({elem_name});
}}
#else
auto {out_name} = {arg_name}.toListOptionalTensor();
#endif
""".split(
"\n"
)
)
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
)
code.extend(
f"""
for (EValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
return code, decl

View File

@ -0,0 +1,220 @@
# Represents all kernels used by an Executorch model.
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
from __future__ import annotations
import itertools
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from enum import IntEnum
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
)
from torchgen.utils import assert_never
KERNEL_KEY_VERSION = 1
# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
class ScalarType(IntEnum):
Byte = 0
Char = 1
Short = 2
Int = 3
Long = 4
Float = 6
Double = 7
Bool = 11
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
@dataclass(frozen=True)
class ETKernelKeyOpArgMeta:
arg_name: str
dtype: str
# The order of the dimensions if entry is a Tensor
dim_order: tuple[int, ...]
def to_native_string(self) -> str:
dtype_str = ScalarType[self.dtype].value
dim_str = str(self.dim_order)[1:-1].replace(" ", "")
return f"{dtype_str};{dim_str}"
@dataclass(frozen=True)
class ETKernelKey:
# Field undefined is default = True
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
# Indicator for this kernel being used as a catch all
default: bool = False
version: int = KERNEL_KEY_VERSION
@staticmethod
def gen_from_yaml(
args: dict[str, tuple[str, str]],
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
dim_order_alias_map: dict[str, list[int]],
) -> list[ETKernelKey]:
"""Generate ETKernelKeys from arg kernel specs
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
type_alias_map (actualizing each potential type permutation as a KernelKey)
Args:
args: Mapping from argument name to kernel specs
Kernel specs are a tuple of (dtype, dim_order).
Currently tuple entries must be aliased via the alias map arguments
type_alias_map: Mapping from type alias to potential type enums
i.e { T0 : [Double, Int] } means T0 can be either Double or Int
Used for lookup by args
dim_order_alias_map: Mapping from alias to a list of dimension orders
Used for lookup by args
"""
# Cast to dim order to int
dim_order_alias_map = {
k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
}
kernel_keys = []
# Get all used Dtype Alias
dtype_alias_used = set()
for type_alias, dim_order in args.values():
# Enforce usage of alias initially
# TODO: Support inlined arguments
assert type_alias in type_alias_map, "Undefined type alias: " + str(
type_alias
)
assert (
dim_order in dim_order_alias_map
), "Undefined dim_order alias: " + str(dim_order)
dtype_alias_used.add(type_alias)
# Generate all permutations of dtype alias values
alias_dtypes = [
[(alias, dtype) for dtype in type_alias_map[alias]]
for alias in dtype_alias_used
]
alias_permutations = [
dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
]
# Using each alias value permutation, generate kernel keys
op_arg_cache = {}
for permutation in alias_permutations:
arg_list = []
for arg_name, arg_spec in args.items():
dtype = permutation[arg_spec[0]]
dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
if (
cache_key := (arg_name, dtype, tuple(dim_order))
) not in op_arg_cache:
op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
arg_list.append(op_arg_cache[cache_key])
kernel_keys.append(ETKernelKey(tuple(arg_list)))
return kernel_keys
def to_native_string(self) -> str:
if self.default:
return "default"
return (
"v"
+ str(KERNEL_KEY_VERSION)
+ "/"
+ "|".join([arg.to_native_string() for arg in self.arg_meta])
)
@dataclass(frozen=True)
class ETKernelIndex:
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
m = self.get_kernels(g)
return m is not None
def get_kernels(
self, g: NativeFunction | NativeFunctionsGroup
) -> dict[ETKernelKey, BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
f = g.functional
else:
assert_never(g)
if f.func.name not in self.index:
return {}
return self.index[f.func.name]
@staticmethod
def grow_from_backend_indices(
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> None:
for dk in backend_indices:
index = backend_indices[dk]
for op, backend_metadata in index.items():
if op in kernel_index:
kernel_index[op][ETKernelKey(default=True)] = backend_metadata
else:
kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
@staticmethod
def from_backend_indices(
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> ETKernelIndex:
kernel_index: dict[
OperatorName, dict[ETKernelKey, BackendMetadata]
] = defaultdict(dict)
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
return ETKernelIndex(kernel_index)
def grow(
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> ETKernelIndex:
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
return self
def _to_backend_index(self) -> BackendIndex:
"""
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
"""
index: dict[OperatorName, BackendMetadata] = {}
for op in self.index:
kernel_dict = self.index[op]
assert (
len(kernel_dict.values()) == 1
), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
index[op] = kernel_dict.get(
ETKernelKey(default=True),
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
)
return BackendIndex(
dispatch_key=DispatchKey.CPU,
use_out_as_primary=False,
device_guard=False,
external=False,
index=index,
)
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
@staticmethod
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
combined = defaultdict(dict, index_a.index.copy())
for op, entry in index_b.index.items():
for key, metadata in entry.items():
combined[op][key] = metadata
return ETKernelIndex(combined)

View File

@ -0,0 +1,153 @@
from __future__ import annotations
from collections import defaultdict, namedtuple
from typing import Any
import yaml
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader, parse_native_yaml
from torchgen.model import (
BackendMetadata,
DispatchKey,
FunctionSchema,
NativeFunction,
OperatorName,
)
from torchgen.utils import NamespaceHelper
# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
# Fields in native_functions.yaml used to determine which kernels should be used
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
"""Given a loaded yaml representing kernel assignment information, extract the
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
Args:
ei: Dict keys {kernels, type_alias, dim_order_alias}
See ETKernelKey for description of arguments
"""
e = ei.copy()
if (kernels := e.pop("kernels", None)) is None:
return {}
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
dim_order_alias.pop("__line__", None)
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
for entry in kernels: # type: ignore[attr-defined]
arg_meta = entry.get("arg_meta")
if arg_meta is not None:
arg_meta.pop("__line__")
kernel_name = entry.get("kernel_name")
namespace_helper = NamespaceHelper.from_namespaced_entity(
kernel_name, max_level=3
)
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
backend_metadata = BackendMetadata(
kernel=namespace_helper.entity_name,
structured=False,
cpp_namespace=(kernel_namespace + "::native"),
)
kernel_keys = (
[ETKernelKey((), default=True)]
if arg_meta is None
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
)
for kernel_key in kernel_keys:
assert kernel_key not in kernel_mapping, (
"Duplicate kernel key: " + str(kernel_key) + " " + str(e)
)
kernel_mapping[kernel_key] = backend_metadata
return kernel_mapping
def parse_et_yaml_struct(es: object) -> ETKernelIndex:
"""Given a loaded yaml representing a list of operators, for each op extract the mapping
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
that should be used by the kernel key).
"""
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
for ei in es: # type: ignore[attr-defined]
e = ei.copy()
funcs = e.pop("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
opname = FunctionSchema.parse(namespace_helper.entity_name).name
assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
if len(index := parse_from_yaml(e)) != 0:
indices[opname] = index
return ETKernelIndex(indices)
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
"""Given a loaded yaml representing a list of operators, extract the
kernel key related fields indexed by the operator name.
"""
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
for ei in es: # type: ignore[attr-defined]
funcs = ei.get("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
opname = FunctionSchema.parse(namespace_helper.entity_name).name
for field in ET_FIELDS:
if (value := ei.get(field)) is not None:
fields[opname][field] = value
return fields
def parse_et_yaml(
path: str,
tags_yaml_path: str,
ignore_keys: set[DispatchKey] | None = None,
skip_native_fns_gen: bool = False,
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
of fields to persist from native_functions.yaml to functions.yaml
"""
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
et_kernel = extract_kernel_fields(es)
# Remove ET specific fields from entries for BC compatibility
strip_et_fields(es)
native_yaml = parse_native_yaml(
path,
tags_yaml_path,
ignore_keys,
skip_native_fns_gen=skip_native_fns_gen,
loaded_yaml=es,
)
return native_yaml.native_functions, et_kernel
def strip_et_fields(es: object) -> None:
"""Given a loaded yaml representing a list of operators,
remove ET specific fields from every entries for BC compatibility
"""
for entry in es: # type: ignore[attr-defined]
for field in ET_FIELDS:
entry.pop(field, None)