I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,870 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import cast, Sequence
from torchgen import local
from torchgen.api import cpp
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
from torchgen.model import (
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsViewGroup,
SchemaKind,
Type,
)
from torchgen.utils import IDENT_REGEX
# Represents a saved attribute involved in backward calculation.
# Note that it can be a derived property of an input argument, e.g.:
# we could save `other.scalar_type()` instead of the entire `other` tensor.
@dataclass(frozen=True)
class SavedAttribute:
# The NamedCType holds the updated name and cpp type of the attribute
# for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
nctype: NamedCType
# The expression to read the derived property at save time, e.g.:
# `other.scalar_type()`.
expr: str
# Represents a backward formula that calculates derivatives for one
# or more tensors.
@dataclass(frozen=True)
class Derivative:
# The formula string (legit C++ expression).
# Note that expressions against input arguments have been replaced with the
# corresponding saved attributes.
# E.g.:
# raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
# here: `mul_tensor_backward(grad, self, other_scalar_type)`
formula: str
# The formula string before input argument replacement
original_formula: str
# Names of the arguments for which this formula calculates derivatives.
var_names: tuple[str, ...]
# Saved inputs that are referenced by the formula.
saved_inputs: tuple[SavedAttribute, ...]
# Saved outputs that are referenced by the formula.
saved_outputs: tuple[SavedAttribute, ...]
# Gradients that are referenced by name in the formula.
named_gradients: set[str]
# Represents a forward formula that calculates forward derivatives
# for one tensor.
@dataclass(frozen=True)
class ForwardDerivative:
# The formula string (legit C++ expression).
# Note that special keywords such as "linear" or "element_wise" have been
# replaced by the automatically generated formula.
formula: str
# Name of the output arguments for which this formula calculates forward
# derivatives
var_names: tuple[str, ...]
# Type of the output arguments for which this formula calculates forward
# derivatives
var_types: tuple[Type, ...]
# Inputs for which the forward derivatives are required for this formula
required_inputs_fw_grad: tuple[str, ...] | None
# Inputs for which the primal is required for this formula
required_inputs_primal: tuple[str, ...] | None
# Flag to specify if this formula requires the original value of self
# This is only used by inplace operations
required_original_self_value: bool
# If this formula is specified in derivatives.yaml or if we are re-using the
# out of place formula for inplace
is_reusing_outplace_formula: bool
# Represents differentiability info for a NativeFunction.
@dataclass(frozen=True)
class DifferentiabilityInfo:
# The base name read from derivatives.yaml.
name: str
# The matching native function.
#
# There can be multiple NativeFunction having the same base name:
# - different overloads with different types of input arguments;
# - in-place/out/functional variants of the same function;
#
# We first use the schema string (under the 'name' key) in derivatives.yaml
# to find the NativeFunction having the same schema string.
# Then we find the in-place/out/functional variants of the matching function.
# Among these variants, we choose the one having the same name as the
# derivatives.yaml entry. If there is no exact match, then we choose the
# in-place variant.
# TODO: maybe the logic to search for all variants is no longer necessary?
func: NativeFunction
# The name of the generated autograd function.
# It's set only if we will calculate a derivative, i.e.
# 'args_with_derivatives' is not empty.
op: str | None
# The derivatives formulae for this function.
# Note that the length of this sequence is the number of differentiable inputs
derivatives: Sequence[Derivative]
# The forward derivatives formulae for this function.
# Note that the length of this sequence is the number of differentiable outputs
forward_derivatives: Sequence[ForwardDerivative]
# The union of 'saved_inputs' of all 'derivatives'.
all_saved_inputs: Sequence[SavedAttribute]
# The union of 'saved_outputs' of all 'derivatives'.
all_saved_outputs: Sequence[SavedAttribute]
# All named gradients that are available for use, in the same
# order as in the grads vector.
available_named_gradients: Sequence[str]
# The named gradients that are used in any of the derivatives.
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
used_named_gradients: set[str]
# The function's input arguments for which it calculates derivatives.
# It's the union of 'var_names' of all 'derivatives', sorted by the
# argument order in the function schema.
args_with_derivatives: Sequence[Binding]
# Names of arguments whose derivative formula is 'non_differentiable'.
non_differentiable_arg_names: Sequence[str]
# Raw data read from derivatives.yaml.
output_differentiability: list[bool] | None
# output_differentiability in derivatives.yaml can be a list of
# conditions that express if the output is differentiable. In this case,
# the number of conditions must match the number of outputs
# (NB: we only support one condition right now).
# output_differentiability gets populated with True for each condition,
# while output_differentiability_conditions gets populated with the conditions
output_differentiability_conditions: list[str] | None
@property
def has_derivatives(self) -> bool:
return len(self.args_with_derivatives) > 0
# Generates a new DifferentiabilityInfo using the exact same set of derivative information,
# but with a new operator name.
# This is used when generating "copy" variants of view ops,
# which are able to use the exact same derivative formula as the original view op
# See Note [Codegen'd {view}_copy Operators]
def create_view_copy_from_view_derivative(
self, g: NativeFunctionsViewGroup
) -> DifferentiabilityInfo | None:
if g.view_copy is None:
return None
f = g.view_copy
name_split_by_period = self.name.split(".", maxsplit=2)
# Append a "_copy" to the base name of the operator (but keep the overload name the same)
view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
name_split_by_period[1:]
)
view_copy_op_name = None if self.op is None else f"{self.op}_copy"
return DifferentiabilityInfo(
# Use the "_copy" version of name/func/op
name=view_copy_name,
func=f,
op=view_copy_op_name,
# But keep all derivative info the same
derivatives=self.derivatives,
forward_derivatives=self.forward_derivatives,
all_saved_inputs=self.all_saved_inputs,
all_saved_outputs=self.all_saved_outputs,
available_named_gradients=self.available_named_gradients,
used_named_gradients=self.used_named_gradients,
args_with_derivatives=self.args_with_derivatives,
non_differentiable_arg_names=self.non_differentiable_arg_names,
output_differentiability=self.output_differentiability,
output_differentiability_conditions=self.output_differentiability_conditions,
)
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
if info is None:
return False
for derivative in info.derivatives:
formula = derivative.formula
if re.search(IDENT_REGEX.format(ident), formula):
return True
return False
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "retain_variables")
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "grad")
# Represents a differentiable `Argument`.
# How is it different from the `Argument` type?
# - It's processed Arguments which are differentiable and only used in the
# context of the autograd codegen;
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
@dataclass(frozen=True)
class DifferentiableInput:
name: str
type: Type
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
cpp_type: str
# Represents a differentiable `Return`.
# How it it different from the `Return` type?
# - The name in `Return` is optional. Here it is always populated using the same
# `cpp.return_names()` method.
# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
# - It's processed Returns which are differentiable, in compliance with the
# `output_differentiability` field defined in derivatives.yaml (if specified),
# and are only used in the context of the autograd codegen;
@dataclass(frozen=True)
class DifferentiableOutput:
name: str
type: Type
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
cpp_type: str
@dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo:
func: NativeFunction
info: dict[str, DifferentiabilityInfo] | None
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
# TODO: Update comment below since it is out of date.
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
"""How are we going to call the underlying implementation of a
declaration? There are two strategies:
- use_derived: we want to call the implementation on CPUDoubleType
(or a similar, derived Type instance). Because these derived
instances deal in Tensors, not Variables (it's a completely different
object, so it doesn't dispatch back to VariableType), code on
this dispatch path needs to wrap/unwrap tensors. If the
derived implementation takes and returns tensors, the
implementation is usually differentiable (although we also use
the derived dispatch path for non-differentiable functions
that we still want to dispatch on the derived Type instance;
e.g., size())
- use_type: we want to call the implementation on Type, because
it is implemented concretely, and the functions it invokes will
get dispatched back to VariableType (which will ensure that they
are differentiable.)
"""
# fn is derived as long as any of its per-key differentiability infos
# has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
# and ADInplaceOrViewType. We want to generate these functions as long as a
# derivative is defined for ANY dispatch key.
if fn.func.is_abstract or (
fn.info is not None and any(info.has_derivatives for info in fn.info.values())
):
# If the function is abstract (not implemented on at::Type), we must
# call the implementation on the derived type with unpacked tensors.
# If the function has a derivative specified and is concrete, we could
# call either implementation. We prefer the calling the derived
# type's implementation with unpacked tensors because it is more
# performant in some cases: any internal calls to other ATen functions
# won't have the history tracked.
# If the function has a type dispatched argument (i.e. is a factory),
# we prefer calling the derived type's implementation both because it is
# more performant and to ensure factory functions return tensors with _version
# of 0 (probably not strictly necessary, but nice to have to keeps versions simple
# to understand.
return "use_derived"
else:
# If the function is concrete (we don't have to override it) and we
# didn't declare it in derivatives.yaml, we'll assume that it is
# actually implemented out of differentiable functions. (This
# assumption might not hold, but then you'll see gradcheck fail.)
return "use_type"
def is_foreach_func(f: NativeFunction) -> bool:
return f.func.name.name.base.startswith("_foreach_")
# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
_foreach_with_inplace_ref = {"_foreach_zero_"}
_foreach_with_tensor_overload = {
"_foreach_add.Tensor",
"_foreach_mul.Tensor",
"_foreach_div.Tensor",
}
# The following do not support the alpha kwarg, which the nonforeach versions support.
_skip_argument_len_check = {
"_foreach_add.Scalar",
"_foreach_add_.Scalar",
"_foreach_add.ScalarList",
"_foreach_add_.ScalarList",
"_foreach_sub.Scalar",
"_foreach_sub_.Scalar",
"_foreach_sub.ScalarList",
"_foreach_sub_.ScalarList",
}
# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
# reference to generate derivatives.
def is_reference_for_foreach(
f: NativeFunction,
function_schema: FunctionSchema,
) -> bool:
return (
f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
and (
not function_schema.name.name.inplace
or str(f.func.name) in _foreach_with_inplace_ref
)
and (
str(f.func.name) in _skip_argument_len_check
or len(f.func.arguments.flat_non_out)
== len(function_schema.arguments.flat_non_out)
)
and all(
ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
for arg, ref_arg in zip(
f.func.arguments.flat_non_out,
function_schema.arguments.flat_non_out,
)
)
)
# TODO(crcrpar): Avoid hard coding "Default" ideally.
def gen_foreach_derivativeinfo(
foreach_function: NativeFunction,
functional_info_by_signature: dict[
FunctionSchema, dict[str, DifferentiabilityInfo]
],
non_functional_info_by_signature: dict[
FunctionSchema, dict[str, DifferentiabilityInfo]
],
dispatch_key: str = "Default",
) -> tuple[DifferentiabilityInfo | None, bool]:
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
The second return value indicates whether the info is generated in this function.
"""
ref_diff_info: DifferentiabilityInfo | None = None
for function_schema, diff_info in functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
continue
ref_diff_info = diff_info[dispatch_key]
if ref_diff_info is not None:
break
# note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
# while the info of `zero_` is in non_functional_info_by_signature
if (
ref_diff_info is None
and foreach_function.func.kind() == SchemaKind.inplace
and str(foreach_function.func.name) in _foreach_with_inplace_ref
):
for function_schema, diff_info in non_functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
continue
ref_diff_info = diff_info[dispatch_key]
if ref_diff_info is not None:
break
if ref_diff_info is None:
return None, False
# non out-place uses the existing Derivative.
if foreach_function.func.kind() == SchemaKind.inplace:
return ref_diff_info, False
map_refarg2foreacharg, map_name2arg = {}, {}
for i, (arg, ref_arg) in enumerate(
zip(
foreach_function.func.arguments.flat_non_out,
function_schema.arguments.flat_non_out,
)
):
map_refarg2foreacharg[ref_arg.name] = arg.name
map_name2arg[arg.name] = arg
all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
modified_derivative_formulas = []
for i, derivative in enumerate(ref_diff_info.derivatives):
modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
"result", "result[i]"
)
saved_inputs, saved_outputs = [], []
# note(crcrpar): This context seems necessary to call `cpp.argument_type`
with local.parametrize(
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
):
for ref_input in derivative.saved_inputs:
ref_input_jit_name = ref_input.expr.split(".")[0]
mapped_name = map_refarg2foreacharg[ref_input_jit_name]
if isinstance(map_name2arg[mapped_name].type, ListType):
mapped_expr = mapped_name + "[i]"
else:
mapped_expr = mapped_name
new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
modified_formula = modified_formula.replace(
cast(str, ref_input.nctype.name), new_expr
)
nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
canonical_nctype = NamedCType(
nctype.name, nctype.type.remove_const_ref()
)
saved_inputs.append(
SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
)
for ref_output in derivative.saved_outputs:
if ref_output.nctype.name == "result":
saved_outputs.append(
SavedAttribute(
nctype=NamedCType(
name="result", type=BaseCType(tensorListT)
),
expr="result",
)
)
else:
raise RuntimeError("")
var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
all_var_names.extend(var_names)
all_saved_inputs.extend(saved_inputs)
all_saved_outputs.extend(saved_outputs)
modified_derivative = Derivative(
formula=modified_formula,
original_formula=derivative.formula,
var_names=tuple(var_names),
saved_inputs=tuple(saved_inputs),
saved_outputs=tuple(saved_outputs),
named_gradients=set(),
)
modified_derivative_formulas.append(modified_derivative)
with local.parametrize(
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
):
args_with_derivatives = [
Binding(
name=arg.name,
nctype=cpp.argument_type(arg, binds=arg.name),
argument=arg,
default=None,
)
for arg in foreach_function.func.arguments.flat_non_out
if arg.name in all_var_names
]
forward_derivatives: list[ForwardDerivative] = []
fw_derivative: ForwardDerivative
for fw_derivative in ref_diff_info.forward_derivatives:
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: list[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: list[str] = []
required_inputs_primal: list[str] = []
if fw_derivative.required_inputs_fw_grad is not None:
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
if fw_derivative.required_inputs_primal:
required_inputs_primal = list(fw_derivative.required_inputs_primal)
modified_formula = fw_derivative.formula
# Foreach's result is TensorList
if "result" in modified_formula:
modified_formula = fw_derivative.formula.replace("result", "result[i]")
for foreach_arg, ref_arg in zip(
foreach_function.func.arguments.flat_non_out,
ref_diff_info.func.func.arguments.flat_non_out,
):
# Modify reference forward formula
if (
isinstance(foreach_arg.type, ListType)
and not foreach_arg.type.is_tensor_like()
):
# Assuming ScalarList
modified_formula = modified_formula.replace(
ref_arg.name, foreach_arg.name + "[i]"
)
elif foreach_arg.type.is_tensor_like():
# Assuming TensorList / Tensor
# assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
assert isinstance(foreach_arg.type, ListType) or (
foreach_arg.type == BaseType(BaseTy.Tensor)
and str(foreach_function.func.name) in _foreach_with_tensor_overload
), f"{foreach_function.func.name}, {foreach_arg.type}"
for suffix in ("_p", "_t"):
curr_expr = ref_arg.name + suffix
if curr_expr in modified_formula:
new_expr = foreach_arg.name + suffix
modified_formula = modified_formula.replace(curr_expr, new_expr)
else:
# Assuming Scalar
if foreach_arg.name != ref_arg.name:
modified_formula = modified_formula.replace(
ref_arg.name, foreach_arg.name
)
# note(crcrpar): there should exist a cooler way...
for i, name in enumerate(var_names):
if name == ref_arg.name:
var_names[i] = foreach_arg.name
var_types[i] = foreach_arg.type
for i, name in enumerate(required_inputs_fw_grad):
if name == ref_arg.name:
required_inputs_fw_grad[i] = foreach_arg.name
for i, name in enumerate(required_inputs_primal):
if name == ref_arg.name:
required_inputs_primal[i] = foreach_arg.name
forward_derivatives.append(
ForwardDerivative(
formula=modified_formula,
var_names=tuple(var_names),
var_types=tuple(var_types),
required_inputs_fw_grad=tuple(required_inputs_fw_grad),
required_inputs_primal=tuple(required_inputs_primal),
required_original_self_value=fw_derivative.required_original_self_value,
is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
)
)
return (
DifferentiabilityInfo(
name=foreach_function.func.name.name.base,
func=foreach_function,
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
derivatives=modified_derivative_formulas,
forward_derivatives=forward_derivatives,
all_saved_inputs=tuple(set(all_saved_inputs)),
all_saved_outputs=tuple(set(all_saved_outputs)),
available_named_gradients=(),
used_named_gradients=set(),
args_with_derivatives=args_with_derivatives,
non_differentiable_arg_names=[],
output_differentiability=None,
output_differentiability_conditions=None,
),
True,
)
def match_differentiability_info(
native_functions: list[NativeFunction],
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
) -> list[NativeFunctionWithDifferentiabilityInfo]:
"""Sets the "derivative" key on declarations to matching autograd function
In-place functions will use the out-of-place derivative definition if there
is no in-place specific derivative.
"""
functional_info_by_signature = {
schema.signature(strip_default=True): info_dict
for schema, info_dict in differentiability_infos.items()
if schema.kind() == SchemaKind.functional
}
non_functional_info_by_signature = {
schema.signature(strip_default=True): info_dict
for schema, info_dict in differentiability_infos.items()
if schema.kind() != SchemaKind.functional
}
def find_info(
f: NativeFunction,
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
# Don't bother matching info to generated out= variants
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
return None, False
# (1) Check for an exact match
if f.func in differentiability_infos:
return differentiability_infos[f.func], True
# (2) If no exact match, check if the out-of-place variant
# of this operator has a match.
# i.e mul() for mul_() or mul_out()
# note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
# native functions instead of the out-place counterparts.
f_sig = f.func.signature(strip_default=True)
if f_sig in functional_info_by_signature and not is_foreach_func(f):
return functional_info_by_signature[f_sig], False
# (3) Some operators have a derivative explicitly defined for the mutable
# variant, but get a code-generated out-of-place variant which does *not*
# come with a derivative formula.
# For the generated out-of-place variant, use the mutable variant's formula
# if it exists.
if "generated" in f.tags and f_sig in non_functional_info_by_signature:
info_dict = non_functional_info_by_signature[f_sig]
# See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
assert not any(
any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
for info in info_dict.values()
), f"""\
Attempted to convert a derivative formula for a mutable operator
to be used by automatically by its functional variant ("{str(f.func)}").
this is not currently supported (we'd need to fix up the formula in the codegen)."""
return info_dict, False
# (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
if is_foreach_func(f):
assert f.func not in differentiability_infos
diff_info, is_generated = gen_foreach_derivativeinfo(
f,
functional_info_by_signature,
non_functional_info_by_signature,
)
if diff_info is None:
return None, False
# TODO(crcrpar): Avoid hard coding "Default" ideally.
diff_info_dict = {"Default": diff_info}
if is_generated:
differentiability_infos[f.func] = diff_info_dict
functional_info_by_signature[f.func] = diff_info_dict
return diff_info_dict, is_generated
return None, False
result: list[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions:
info_dict, is_exact_match = find_info(f)
# Currently, the '.strides()' to 'strides_or_error' replacement does not support
# 'self' derivatives of an inplace function, so we must check for this case.
if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
for info in info_dict.values():
for derivative in info.derivatives:
if "self" in derivative.var_names:
for saved_input in derivative.saved_inputs:
assert "strides_or_error" not in saved_input.expr, (
"Calling '.strides()' in the 'self' derivative formula of an "
f"in-place function is not supported: {f.func}"
)
if not info_dict:
result.append(
NativeFunctionWithDifferentiabilityInfo(
func=f, info=None, fw_derivatives=None
)
)
continue
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
for key, info in info_dict.items():
if not info.forward_derivatives:
fw_derivative_dict[key] = []
continue
forward_derivatives = info.forward_derivatives
# For functions that have a single def for out-of-place and inplace (like abs())
if f.func.kind() == SchemaKind.inplace:
# For inplace functions there is a little bit of work to do:
# 1) Validate the formula and make sure the input that is modified in not used:
# - If there is a formula for the inplace variant of the function (is_exact_match == True) then
# we make sure that the original value of the input that is being modified inplace (self_p) is
# not used in the formula. Note that the formula can use "original_self_p" here and that would
# trigger a clone of the original input.
# - If we are re-using the out of place formula (is_exact_match == False) then we replace every
# occurrence of self_p and self_t by original_self_p and original_self_t. These will be
# populated by cloned version of the original input (either the clone done by the backward AD
# logic if self is also used in a backward formula or a special clone that we add).
# 2) At this point, there cannot be a self_p in the formula.
# 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
# simply called self (as it is modified inplace).
# 4) Update the required primals data in case it used to contain "result" but should now contain
# "self"
# 5) If it is not an exact match, the user formula is not modifying the existing forward grad
# inplace as it should. So add some code that makes sure that we do so if the forward grad
# already exists.
assert (
len(info.forward_derivatives) == 1
) # Only single output inplace should exist
fw_info = info.forward_derivatives[0]
formula = fw_info.formula
def replace_self_with_original_self(formula: str, postfix: str) -> str:
def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
if re.search(IDENT_REGEX.format("self_p"), formula):
if is_exact_match:
# For manually defined formulas, don't allow the original value to be used
raise RuntimeError(
f'The formula for "{f.func.name}" is using the original value of self '
"that is being modified inplace. This would lead to wrong forward gradients. "
'Please use "result" in the formula only.'
)
else:
# When the original formula is out of place, we save a clone of the primal
# value to be able to access this value if needed
# replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
formula = replace_self_with_original_self(formula, "_p")
formula = replace_self_with_original_self(formula, "_t")
# replace "result" from the formula by "self_p"
def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}self_p{m.group(2)}"
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
required_primals = fw_info.required_inputs_primal
if re.search(IDENT_REGEX.format("self_p"), formula):
required_primals = (
required_primals + ("self",) if required_primals else ("self",)
)
if not is_exact_match:
# NOTE [In-place forward AD formula Optimization]
#
# This optimization transforms the formula to directly do inplace, i.e.
# instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
#
# 1) the formula satisfies the pattern: "self_t.op(*args)"
# 2) "op" in (1) needs to be the same as the op the derivative is for
#
# (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
# If there is a need, we can relax (2) to allow any op that has an in-place variant
is_single_method_on_self_t = False
directly_do_inplace = False
op_name: str | None = None
between_parens: str | None = None
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
if match:
op_name, between_parens = match.group(1), match.group(2)
# We want to...
# Match: self_t.op1(other_p.op2(arg))
# Avoid: self_t.op1(args) + self_t.op2(args)
# Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
def check_parens_nest_level_gt_zero(s: str) -> bool:
level = 1
for ch in s:
if ch == ")":
level -= 1
if level == 0:
return False
if ch == "(":
level += 1
return True
is_single_method_on_self_t = check_parens_nest_level_gt_zero(
between_parens
)
directly_do_inplace = (
is_single_method_on_self_t and op_name == info.name
)
if directly_do_inplace:
assert op_name is not None
assert between_parens is not None
formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
else:
# Make sure that the forward grad is modified inplace when the original formula
# is out of place
formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
required_original_self_value = bool(
re.search(IDENT_REGEX.format("original_self_p"), formula)
) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
forward_derivatives = [
ForwardDerivative(
formula=formula,
var_names=("self",),
var_types=fw_info.var_types,
required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
required_inputs_primal=required_primals,
required_original_self_value=required_original_self_value,
is_reusing_outplace_formula=not is_exact_match,
),
]
fw_derivative_dict[key] = forward_derivatives
result.append(
NativeFunctionWithDifferentiabilityInfo(
func=f, info=info_dict, fw_derivatives=fw_derivative_dict
)
)
return result
def is_differentiable(
name: str, type: Type, info: DifferentiabilityInfo | None
) -> bool:
return type.is_tensor_like() and (
info is None or name not in info.non_differentiable_arg_names
)
def gen_differentiable_outputs(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> list[DifferentiableOutput]:
f = fn.func
info = fn.info[key] if fn.info else None
outputs: list[DifferentiableOutput] = [
DifferentiableOutput(
name=name,
type=ret.type,
cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
)
for name, ret in zip(cpp.return_names(f), f.func.returns)
]
output_differentiability = info.output_differentiability if info else None
if output_differentiability is not None:
if len(output_differentiability) != len(outputs):
raise RuntimeError(
f"The length of output_differentiability ({len(output_differentiability)}), "
f"does not match the number of outputs ({len(outputs)})."
)
differentiable_outputs: list[DifferentiableOutput] = []
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
raise RuntimeError(
"output_differentiability=False for inplace operation (version_counter won't get updated)"
)
for differentiable, output in zip(output_differentiability, outputs):
if differentiable:
differentiable_outputs.append(output)
return differentiable_outputs
candidate_differentiable_outputs = list(
filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
)
if uses_single_grad(info):
return candidate_differentiable_outputs[:1]
else:
return candidate_differentiable_outputs

View File

@ -0,0 +1,472 @@
from __future__ import annotations
from typing import Sequence
from torchgen import local
from torchgen.api.types import (
ArgName,
ArrayCType,
ArrayRefCType,
BaseCType,
BaseTypeToCppMapping,
Binding,
boolT,
ConstRefCType,
CType,
dimnameListT,
intArrayRefT,
iTensorListRefT,
ListCType,
longT,
MutRefCType,
NamedCType,
OptionalCType,
optionalIntArrayRefT,
optionalSymIntArrayRefT,
scalarT,
SpecialArgName,
symIntArrayRefT,
SymIntT,
tensorListT,
tensorOptionsT,
tensorT,
TupleCType,
VectorCType,
voidT,
)
from torchgen.model import (
Argument,
Arguments,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
OptionalType,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
#
# Prominent characteristics of the C++ API:
#
# - dtype, layout, device and pin_memory are collected into
# a single C++ type TensorOptions (the native functions API
# also has this, but tensor options is really most relevant
# for the C++ API; it makes calling kwarg factory functions
# pleasant)
#
# - defaulting lives here (in fact, the dispatcher is completely
# oblivious of defaults!)
#
# BTW: policy on name collisions: we try not to have types with
# collisions, but functions are fair game to collide
def name(
func: FunctionSchema,
*,
faithful_name_for_out_overloads: bool = False,
symint_overload: bool = False,
) -> str:
name = str(func.name.name)
if symint_overload:
name += "_symint"
if func.is_out_fn():
if faithful_name_for_out_overloads:
name += "_outf"
else:
name += "_out"
return name
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(
t: Type,
*,
binds: ArgName,
mutable: bool = True,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
) -> NamedCType | None:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
elif str(t) == "SymInt":
if symint:
return NamedCType(binds, BaseCType(SymIntT))
else:
return NamedCType(binds, BaseCType(longT))
if remove_non_owning_ref_types:
if t.name == BaseTy.str:
raise AssertionError(
"string ref->value conversion: not implemented yet"
)
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if str(t.elem) == "bool":
assert t.size is not None
return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occurring in JIT arguments to a C++ argument type.
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
# For example, we'll return std::vector<int> instead of IntArrayRef.
# See Note [translation from C++ reference to value types]
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
) -> NamedCType:
# If it's a value type, do the value type translation
r = valuetype_type(
t,
binds=binds,
mutable=mutable,
symint=symint,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == "Tensor":
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(
binds, MutRefCType(BaseCType(tensorT))
) # TODO: fix this discrepancy
else:
return NamedCType(
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
)
elif str(t.elem) == "Scalar":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
if symint:
return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
else:
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
# TODO: remove these special cases, ArrayRef fallthrough works fine
if str(t.elem) == "int":
if remove_non_owning_ref_types:
return NamedCType(binds, VectorCType(BaseCType(longT)))
else:
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "SymInt":
if remove_non_owning_ref_types:
if symint:
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
else:
return NamedCType(binds, VectorCType(BaseCType(longT)))
else:
if symint:
return NamedCType(binds, BaseCType(symIntArrayRefT))
else:
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "Tensor":
if local.use_ilistref_for_tensor_lists():
return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
else:
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Scalar":
return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
elif str(t.elem) == "Dimname":
return NamedCType(binds, BaseCType(dimnameListT))
elif str(t.elem) == "Tensor?":
return NamedCType(
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
# N.B: returntype_type returns a CType, not a NamedCType.
# This is mostly because of the mismatch between return types and return names.
# e.g. a function with a return type of 'void' has 0 return names,
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
# placeholder is ignored
# NB: symint is ALWAYS respected for return types. So symint argument
# here is IGNORED
r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
if r is not None:
return r.type
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
if local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
else:
return MutRefCType(BaseCType(tensorT))
else:
# Note [Tensor Copy Returns]
# Currently, we use "Argument.is_write" to determine
# whether or not Tensor return types should be copies or references.
# If that ever changes, take a look at other locations of this note!
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)
elif isinstance(t, OptionalType):
elem = returntype_type(t.elem, mutable=mutable)
if str(t.elem) == "Tensor":
return OptionalCType(elem)
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return, *, symint: bool = False) -> CType:
return returntype_type(r.type, mutable=r.is_write, symint=symint)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0], symint=symint)
else:
return TupleCType([return_type(r, symint=symint) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: list[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
# TODO: Consider incorporating this into the data model
if f.func.name.name.inplace:
assert i == 0, "illegal inplace function with multiple returns"
name = "self"
# If we are out function, the name is the name of the
# corresponding output function (r.name will get recorded
# in field_name later.)
elif f.func.is_out_fn():
name = f.func.arguments.out[i].name
# If the return argument is explicitly named...
elif r.name:
name_conflict = any(
r.name == a.name for a in f.func.schema_order_arguments()
)
if name_conflict and not f.func.is_out_fn():
name = f"{r.name}_return"
else:
name = r.name
# If there is no explicit name and no fallback name was passed in, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
returns.append(name)
return returns
JIT_TO_CPP_DEFAULT = {
"False": "false",
"True": "true",
"None": "::std::nullopt", # UGH this one is type directed
"Mean": "at::Reduction::Mean",
"[]": "{}",
"contiguous_format": "c10::MemoryFormat::Contiguous",
"long": "at::kLong",
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type, *, symint: bool) -> str:
if d == "None" and str(t) == "Tensor?":
return "{}"
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ""
i = 1
while i + 1 < len(d):
if d[i] != "\\":
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i : i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == "None":
return "::std::nullopt"
return default_expr(d, t.elem, symint=symint)
if isinstance(t, ListType):
if d.startswith("[") and d.endswith("]"):
return "{" + d[1:-1] + "}"
elif symint and d.isdigit() and str(t.elem) == "SymInt":
return f"c10::SymInt({d})"
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument(
a: Argument | TensorOptionsArguments | SelfArgument,
*,
cpp_no_default_args: set[str],
method: bool,
faithful: bool,
symint: bool = False,
has_tensor_options: bool,
) -> list[Binding]:
def sub_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
method=method,
faithful=faithful,
symint=symint,
has_tensor_options=has_tensor_options,
)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: str | None = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type, symint=symint)
return [
Binding(
nctype=argument_type(a, binds=binds, symint=symint),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, TensorOptionsArguments):
if faithful:
return (
sub_argument(a.dtype)
+ sub_argument(a.layout)
+ sub_argument(a.device)
+ sub_argument(a.pin_memory)
)
else:
default = None
# Enforced by NativeFunction.__post_init__
assert "options" not in cpp_no_default_args
if all(x.default == "None" for x in a.all()):
default = "{}"
elif a.dtype.default == "long":
default = "at::kLong" # TODO: this is wrong
return [
Binding(
nctype=NamedCType("options", BaseCType(tensorOptionsT)),
name="options",
default=default,
argument=a,
)
]
elif isinstance(a, SelfArgument):
if method:
# Caller is responsible for installing implicit this in context!
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
arguments: Arguments,
*,
faithful: bool,
symint: bool = False,
method: bool,
cpp_no_default_args: set[str],
) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
else:
args.extend(arguments.out)
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r
for a in args
for r in argument(
a,
faithful=faithful,
symint=symint,
method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args,
)
]

View File

@ -0,0 +1,120 @@
from __future__ import annotations
import itertools
from typing import Sequence
from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType
from torchgen.model import (
Argument,
FunctionSchema,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never, concatMap
# This file describes the translation of JIT schema to the dispatcher
# API, the *unboxed* calling convention by which invocations through
# the dispatcher are made. Historically, the dispatcher API matched
# the C++ API, but with the establishment of the boxed API, we've
# made changes to the dispatcher API to so that the unboxed API
# better aligns with the boxed API. The dispatcher API hooks heavily
# into our template based boxing/unboxing machinery, so changes
# to this convention will usually need template updates too.
#
# Prominent characteristics of the dispatcher API:
#
# - dtype, layout, device and pin_memory are represented as separate
# arguments.
#
def name(func: FunctionSchema) -> str:
return cpp.name(func)
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = True,
) -> NamedCType:
# This is a faux amis. If it makes sense in the future to add
# more special cases here, or invert things so cpp.argument_type
# calls this, or just completely inline the function, please do
# it.
return cpp.argumenttype_type(
t,
mutable=mutable,
binds=binds,
symint=symint,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
def argument_type(
a: Argument,
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = True,
) -> NamedCType:
return argumenttype_type(
a.type,
mutable=a.is_write,
binds=binds,
remove_non_owning_ref_types=remove_non_owning_ref_types,
symint=symint,
)
def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
# At present, there is no difference. But there could be!
return cpp.returns_type(rs, symint=symint)
def jit_arguments(func: FunctionSchema) -> list[Argument]:
def to_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Argument]:
if isinstance(a, Argument):
return [a]
elif isinstance(a, SelfArgument):
return [a.argument]
elif isinstance(a, TensorOptionsArguments):
return [a.dtype, a.layout, a.device, a.pin_memory]
else:
assert_never(a)
return list(
concatMap(
to_argument,
itertools.chain(
func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
),
)
)
def argument(
a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
) -> Binding:
return Binding(
nctype=argument_type(
a,
binds=a.name,
remove_non_owning_ref_types=remove_non_owning_ref_types,
symint=symint,
),
name=a.name,
argument=a,
)
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
return [argument(a, symint=symint) for a in jit_arguments(func)]

View File

@ -0,0 +1,199 @@
from __future__ import annotations
from torchgen.api import dispatcher
from torchgen.api.types import (
BaseCppType,
BaseCType,
Binding,
boolT,
ConstRefCType,
CType,
longT,
NamedCType,
tensorT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
NativeFunction,
NativeFunctionsViewGroup,
)
# This file describes the translation of JIT schema to API's used
# when creating view lambdas that are used by the functionalization pass.
# There are two types of lambdas: forward lambdas and reverse lambdas.
# These API's mostly follow the dispatcher API, with a few quirks:
# - The lambda capture has to convert reference types to value types
# - While the forward lambda just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse lambda
# is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace).
# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
# return name(inner_arguments);
# }
# Define some specific lambda input arguments.
base_binding = Binding(
name="base",
nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
mutated_view_binding = Binding(
name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
mutated_view_idx_binding = Binding(
name="mutated_view_idx",
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
reapply_views_binding = Binding(
name="reapply_views",
nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
argument=Argument(
name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
),
default=None,
)
InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
inverse_return_mode_binding = Binding(
name="inverse_return_mode",
nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
argument=Argument(
name="inverse_return_mode",
# NB: not actually a bool but it doesn't matter because this isn't used
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
# The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda.
def name(
g: NativeFunctionsViewGroup,
*,
is_reverse: bool,
include_namespace: bool,
reapply_views: bool | None = None,
) -> str:
if reapply_views is None:
# reapply_views is only important for the fwd lambda,
# since we always plumb the runtime "reapply_views" argument into the reverse function.
assert is_reverse
if is_reverse:
return reverse_name(g.view, include_namespace)
# in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
assert include_namespace
assert g.view_copy is not None
api_name = (
g.view.func.name.unambiguous_name()
if reapply_views
else g.view_copy.func.name.unambiguous_name()
)
return f"at::_ops::{api_name}::call"
def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
# for the reverse: we plumb the "reapply_views" flag into that function and support
# both copy and non-copy variants. (We could avoid doing that, but that would require
# writing out twice as many view inverse functions).
api_name = f.func.name.unambiguous_name()
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
if include_namespace:
return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
else:
return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [
inverse_return_mode_binding if is_reverse else reapply_views_binding
]
all_bindings.extend(non_self_value_bindings)
return all_bindings
def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1
for ret in func.returns:
assert ret.type.is_tensor_like()
# However, the return type of the lambda is always an individual tensor.
# For multi-tensor outputs, each tensor needs to be tracked individually.
return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
if is_reverse:
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
else:
return [base_binding, mutated_view_idx_binding]
def inner_call_index(func: FunctionSchema) -> Binding | None:
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
if len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like()
):
return mutated_view_idx_binding
return None
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
# The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
# Both of these follow the dispatcher API.
non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
if not is_reverse:
# the forward lambda swaps out the original tensor argument with the lambd arg "base"
return [base_binding] + non_self_bindings
else:
# the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument.
index_binding = inner_call_index(func)
if index_binding is not None:
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
index_binding,
] + non_self_bindings
else:
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
] + non_self_bindings

View File

@ -0,0 +1,467 @@
from __future__ import annotations
from typing import Any
from torchgen.api.types import (
BaseCppType,
BaseCType,
boolT,
CType,
deviceT,
doubleT,
generatorT,
layoutT,
ListCType,
longT,
memoryFormatT,
NamedCType,
OptionalCType,
scalarT,
scalarTypeT,
stringT,
SymIntT,
VectorCType,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
ListType,
OperatorName,
OptionalType,
Return,
TensorOptionsArguments,
Type,
)
_valueT: BaseCppType | None = None
# A ValueT is an IR type which represents the computation of a Tensor. In other
# words, a PyTorch user will do operations on lazy tensors, and each output lazy
# tensor internally tracks a ValueT representing the IR node that would have
# actually produced the value of this tensor for real.
#
# This is configurable because different lazy tensor backends (LTC vs XLA) will
# have different IR representations. (Though, arguably, after unification they
# shouldn't!)
def getValueT() -> BaseCppType:
global _valueT
if not _valueT:
raise NotImplementedError(
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
)
return _valueT
def setValueT(val: BaseCppType) -> None:
global _valueT
_valueT = val
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type(
typ: Type, properties: LazyIrProperties, *, symint: bool
) -> BaseCType | VectorCType | OptionalCType | ListCType:
"""
This function takes a type from NativeFunctions and converts it for use with
lazy tensor codegen.
Type conversion for lazy currently consists of
(1) changing at::Tensors into lazy::Values
(2) wrapping everything in a BaseCType
(3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
(1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
This is incomplete- there are assertions in places that it's expected to need to add
more types as the codegen is used with more operators.
"""
if isinstance(typ, BaseType):
if typ.name == BaseTy.Tensor:
return BaseCType(getValueT())
elif typ.name == BaseTy.Scalar:
if properties.TreatScalarsAsConstants:
return BaseCType(scalarT)
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(getValueT())
elif typ.name == BaseTy.ScalarType:
return BaseCType(scalarTypeT)
elif typ.name == BaseTy.int:
return BaseCType(longT)
elif typ.name == BaseTy.SymInt:
if symint:
return BaseCType(getValueT())
else:
return BaseCType(longT)
elif typ.name == BaseTy.bool:
return BaseCType(boolT)
elif typ.name == BaseTy.float:
return BaseCType(doubleT)
elif typ.name == BaseTy.str:
return BaseCType(stringT)
elif typ.name == BaseTy.Device:
return BaseCType(deviceT)
elif typ.name == BaseTy.Generator:
return BaseCType(generatorT)
elif typ.name == BaseTy.Layout:
return BaseCType(layoutT)
elif typ.name == BaseTy.MemoryFormat:
return BaseCType(memoryFormatT)
else:
raise AssertionError(f"TODO add support for type {repr(typ)}")
elif isinstance(typ, OptionalType):
return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
return ListCType(OptionalCType(BaseCType(getValueT())))
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
elif typ.elem == BaseType(BaseTy.SymInt):
# TODO: return a value type. The problem here is analogous to
# the problem with tensorListValueT: if you have SymInt[] you
# cannot conveniently save the list of Value directly, as nodes
# expect to save values as a vector for ALL arguments. So you
# need a separate IR node that represents all of the size nodes
# assembled into a list. I'm not an LTC dev so I don't want to
# figure it out right now. Y'all figure it out...
return VectorCType(BaseCType(longT))
else:
return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
else:
raise AssertionError(f"unrecognized type {repr(typ)}")
# TODO: Determining this based off of CType is bad; this should be computed
# from Type directly; then the same logic as process_ir_type can be used
#
# Invariant: passed typ should be an *owning* CType (e.g., we will report
# that ArrayRef<Value> is NOT a value type)
def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed.
"""
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
return (
typ.type == getValueT()
or (typ.type == scalarT and not treat_scalars_as_constants)
or typ.type == SymIntT
)
elif typ == VectorCType(BaseCType(SymIntT)):
# TODO: report True for this
return False
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem, properties)
return False
def isSymIntType(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
def isWrappedScalarType(typ: Type) -> bool:
"""
Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
Since we literally change the type from scalarT to valueT, information is lost.
This function helps build a list of wrapped scalars to save that information
"""
if isinstance(typ, BaseType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.name == BaseTy.Scalar
elif isinstance(typ, (OptionalType, ListType)):
return isWrappedScalarType(typ.elem)
return False
# TODO: dedupe with Type.is_generator_like
def isGeneratorType(typ: Type) -> bool:
if isinstance(typ, BaseType):
return typ.name == BaseTy.Generator
elif isinstance(typ, (OptionalType)):
return isGeneratorType(typ.elem)
return False
# This class caches a few derived properties computed from an Argument
# and LazyIrProperties
class LazyArgument:
name: str
orig_type: Type
lazy_type_: CType | None
is_wrapped_scalar: bool
is_generator: bool
# TODO: this is lies, it is false for symint list
is_symint_or_list: bool
# Whether or not we are treating this as symint or not
symint: bool
# true if this argument is or contains a lazy IR value
is_lazy_value: bool
def __init__(
self, arg: Argument, properties: LazyIrProperties, *, symint: bool
) -> None:
self.name = arg.name
self.orig_type = arg.type
self.symint = symint
self.is_optional = isinstance(arg.type, OptionalType)
self.is_generator = isGeneratorType(arg.type)
self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = symint and (
isSymIntType(arg.type)
or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
# TODO: lists of symints are not currently treated as value types
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
)
self.is_lazy_value = isValueType(self.lazy_type, properties)
@property
def lazy_type(self) -> CType:
assert (
self.lazy_type_ is not None
), f"Attempted to access lazy_type for invalid argument {self.name}"
return self.lazy_type_
class LazyIrProperties:
"""Collection of properties for an IR node
The property groups are listed below. Each group is mutually
exclusive, meaning that only one property from each group can be True
at any one time. The properties can be accessed as if they were normal
attributes. The mutual exclusivity is automatically handled.
"""
Properties: tuple[tuple[str, ...], ...] = (
(
"ShapePrecompute", # Assume shape has been precomputed
"ShapeCompute", # Need to compute the shape on construction
"ShapeCache", # Utilize the shape cache to defer computation
),
(
"Lower", # Codegen full lower function
"LowerDeclOnly", # Codegen only lower function declaration
),
(
"CanBeReused", # Codegen full reuse function
"CanBeReusedDeclOnly", # Codegen only reuse function declaration
),
(
"CreateFn", # Codegen full create function
"CreateFnDeclOnly", # Codegen only create function declaration
),
(
"TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
),
)
def __init__(self, *default_properties: str) -> None:
properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
LazyIrProperties.Properties
)
self.__dict__["properties"] = properties
for p in default_properties:
setattr(self, p, True)
def __getattr__(self, key: str) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
return properties[values] == key
return self.__getattribute__(key)
def __setattr__(self, key: str, value: Any) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
properties[values] = key if value else None
return value
raise KeyError(f"Invalid property: {key}")
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
#
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
class LazyIrSchema:
# The name of the operator this function schema describes.
name: OperatorName
positional_args: tuple[LazyArgument, ...]
keyword_args: tuple[LazyArgument, ...]
# TODO: Need to handle collisions with argument names at some point
returns: tuple[Return, ...]
# if this schema has a Generator arg, list its orig ctype/name but don't
# build a LazyArgument since lazy IR doesn't support it
generator_arg: NamedCType | None = None
# original function schema
func: FunctionSchema
# Whether or not we are code-genning for SymInt or not
symint: bool
properties: LazyIrProperties = LazyIrProperties(
# default properties
"ShapePrecompute",
"Lower",
"CanBeReused",
)
opkind: str | None = None
def __init__(
self,
func: FunctionSchema,
properties: LazyIrProperties | None = None,
*,
symint: bool,
) -> None:
if properties:
self.properties = properties
self.func = func
self.symint = symint
positional_args: list[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = func.arguments.self_arg.argument
positional_args.append(
LazyArgument(arg, self.properties, symint=symint)
)
elif getattr(func.arguments, arg_field) is not None:
positional_args.extend(
LazyArgument(arg, self.properties, symint=symint)
for arg in getattr(func.arguments, arg_field)
)
self.positional_args = tuple(positional_args)
keyword_args: list[LazyArgument] = []
for arg_field in [
"pre_tensor_options_kwarg_only",
"tensor_options",
"post_tensor_options_kwarg_only",
"out",
]:
curr_args = getattr(func.arguments, arg_field)
if curr_args is not None:
if isinstance(curr_args, TensorOptionsArguments):
curr_args = curr_args.all()
for arg in curr_args:
if isGeneratorType(arg.type):
assert (
self.generator_arg is None
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(
arg.name, arg.type # type:ignore[arg-type]
)
keyword_args.extend(
LazyArgument(arg, self.properties, symint=symint)
for arg in curr_args
)
self.keyword_args = tuple(keyword_args)
self.name = func.name
self.returns = func.returns
@property
def node_name(self) -> str:
"""
Return camel-case version of op in node.
Note: This function also appends any `overload_name` in the operation.
For example, if the op is `bitwise_and.Tensor`, the returned name
will be `BitwiseAndTensor`.
"""
op_name = f"{self.name.name}_{self.name.overload_name}".lower()
return "".join(word.capitalize() or "" for word in op_name.split("_"))
@property
def aten_name(self) -> str:
return str(self.name.name)
@property
def base_name(self) -> str:
return f"{self.name.name.base}"
def filtered_args(
self,
positional: bool = True,
keyword: bool = True,
values: bool = True,
scalars: bool = True,
generator: bool = True,
) -> list[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings),
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
# in TS lowerings and therefore also omitted from lazy IR.
args: list[LazyArgument] = []
if positional:
args.extend(self.positional_args)
if keyword:
args.extend(self.keyword_args)
if values and scalars and generator:
return args
elif values and scalars:
return [a for a in args if not a.is_generator]
elif values:
return [a for a in args if a.is_lazy_value]
elif scalars:
return [
a
for a in args
if not a.is_lazy_value and (generator or not a.is_generator)
]
return []
@property
def positional_values(self) -> list[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=True, scalars=False
)
@property
def positional_scalars(self) -> list[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=False, scalars=True
)
@property
def keyword_values(self) -> list[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=True, scalars=False
)
@property
def keyword_scalars(self) -> list[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=False, scalars=True
)

View File

@ -0,0 +1,13 @@
from torchgen.model import NativeFunctionsGroup
# Follows dispatcher calling convention, but:
# - Mutable arguments not allowed. Meta functions are always
# written in functional form. Look at FunctionSchema.signature()
# - No tensor returns; instead we return a TensorMeta describing
# the tensor in question
def name(g: NativeFunctionsGroup) -> str:
# use the overload name from the functional version
return str(g.functional.func.name).replace(".", "_")

View File

@ -0,0 +1,155 @@
from __future__ import annotations
from typing import Sequence
from torchgen import local
from torchgen.api import cpp
from torchgen.api.types import (
ArgName,
BaseCType,
Binding,
boolT,
ConstRefCType,
CType,
deviceT,
layoutT,
ListCType,
MutRefCType,
NamedCType,
OptionalCType,
scalarT,
scalarTypeT,
tensorT,
)
from torchgen.model import (
Argument,
FunctionSchema,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
# This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the
# idea was you wrote native functions to implement functions in the C++ API),
# but over time we have evolved the C++ API without actually changing our
# native:: kernels. The intention is to make native API and dispatcher API
# line up as closely as possible, since this results in the least overhead
# (no translation is needed from dispatcher API to native API).
#
# NB: this is symint aware, you will get the non-SymInt variant for some
# dispatch entries and SymInt for others.
def name(func: FunctionSchema) -> str:
name = str(func.name.name)
# TODO: delete this!
if func.is_out_fn():
name += "_out"
if func.name.overload_name:
name += f"_{func.name.overload_name}"
return name
def argumenttype_type(
t: Type, *, mutable: bool, binds: ArgName, symint: bool
) -> NamedCType:
if str(t) == "Tensor?":
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(tensor_type))
else:
return NamedCType(binds, ConstRefCType(tensor_type))
elif str(t) == "Tensor?[]":
return NamedCType(
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
)
elif str(t) == "Scalar":
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
elif str(t) == "Scalar?":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
return cpp.returns_type(rs, symint=symint)
def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
def argument(
a: Argument | SelfArgument | TensorOptionsArguments,
*,
is_out: bool,
symint: bool,
) -> list[Binding]:
# Ideally, we NEVER default native functions. However, there are a number
# of functions that call native:: directly and rely on the defaulting
# existing. So for BC, we generate defaults for non-out variants (but not
# for out variants, where it is impossible to generate an appropriate
# default)
should_default = not is_out
if isinstance(a, Argument):
default: str | None = None
if should_default and a.default is not None:
default = cpp.default_expr(a.default, a.type, symint=symint)
return [
Binding(
nctype=argument_type(a, binds=a.name, symint=symint),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, SelfArgument):
# Erase SelfArgument from the distinction
return argument(a.argument, is_out=is_out, symint=symint)
elif isinstance(a, TensorOptionsArguments):
default = None
if should_default:
default = "{}"
# TODO: Not sure why the arguments assigned here are for
# TensorOptionsArguments and not the constituent pieces. It seems
# to matter
return [
Binding(
nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
name="dtype",
default=default,
argument=a,
),
Binding(
nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
name="layout",
default=default,
argument=a,
),
Binding(
nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
name="device",
default=default,
argument=a,
),
Binding(
nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
name="pin_memory",
default=default,
argument=a,
),
]
else:
assert_never(a)
def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(func.arguments.non_out)
args.extend(func.arguments.out)
return [
r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,157 @@
from __future__ import annotations
from torchgen.api import cpp
from torchgen.api.types import (
ArgName,
ArrayRefCType,
BaseCType,
Binding,
ConstRefCType,
dimnameListT,
intArrayRefT,
iOptTensorListRefT,
iTensorListRefT,
NamedCType,
OptionalCType,
optionalIntArrayRefT,
optionalScalarRefT,
optionalTensorRefT,
scalarT,
tensorT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunctionsGroup,
OptionalType,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
# This file describes the translation of JIT schema to the structured functions API.
# This is similar to native API, but a number of historical problems with native
# API have been fixed.
# Translation of types occurring in JIT arguments to a C++ argument type.
# NB: For now, mutable doesn't do anything; but it could if we make
# some more nominal types
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
# If it's a value type, do the value type translation
# NB: structured kernels ALWAYS have symint off, since they involve actual
# kernels that require real ints. The one exception is the
# CompositeExplicitAutograd and the meta function (which could
# hypothetically be SymInt), but for simplicity we plan for these to just
# be handled in Python
r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if t.elem == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(optionalTensorRefT))
elif t.elem == BaseType(BaseTy.Scalar):
return NamedCType(binds, BaseCType(optionalScalarRefT))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if t.elem == BaseType(BaseTy.Tensor):
return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
return NamedCType(binds, BaseCType(iOptTensorListRefT))
# TODO: delete these special cases; see torchgen.api.cpp--these
# must be changed in tandem, but there are problems; see
# https://github.com/pytorch/pytorch/pull/51485
elif str(t.elem) == "int":
return NamedCType(binds, BaseCType(intArrayRefT))
elif str(t.elem) == "Dimname":
return NamedCType(binds, BaseCType(dimnameListT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# returns_type intentionally omitted, because structured kernels never "return";
# instead, they always indirectly report their outputs (in the case of a meta
# function, by calling set_output; in the case of an impl function, by writing
# directly into the provided out argument).
# Structured kernels are never defaulted
def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
if isinstance(a, Argument):
return [
Binding(
nctype=argument_type(a, binds=a.name),
name=a.name,
default=None,
argument=a,
)
]
elif isinstance(a, SelfArgument):
return argument(a.argument)
elif isinstance(a, TensorOptionsArguments):
raise AssertionError("structured kernels don't support TensorOptions yet")
else:
assert_never(a)
def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if g.out.precomputed:
# A list of parameters for the impl function with
# certain parameters replaced with precomputed counterparts
# as specified in native_functions.yaml.
non_out_args_replaced: list[
Argument | TensorOptionsArguments | SelfArgument
] = []
for a in g.out.func.arguments.non_out:
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
# If a is in precompute.replace, append the parameters
# that should replace it onto non_out_args_replaced.
non_out_args_replaced.extend(g.out.precomputed.replace[a.name])
else:
# If not, push a as it is.
non_out_args_replaced.append(a)
args.extend(non_out_args_replaced)
# g.out.precomputed.add is the list of parameters that are added
# without replacement after the non out args and just before the out args
args.extend(g.out.precomputed.add)
else:
args.extend(g.out.func.arguments.non_out)
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]
def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(g.functional.func.arguments.non_out)
return [r for arg in args for r in argument(arg)]
def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]

View File

@ -0,0 +1,433 @@
from __future__ import annotations
from typing import NoReturn, Sequence
from torchgen.api.types import (
ArrayRefCType,
BaseCType,
Binding,
boolT,
ConstRefCType,
deviceT,
Expr,
intArrayRefT,
iOptTensorListRefT,
layoutT,
ListCType,
longT,
memoryFormatT,
MutRefCType,
NamedCType,
opmath_t,
OptionalCType,
optionalIntArrayRefT,
optionalScalarRefT,
optionalSymIntArrayRefT,
optionalTensorRefT,
scalar_t,
scalarT,
scalarTypeT,
SpecialArgName,
symIntArrayRefT,
SymIntT,
tensorOptionsT,
tensorT,
VectorCType,
)
# This file implements a small program synthesis engine that implements
# conversions between one API to another.
#
# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType
# represents a C++ type, plus semantic information about what it represents.
# For example, consider the argument "bool pin_memory"; its normal C++ type is
# "bool", but its C++ semantic type also keeps track that this represents a
# "pin_memory"; you can't just use a random other boolean in a context where you
# need a "pin_memory"!
#
# The translator takes a list of needed NamedCTypes, and then figures out how
# to construct expressions with these NamedCTypes from the given bindings. Many
# of these expressions are trivial (I need a Tensor other; there's a Tensor
# other scope); others are more nontrivial and may require packing/unpacking.
# Some examples of non-trivial action:
#
# - Need the "dtype" binding? Well, maybe "dtype" isn't available
# in the context, instead, "options" is, and you need to extract
# it from there. (Gather)
#
# - Need the "context" binding? Well, maybe "context" isn't available
# in the context, and you need to construct it from "dtype", "device",
# etc. (Scatter)
#
# - Need the "memory_format" binding? Well, actually, it's available
# from both "memory_format" and "options", so you had better make sure
# they are consistent. (Join)
options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
longVec_ctype = VectorCType(BaseCType(longT))
longSymVec_ctype = VectorCType(BaseCType(SymIntT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
class UnsatError(RuntimeError):
pass
# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals. You may want to use this function if
# you're generating code for a function like:
#
# void f({args}) {
# g({exprs}); // g is a different API
# }
#
# and you need to generate "exprs".
#
# Typically, a list of Bindings is convenient to get (you usually call something
# like arguments() to get them); but technically you only need less information:
# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing
# something more complicated, e.g., tracking the set of bindings in a context,
# you may find using these smaller types more convenient.
def translate(
bindings: Sequence[Expr | Binding],
goals: Sequence[NamedCType | Binding],
*,
method: bool = False,
allow_expensive_conversions: bool = False,
) -> list[Expr]:
binding_exprs: list[Expr] = []
for b in bindings:
if isinstance(b, Binding):
binding_exprs.append(
Expr(
expr=b.name,
type=b.nctype,
)
)
else:
binding_exprs.append(b)
goal_ctypes: list[NamedCType] = []
for g in goals:
if isinstance(g, Binding):
goal_ctypes.append(g.nctype)
else:
goal_ctypes.append(g)
# Add all the bindings to the context
ctx: dict[NamedCType, str] = {}
for b in binding_exprs:
ctx[b.type] = b.expr
# While we're at it, do some simple forward inference, looking through
# constructors.
#
# NB: When should you do forward inference versus backward inference?
# The general idea:
#
# - Backward inference WHEN the goal gets smaller
# - Forward inference WHEN the hypothesis gets smaller
#
# This helps ensure termination: backward inference starts with a goal
# and tries to make it simpler and simpler until it's trivial; if the
# goal can grow in size, we blow up to a really huge goal size.
# Similarly, with forward inference we take hypotheses and decompose
# them into simpler hypotheses; if hypotheses could expand in size,
# we also have potential nontermination. (In the code below, forward
# inference is only ever carried out at a single step, but you could
# imagine repeated application of forward inference being profitable.)
#
# A good starting point in the literature for exploring more about proof
# search are these lecture notes
# https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
#
# TODO: My kingdom for a pattern matcher
# https://www.python.org/dev/peps/pep-0634/
#
# TODO: This could get us in recomputation trouble if b.expr is nontrivial.
# Fix this by implementing some sort of sharing so that if multiple
# goals share the same expression, we only compute it once. This seems
# to matter in practice as compiler is often unwilling to CSE nontrivial
# expressions like scalar.to<scalar_t>()
t = b.type
if (
isinstance(t, ConstRefCType)
and isinstance(t.elem, OptionalCType)
and isinstance(t.elem.elem, BaseCType)
and str(t.elem.elem.type) == "at::Tensor"
):
ctx[
NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
ctx[
NamedCType(t.name, BaseCType(optionalTensorRefT))
] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
if t.type == ConstRefCType(BaseCType(scalarT)):
ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"
if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
ctx[
NamedCType(t.name, BaseCType(optionalScalarRefT))
] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
if t.type == BaseCType(scalar_t):
ctx[
NamedCType(t.name, BaseCType(opmath_t))
] = f"static_cast<opmath_t>({b.expr})"
# [Note: IOptTensorListRef]
if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
ctx[
NamedCType(t.name, BaseCType(iOptTensorListRefT))
] = f"at::IOptTensorListRef({b.expr})"
# Add implicit bindings if the generated code is inside a Tensor method
if method:
ctx[
NamedCType("self", MutRefCType(BaseCType(tensorT)))
] = "const_cast<Tensor&>(*this)"
ctx[
NamedCType("self", ConstRefCType(BaseCType(tensorT)))
] = "const_cast<Tensor&>(*this)"
# This is better! Byte-for-byte compat
# ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
def unsat(goal: NamedCType) -> NoReturn:
ctx_desc = "\n".join(
f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
)
raise UnsatError(
f"""
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
When I failed, the following bindings were available in the context:
{ctx_desc}
This probably means there is a missing rule in the rules of torchgen.api.translate.
Check this module for more information.
"""
)
# A shitty backtracking search implementation. It's shitty because it
# does backtracking via stack (bad idea!) and for the most part tries to
# avoid backtracking. In particular, if
# direct=True, we won't try to do any fancy synthesis, just trivial
# conversions (e.g., "T a" is OK for "const T& a"). So all of the
# existing rules in this function simply try to solve immediately,
# and bail if things don't work out.
def solve(goal: NamedCType, *, direct: bool) -> str:
def direct_solve(goal: NamedCType) -> str:
return solve(goal, direct=True)
if goal in ctx:
# Trivial
return ctx[goal]
# const & is satisfied with mutable &
if isinstance(goal.type, ConstRefCType):
try:
# WARNING: not strictly decreasing; be careful not
# to add a direct conversion that goes satisfies
# mutable& with const&
return solve(
NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
)
except UnsatError:
pass
# mutable & is satisfied with value
if isinstance(goal.type, MutRefCType):
try:
return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
except UnsatError:
pass
# TODO: These are referentially equal, shouldn't have to do this;
# ensuring we don't use type synonym IntArrayRef in codegen would
# help
if goal.type == ArrayRefCType(BaseCType(longT)):
return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)
if direct:
unsat(goal)
# For now, all of these rules are mutually exclusive.
if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
memory_format = direct_solve(
NamedCType(
SpecialArgName.possibly_redundant_memory_format,
OptionalCType(BaseCType(memoryFormatT)),
)
)
# No need to join "memory_format" and "options" if the target API takes "options" directly.
# Otherwise it will cause the redundant memory_format error.
if options_ctype in goal_ctypes:
return memory_format
try:
options = direct_solve(options_ctype)
return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
except UnsatError:
return memory_format
elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
dtype = direct_solve(
NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
)
pin_memory = direct_solve(
NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
)
device = direct_solve(
NamedCType("device", OptionalCType(BaseCType(deviceT)))
)
layout = direct_solve(
NamedCType("layout", OptionalCType(BaseCType(layoutT)))
)
return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
try:
options = direct_solve(options_ctype)
return f"c10::optTypeMetaToScalarType({options}.dtype_opt())"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.scalar_type()"
elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
try:
options = direct_solve(options_ctype)
return f"{options}.layout_opt()"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.layout()"
elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
try:
options = direct_solve(options_ctype)
return f"{options}.device_opt()"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.device()"
elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
try:
options = direct_solve(options_ctype)
return f"{options}.pinned_memory_opt()"
except UnsatError:
# If we're calling a factory op from its out= variant,
# We don't actually care about the value of pin_memory.
out_tensor = direct_solve(out_tensor_ctype)
return "::std::nullopt"
# We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
elif goal.type == BaseCType(intArrayRefT):
try:
return direct_solve(NamedCType(goal.name, longVec_ctype))
except UnsatError:
# We can also go SymIntArrayRef -> IntArrayRef
symIntArrayRef_type = direct_solve(
NamedCType(goal.name, BaseCType(symIntArrayRefT))
)
return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
elif goal.type == BaseCType(symIntArrayRefT):
try:
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
return f"c10::fromIntArrayRefSlow({r})"
except UnsatError:
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
elif goal.type == BaseCType(SymIntT):
return direct_solve(NamedCType(goal.name, BaseCType(longT)))
elif goal.type == OptionalCType(BaseCType(SymIntT)):
argname = direct_solve(
NamedCType(goal.name, OptionalCType(BaseCType(longT)))
)
return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(longT):
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
elif goal.type == OptionalCType(BaseCType(longT)):
argname = direct_solve(
NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
)
return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt"
elif goal.type == BaseCType(optionalIntArrayRefT):
try:
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
except UnsatError:
argname = direct_solve(
NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
)
return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(optionalSymIntArrayRefT):
# TODO: You might also want to solve this from longSymVec_ctype or
# an optional version of it
argname = direct_solve(
NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
)
return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(optionalScalarRefT):
return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
elif goal.type == BaseCType(optionalTensorRefT):
return direct_solve(NamedCType(goal.name, optionalTensor_ctype))
# Note [translation from C++ reference to value types]
# The below cases are all for when we have an argument with a reference type,
# and a corresponding goal with a value type.
# These are needed when we populate the inputs to a lambda capture and we need
# to guarantee the lifetime of each captured argument.
# We guard it with an explicit kwarg because converting to a value type is expensive
# (O(n)) to convert from IntArrayRef to vector<int>),
# so the caller of translate() should be explicit that they need it.
if allow_expensive_conversions:
if goal.type == VectorCType(BaseCType(longT)):
intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
argname = direct_solve(intArrayRef_ctype)
return f"{argname}.vec()"
if goal.type == VectorCType(BaseCType(SymIntT)):
symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
argname = direct_solve(symIntArrayRef_ctype)
return f"{argname}.vec()"
elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
optionalIntArrayRef_ctype = NamedCType(
goal.name, BaseCType(optionalIntArrayRefT)
)
argname = direct_solve(optionalIntArrayRef_ctype)
return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt"
elif goal.type == OptionalCType(BaseCType(scalarT)):
optionalScalarRef_ctype = NamedCType(
goal.name, BaseCType(optionalScalarRefT)
)
argname = direct_solve(optionalScalarRef_ctype)
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
elif goal.type == OptionalCType(BaseCType(scalarT)):
optionalTensorRef_ctype = NamedCType(
goal.name, BaseCType(optionalTensorRefT)
)
argname = direct_solve(optionalTensorRef_ctype)
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
# Technically, we also need to handle cases of C++ containers holding reference types.
# But there currently aren't any ops that require lambda capture codegen
# With arguments like ::std::vector<IntArrayRef>.
# If that changes, we'll have to add the translation here.
# We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
# We could probably generalize this to non-tensor types too.
if goal.type == MutRefCType(BaseCType(tensorT)):
const_ref_tensor_ctype = NamedCType(
goal.name, ConstRefCType(BaseCType(tensorT))
)
argname = direct_solve(const_ref_tensor_ctype)
return f"const_cast<Tensor&>({argname})"
unsat(goal)
return [Expr(solve(g, direct=False), g) for g in goal_ctypes]

View File

@ -0,0 +1,5 @@
from torchgen.api.types.types import *
from torchgen.api.types.types_base import *
from torchgen.api.types.signatures import * # usort: skip

View File

@ -0,0 +1,426 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, Sequence, TYPE_CHECKING
from torchgen.api.types.types_base import Binding, CType, Expr
if TYPE_CHECKING:
from torchgen.model import (
BackendIndex,
FunctionSchema,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
)
@dataclass(frozen=True)
class CppSignature:
"""
A CppSignature represents a single overload in the C++ API. For
any given function schema, there may be multiple CppSignatures
corresponding to it, based on how we desugar to C++. See also
CppSignatureGroup.
"""
# The schema this signature is derived from
func: FunctionSchema
# Is this a C++ signature for a method, i.e. Tensor::my_op(...)?
method: bool
# Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API
# (i.e. with a potential TensorOptions argument and out arguments in the front)
faithful: bool
# Is this a symint C++ signature. For BC reasons, functions that take
# SymInts still present as int64_t in C++, and the SymInt variant is
# offered at a different overload name
#
# NB: If a function RETURNS a SymInt, this is ALWAYS false
symint: bool
# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: set[str]
# Is this a fallback C++ binding? Fallback bindings are enabled by
# manual_cpp_binding: True and are alternate, non-public API that
# lets manual C++ binding implementors access the binding that would
# have been automatically generated
fallback_binding: bool = False
# Return the unpacked argument structure of this signature,
# discarding information about which arguments are semantically
# related to each other.
def arguments(self) -> Sequence[Binding]:
return cpp.arguments(
self.func.arguments,
faithful=self.faithful,
symint=self.symint,
method=self.method,
cpp_no_default_args=self.cpp_no_default_args,
)
def name(self, *, suppress_symint_suffix: bool = False) -> str:
n = cpp.name(
self.func,
faithful_name_for_out_overloads=self.faithful,
symint_overload=False if suppress_symint_suffix else self.symint,
)
if self.fallback_binding:
n = f"__dispatch_{n}"
return n
# Render the C++ declaration for this signature
def decl(
self,
*,
name: str | None = None,
prefix: str = "",
is_redispatching_fn: bool = False,
suppress_symint_suffix: bool = False,
) -> str:
returns_type = cpp.returns_type(
self.func.returns, symint=self.symint
).cpp_type()
cpp_args = [a.decl() for a in self.arguments()]
if is_redispatching_fn:
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
cpp_args_str = ", ".join(cpp_args)
if name is None:
name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix)
return f"{returns_type} {name}({cpp_args_str})"
# Render the C++ definition for this signature, not including
# the body (with curly braces)
def defn(
self,
*,
name: str | None = None,
prefix: str = "",
is_redispatching_fn: bool = False,
) -> str:
returns_type = cpp.returns_type(
self.func.returns, symint=self.symint
).cpp_type()
cpp_args = [a.defn() for a in self.arguments()]
if is_redispatching_fn:
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
cpp_args_str = ", ".join(cpp_args)
if name is None:
name = prefix + self.name()
return f"{returns_type} {name}({cpp_args_str})"
def ptr_type(self) -> str:
args_types_str = ", ".join(a.type for a in self.arguments())
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})"
# Return the C++ function type, e.g., something like int(bool)
def type(self) -> str:
args_types_str = ", ".join(a.type for a in self.arguments())
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})"
# Represents group of all CppSignatures associated with a
# FunctionSchema. Right now, that's the regular, user-visible
# signature, as well as a "faithful" signature which doesn't
# have grouping.
@dataclass(frozen=True)
class CppSignatureGroup:
func: FunctionSchema
signature: CppSignature
faithful_signature: CppSignature | None
symint_signature: CppSignature | None
symint_faithful_signature: CppSignature | None
def most_faithful_signature(self) -> CppSignature:
if self.faithful_signature:
return self.faithful_signature
else:
return self.signature
def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]:
yield self.signature
if self.faithful_signature:
yield self.faithful_signature
if symint:
if self.symint_signature:
yield self.symint_signature
if self.symint_faithful_signature:
yield self.symint_faithful_signature
@staticmethod
def from_native_function(
f: NativeFunction, *, method: bool, fallback_binding: bool = False
) -> CppSignatureGroup:
func = f.func
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
return CppSignature(
func=func,
faithful=faithful,
symint=symint,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args,
)
def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
faithful_signature: CppSignature | None = None
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = make_sig(faithful=True, symint=symint)
signature = make_sig(faithful=False, symint=symint)
return signature, faithful_signature
signature, faithful_signature = make_sigs(symint=False)
symint_signature: CppSignature | None = None
symint_faithful_signature: CppSignature | None = None
if func.has_symint():
symint_signature, symint_faithful_signature = make_sigs(symint=True)
return CppSignatureGroup(
func=func,
signature=signature,
faithful_signature=faithful_signature,
symint_signature=symint_signature,
symint_faithful_signature=symint_faithful_signature,
)
@dataclass(frozen=True)
class DispatcherSignature:
# The schema this signature is derived from
func: FunctionSchema
# Allows you to prepend an arbitrary prefix to the signature name.
# This is useful for parts of the codegen that generate wrappers around kernels,
# and need to avoid naming collisions.
prefix: str = ""
symint: bool = True
def arguments(self) -> list[Binding]:
return dispatcher.arguments(self.func, symint=self.symint)
def name(self) -> str:
return self.prefix + dispatcher.name(self.func)
def decl(self, name: str | None = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments())
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn(
self, name: str | None = None, *, is_redispatching_fn: bool = False
) -> str:
args = [a.defn() for a in self.arguments()]
if is_redispatching_fn:
args = ["c10::DispatchKeySet dispatchKeySet"] + args
args_str = ", ".join(args)
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def exprs(self) -> list[Expr]:
return [Expr(a.name, a.nctype) for a in self.arguments()]
def returns_type(self) -> CType:
return dispatcher.returns_type(self.func.returns, symint=self.symint)
def ptr_type(self) -> str:
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})"
# Return the C++ function type, e.g., something like int(bool)
def type(self) -> str:
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
@staticmethod
def from_schema(
func: FunctionSchema, *, prefix: str = "", symint: bool = True
) -> DispatcherSignature:
return DispatcherSignature(func, prefix, symint)
@dataclass(frozen=True)
class NativeSignature:
# The schema this signature is derived from
func: FunctionSchema
symint: bool
prefix: str = ""
def name(self) -> str:
return self.prefix + native.name(self.func)
def decl(self, name: str | None = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments())
if name is None:
name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
def defn(self, name: str | None = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments())
if name is None:
name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
def ptr_type(self) -> str:
# don't include defaults in type signature!
args_str = ", ".join(a.defn() for a in self.arguments())
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
def arguments(self) -> list[Binding]:
return native.arguments(self.func, symint=self.symint)
def returns_type(self) -> CType:
return native.returns_type(self.func.returns, symint=self.symint)
def dispatcher_exprs(self) -> list[Expr]:
return translate.translate(
self.arguments(), dispatcher.arguments(self.func), method=False
)
@dataclass(frozen=True)
class ViewInverseSignature:
g: NativeFunctionsViewGroup
def name(self) -> str:
return functionalization.reverse_name(self.g.view, include_namespace=False)
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
decls = [
a.decl()
for a in functionalization.inner_arguments(
self.g.view.func, is_reverse=True
)
]
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
@dataclass(frozen=True)
class FunctionalizationLambda:
g: NativeFunctionsViewGroup
# are we generating the forward lambda or the reverse lambda?
is_reverse: bool
def captures(self) -> list[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
outer_ctx = dispatcher.arguments(self.g.view.func) + [
functionalization.reapply_views_binding,
functionalization.inverse_return_mode_binding,
]
capture_bindings = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
# allow_expensive_conversions is set because we want to convert
# some reference types (IntArrayRef) to value types (vector<int64_t>).
capture_exprs = translate.translate(
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
)
return capture_exprs
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
capture_str = ", ".join(
f"{val.type.name} = {val.expr}" for val in self.captures()
)
decls = [
a.decl()
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: bool | None = None) -> str:
inner_call_name = functionalization.name(
self.g,
is_reverse=self.is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
capture_ctx = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
full_ctx = arg_ctx + capture_ctx
assert self.g.view_copy is not None
call_bindings = functionalization.inner_arguments(
self.g.view_copy.func, is_reverse=self.is_reverse
)
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
call_exprs = [
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
]
if not self.is_reverse and maybe_index is not None:
return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
else:
return f'{inner_call_name}({", ".join(call_exprs)});'
@staticmethod
def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> FunctionalizationLambda:
return FunctionalizationLambda(g, is_reverse)
@dataclass(frozen=True)
class StructuredImplSignature:
g: NativeFunctionsGroup
name: str
def defn(self, name: str | None = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments())
return f"TORCH_IMPL_FUNC({self.name})({args_str})"
def arguments(self) -> list[Binding]:
return structured.impl_arguments(self.g)
# Helper functions
def kernel_signature(
f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
) -> NativeSignature | DispatcherSignature:
# Note [External Backends Follow Dispatcher API]
# Kernel signatures for in-tree backends follow the "native" API,
# while kernels for out-of-tree backends follow the dispatcher API.
# See the comments in `native.py` for details, but historically there have been
# some small differences in schema convention between them and the Dispatcher API.
# Any differences that require translating between the two will results in a runtime cost,
# so we'd like to keep the differences as small as possible.
# With external backends, we'd like to enforce that they write their kernels with schemas
# that match the Dispatcher API directly, if they can.
meta = backend_index.get_kernel(f)
symint = meta is not None and meta.supports_symint()
if symint:
assert (
f.func.has_symint()
), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
if backend_index.external:
return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
else:
return NativeSignature(f.func, prefix=prefix, symint=symint)
# Functions only, no types
from torchgen.api import (
cpp,
dispatcher,
functionalization,
native,
structured,
translate,
)

View File

@ -0,0 +1,191 @@
"""
Where should I add a new type? `types_base.py` vs `types.py`
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from __future__ import annotations
from dataclasses import dataclass
from torchgen.api.types.types_base import (
BaseCppType,
BaseCType,
boolT,
byteT,
charT,
CType,
doubleT,
floatT,
int32T,
longT,
shortT,
)
from torchgen.model import BaseTy, ScalarType
TENSOR_LIST_LIKE_CTYPES = [
"at::TensorList",
"const c10::List<::std::optional<at::Tensor>> &",
"const at::ITensorListRef &",
]
halfT = BaseCppType("at", "Half")
complexHalfT = BaseCppType(
"c10", "complex<c10::Half>"
) # stuffing template param here is an abuse
complexFloatT = BaseCppType("c10", "complex<float>")
complexDoubleT = BaseCppType("c10", "complex<double>")
bfloat16T = BaseCppType("at", "BFloat16")
float8_e5m2T = BaseCppType("at", "Float8_e5m2")
float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
stringT = BaseCppType("c10", "string_view")
generatorT = BaseCppType("at", "Generator")
scalarTypeT = BaseCppType("at", "ScalarType")
tensorT = BaseCppType("at", "Tensor")
optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
tensorListT = BaseCppType("at", "TensorList")
iTensorListRefT = BaseCppType("at", "ITensorListRef")
iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
dimnameT = BaseCppType("at", "Dimname")
dimnameListT = BaseCppType("at", "DimnameList")
dimVectorT = BaseCppType("at", "DimVector")
layoutT = BaseCppType("at", "Layout")
deviceT = BaseCppType("at", "Device")
deviceIndexT = BaseCppType("at", "DeviceIndex")
scalarT = BaseCppType("at", "Scalar")
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
memoryFormatT = BaseCppType("at", "MemoryFormat")
qschemeT = BaseCppType("at", "QScheme")
storageT = BaseCppType("at", "Storage")
streamT = BaseCppType("at", "Stream")
intArrayRefT = BaseCppType("at", "IntArrayRef")
optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
tensorOptionsT = BaseCppType("at", "TensorOptions")
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
tensorGeometryT = BaseCppType("at", "TensorGeometry")
SymIntT = BaseCppType("c10", "SymInt")
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
# Types representing template parameters. Technically, we probably shouldn't
# represent them this way in codegen, but it was pretty convenient.
scalar_t = BaseCppType("", "scalar_t")
opmath_t = BaseCppType("", "opmath_t")
ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
ScalarType.Byte: byteT,
ScalarType.Char: charT,
ScalarType.Short: shortT,
ScalarType.Int: int32T,
ScalarType.Long: longT,
ScalarType.Half: halfT,
ScalarType.Float: floatT,
ScalarType.Double: doubleT,
ScalarType.ComplexHalf: complexHalfT,
ScalarType.ComplexFloat: complexFloatT,
ScalarType.ComplexDouble: complexDoubleT,
ScalarType.Bool: boolT,
ScalarType.Float8_e5m2: float8_e5m2T,
ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
ScalarType.Float8_e4m3fn: float8_e4m3fnT,
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
}
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
BaseTy.str: stringT,
BaseTy.Generator: generatorT,
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Dimname: dimnameT,
BaseTy.DimVector: dimVectorT,
BaseTy.Layout: layoutT,
BaseTy.Device: deviceT,
BaseTy.DeviceIndex: deviceIndexT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
BaseTy.QScheme: qschemeT,
BaseTy.Storage: storageT,
BaseTy.Stream: streamT,
BaseTy.SymInt: SymIntT,
}
# CTypes encode C++ type structure as needed for translation.
@dataclass(frozen=True)
class OptionalCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::optional<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ListCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"c10::List<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return ListCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"at::ArrayRef<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return ArrayRefCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class VectorizedCType(CType):
# This template is explicitly specialized, so the only valid
# elems are those we have specializations for (e.g., float, double, ...)
# scalar_t is also a common argument here (when we are codegen in
# a templated context)
elem: BaseCType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
def remove_const_ref(self) -> CType:
return self

View File

@ -0,0 +1,276 @@
"""
Where should I add a new type? `types_base.py` vs `types.py`
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import auto, Enum
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
# An ArgName is just the str name of the argument in schema;
# but in some special circumstances, we may add a little extra
# context. The Enum SpecialArgName covers all of these cases;
# grep for their construction sites to see when they can occur.
class SpecialArgName(Enum):
possibly_redundant_memory_format = auto()
ArgName = Union[str, SpecialArgName]
# This class shouldn't be created directly; instead, use/create one of the singletons below.
@dataclass(frozen=True)
class BaseCppType:
ns: str | None
name: str
def __str__(self) -> str:
if self.ns is None or self.ns == "":
return self.name
return f"{self.ns}::{self.name}"
# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
# Templated types get their own dataclass, mainly to make namespace parsing easier.
byteT = BaseCppType("", "uint8_t")
charT = BaseCppType("", "int8_t")
shortT = BaseCppType("", "int16_t")
# It would be more symmetric for this to be called intT, but it easy to mix
# this up with JIT int (which is int64_t in C++), so we intentionally don't
# define intT to make it obvious when you've stuffed it up
int32T = BaseCppType("", "int32_t")
longT = BaseCppType("", "int64_t")
doubleT = BaseCppType("", "double")
floatT = BaseCppType("", "float")
boolT = BaseCppType("", "bool")
voidT = BaseCppType("", "void")
class CType(ABC):
@abstractmethod
def cpp_type(self, *, strip_ref: bool = False) -> str:
raise NotImplementedError
@abstractmethod
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
@abstractmethod
def remove_const_ref(self) -> CType:
return self
@dataclass(frozen=True)
class BaseCType(CType):
type: BaseCppType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return str(self.type)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return str(self.type).replace("at::", "")
def remove_const_ref(self) -> CType:
return self
@dataclass(frozen=True)
class ConstRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"const {self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"const {self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> CType:
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class VectorCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::vector<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return VectorCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayCType(CType):
elem: CType
size: int
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::array<{self.elem.cpp_type()},{self.size}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
def remove_const_ref(self) -> CType:
return ArrayCType(self.elem.remove_const_ref(), self.size)
@dataclass(frozen=True)
class TupleCType(CType):
elems: list[CType]
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
def cpp_type_registration_declarations(self) -> str:
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
def remove_const_ref(self) -> CType:
return TupleCType([e.remove_const_ref() for e in self.elems])
@dataclass(frozen=True)
class MutRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"{self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"{self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> CType:
return self.elem.remove_const_ref()
# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus
# semantic information about what it represents. For example, consider the
# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
# semantic type also keeps track that this represents a "pin_memory"; you can't
# just use a random other boolean in a context where you need a "pin_memory"!
#
@dataclass(frozen=True)
class NamedCType:
name: ArgName
type: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return self.type.cpp_type(strip_ref=strip_ref)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return self.type.cpp_type_registration_declarations()
def remove_const_ref(self) -> NamedCType:
return NamedCType(self.name, self.type.remove_const_ref())
def with_name(self, name: str) -> NamedCType:
return NamedCType(name, self.type)
# A binding represents any C++ binding site for a formal parameter.
# We don't distinguish between binding sites for different APIs;
# instead, all of the important distinctions are encoded in CType,
# which you can use to figure out if a given Binding is appropriate
# for use in another context. (See torchgen.api.translate)
@dataclass(frozen=True)
class Binding:
name: str
nctype: NamedCType
argument: Argument | TensorOptionsArguments | SelfArgument
# TODO: maybe don't represent default here
default: str | None = None
def rename(self, name: str) -> Binding:
return Binding(
name=name,
nctype=self.nctype,
argument=self.argument,
default=self.default,
)
@property
def type(self) -> str:
return self.nctype.cpp_type()
def no_default(self) -> Binding:
return Binding(
name=self.name,
nctype=self.nctype,
default=None,
argument=self.argument,
)
def decl(self, *, func_ptr_cast: bool = False) -> str:
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
# casting only needs to know the type
if func_ptr_cast:
return f"{self.type}"
else:
return f"{self.type} {self.name}{mb_default}"
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def decl_registration_declarations(self) -> str:
type_s = self.nctype.cpp_type_registration_declarations()
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
return f"{type_s} {self.name}{mb_default}"
def defn(self) -> str:
return f"{self.type} {self.name}"
def with_name(self, name: str) -> Binding:
return Binding(
name=name, nctype=self.nctype, argument=self.argument, default=self.default
)
# An Expr is a C++ expression. It has a C++ string representing its syntax,
# as well as a CType saying what it provides.
@dataclass(frozen=True)
class Expr:
expr: str
type: NamedCType

View File

@ -0,0 +1,209 @@
from __future__ import annotations
from dataclasses import dataclass
import torchgen.api.types as api_types
from torchgen.api import cpp, structured
from torchgen.api.types import (
ArgName,
BaseCppType,
BaseCType,
Binding,
ConstRefCType,
CType,
NamedCType,
scalarT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
DispatchKey,
FunctionSchema,
NativeFunctionsGroup,
Type,
)
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
return f"ufunc_{func.name.name}_{dispatch_key}"
def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
return schema_kernel_name(g.out.func, dispatch_key)
# Tensors are omitted (as they are stored in TensorIterator), everything else is
# passed along (technically, we can pass tensors along too, it just wastes
# argument registers)
#
# NB: used for CPU only
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
# Dispatch stubs are always plain ints
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
elif t == BaseType(BaseTy.Tensor):
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
if scalar_t == api_types.scalar_t:
return api_types.opmath_t
raise NotImplementedError
# NB: Tensors in constructor are stored in opmath_t, not scalar_t
# because Tensor in constructor = its a scalar tensor partially applied =
# it can be higher precision and we want to compute in that higher precision
#
# NB: CUDA only
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
elif t == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Only Tensors ever get passed directly to operator()
#
# NB: CUDA only
# (Actually, this works for CPU too)
def ufunctor_apply_type(
t: Type, *, binds: ArgName, scalar_t: BaseCppType
) -> NamedCType:
if t == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(scalar_t))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# The actual ufunc template function the user writes. Everything here
# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
# in CPU
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, compute_t)
elif t == BaseType(BaseTy.Tensor):
return NamedCType(binds, compute_t)
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
return Binding(
nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
name=a.name,
default=None,
argument=a,
)
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
return Binding(
nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
name=a.name,
default=None,
argument=a,
)
def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
return Binding(
nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
name=a.name,
default=None,
argument=a,
)
@dataclass(frozen=True)
class UfunctorBindings:
ctor: list[Binding]
apply: list[Binding]
# ufunctors are a CUDA-only concept representing functors that take some of
# their arguments on a host-side constructor, and the rest in the device-side
# apply. E.g.,
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
# using opmath_t = at::opmath_type<scalar_t>;
# opmath_t other_;
# opmath_t alpha_;
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
# __device__ scalar_t operator()(scalar_t self) {
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
# }
# };
#
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
# to the operator() definition
def ufunctor_arguments(
g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
) -> UfunctorBindings:
ctor = []
apply = []
for a in g.functional.func.arguments.flat_non_out:
if a.type.is_tensor_like():
if scalar_tensor_idx == 0:
# put it in the ctor anyway
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
scalar_tensor_idx = None
else:
if scalar_tensor_idx is not None:
scalar_tensor_idx -= 1
apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
else:
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
assert scalar_tensor_idx is None
return UfunctorBindings(ctor=ctor, apply=apply)
# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
# which do the actual computation in question. E.g.,
#
# template <typename T>
# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
# return self + alpha * other;
# }
#
# In this file, we refer to T as compute_t which is bound by caller
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
return [
ufunc_argument(a, compute_t=compute_t)
for a in g.functional.func.arguments.flat_non_out
]
# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
# vectorized versions. E.g.,
#
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
# stubs drop all tensor arguments (they are implicit in the TensorIterator
# argument and keep everything else)
return [
r
for a in g.out.func.arguments.flat_non_out
if not a.type.is_tensor_like()
for r in structured.argument(a)
]

View File

@ -0,0 +1,249 @@
from __future__ import annotations
from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignatureGroup, CType
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Type,
)
# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is
# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the
# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register
# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase.
#
# Here's an example on how the codegen works:
#
# - Function Schema (source of truth)
#
# aten::empty.names(int[] size, *, Dimname[]? names,
# ScalarType? dtype=None, Layout? layout=None,
# Device? device=None, bool? pin_memory=None,
# MemoryFormat? memory_format=None) -> Tensor
# - Argument Conversion
# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type.
# - int[] size
# ```cpp
# const c10::List<c10::IValue> size_list_in = (std::move(peek(stack, 0, 7))).toList();
#
# std::vector<int64_t> size_vec;
# for (c10::IValue size_elem: size_list_in) {
# int64_t size_base = size_elem.to<int64_t>();
# size_vec.push_back(size_base);
# }
# at::ArrayRef<int64_t> size_list_out(size_vec);
# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack.
# Will be passed to unboxed kernel.
# ```
# - Dimname[]? names
# ```cpp
# ::std::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>();
# ::std::optional<at::ArrayRef<at::Dimname>> names_opt_out;
# if (names_opt.has_value()) {
# ~~~~~~~~~~~ <-- Unwrapping optional shell
# const c10::IValue names_opt_in = names_opt.value();
# const c10::List<c10::IValue> names_list_in = names_opt_in.toList();
#
# std::vector<at::Dimname> names_vec;
# for (c10::IValue names_elem: names_list_in) {
# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one.
# at::Dimname names_base = names_elem.to<at::Dimname>();
# names_vec.push_back(names_base);
# }
# at::ArrayRef<at::Dimname> names_list_out(names_vec);
#
# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(names_list_out);
# } else {
# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>();
# }
# ```
# - ScalarType? dtype (similarly for the rest of the arguments)
# ```cpp
# ::std::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>();
# ::std::optional<at::ScalarType> dtype_opt_out;
# if (dtype_opt.has_value()) {
# const c10::IValue dtype_opt_in = dtype_opt.value();
# at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>();
# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
# directly using ".to<T>()" API.
# dtype_opt_out = ::std::optional<at::ScalarType>(dtype_base);
# } else {
# dtype_opt_out = ::std::optional<at::ScalarType>();
# }
# ```
#
# - Unboxed Kernel Call
# ```cpp
# auto result_ = torch::empty(
# size_list_out,
# names_opt_out,
# options,
# memory_format_opt_out
# );
# ```
#
# - Push Result Back to Stack
# ```cpp
# drop(stack, 7);
# pack(stack, std::move(result_));
# ```
connector = "\n\t"
# Return unboxing function name for a NativeFunction
def name(f: NativeFunction) -> str:
return f.func.name.unambiguous_name()
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
# we need the 'self' argument so method needs to be False
args = (
CppSignatureGroup.from_native_function(f, method=False)
.most_faithful_signature()
.arguments()
)
code_list = [
f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));"
for i in range(len(args))
] + [""]
binding_list = []
for arg in args:
# expecting only Argument
if not isinstance(arg.argument, Argument):
raise Exception( # noqa: TRY002
f"Unexpected argument type, expecting `Argument` but got {arg}"
)
argument: Argument = arg.argument
unboxed_name, _, code, decl = argumenttype_ivalue_convert(
argument.type,
argument.name,
mutable=argument.is_write,
)
code_list.extend(decl)
code_list.extend(code)
binding_list.append(arg.with_name(unboxed_name))
return binding_list, code_list
# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
# (1) the C++ code necessary to unbox the argument
# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
def argumenttype_ivalue_convert(
t: Type, arg_name: str, *, mutable: bool = False
) -> tuple[str, CType, list[str], list[str]]:
# Unboxing is for mobile, which doesn't care about SymInts
ctype = cpp.argumenttype_type(
t=t, mutable=mutable, binds=arg_name, symint=False
).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"
code, decl = _gen_code_base_type(
arg_name=arg_name, out_name=out_name, ctype=ctype
)
elif isinstance(t, OptionalType):
out_name = f"{arg_name}_opt_out"
code, decl = _gen_code_optional_type(
arg_name=arg_name,
out_name=out_name,
t=t,
ctype=ctype,
)
elif isinstance(t, ListType):
out_name = f"{arg_name}_list_out"
code, decl = _gen_code_list_type(
arg_name=arg_name,
out_name=out_name,
t=t,
ctype=ctype,
)
else:
raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") # noqa: TRY002
return out_name, ctype, code, decl
def _gen_code_base_type(
arg_name: str, out_name: str, ctype: CType
) -> tuple[list[str], list[str]]:
return [
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_opt_in"
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
return (
f"""
auto {arg_name}_opt = {arg_name}.toOptional<c10::IValue>();
{ctype.cpp_type(strip_ref=True)} {out_name};
if ({arg_name}_opt.has_value()) {{
const c10::IValue {in_name} = {arg_name}_opt.value();
{connector.join(res_code)}
{out_name} = {ctype.cpp_type(strip_ref=True)}({res_name});
}} else {{
{out_name} = {ctype.cpp_type(strip_ref=True)}();
}}
""".split(
"\n"
),
decl,
)
def _gen_code_list_type(
arg_name: str, out_name: str, t: ListType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name)
# handle list type with size, e.g., bool[4]
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
""".split(
"\n"
)
)
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif isinstance(t.elem, OptionalType):
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name};
for (c10::IValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{out_name}.push_back({res_name});
}}
""".split(
"\n"
)
)
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};")
code.extend(
f"""
for (c10::IValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
return code, decl