I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,10 @@
"""torchgen
This module contains codegeneration utilities for PyTorch. It is used to
build PyTorch from source, but may also be used for out-of-tree projects
that extend PyTorch.
Note well that we provide no BC guarantees for torchgen. If you're interested
in using torchgen and want the PyTorch team to be aware, please reach out
on GitHub.
"""

View File

@ -0,0 +1,149 @@
# Be extra careful when you edit this file, because it affects AOTInductor ABI compatbility. See
# https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436
# for details.
#
# The inductor_fallback_ops list is based on the fallback ops from torch/_inductor/lowering.py.
# Generally speaking, it is ok to add a new op to the list, but you need to run
# `python torchgen/gen.py --update-aoti-c-shim` in order to regenerate C shim header files.
# But it is NOT ok to remove an existing fallback op from the list, since that will break
# some existing AOTInductor-compiled models.
inductor_fallback_ops = {
"aten._adaptive_avg_pool2d_backward.default",
"aten._adaptive_avg_pool2d.default",
"aten._adaptive_avg_pool3d.default",
"aten._adaptive_avg_pool3d_backward.default",
"aten.adaptive_max_pool2d_backward.default",
"aten.adaptive_max_pool2d.default",
"aten.adaptive_max_pool3d.default",
"aten.adaptive_max_pool3d_backward.default",
"aten.addbmm.default",
"aten._addmm_activation.default",
"aten.addmm.out",
"aten.addmv.default",
"aten.angle.default",
"aten.avg_pool2d_backward.default",
"aten.avg_pool2d.default",
"aten.avg_pool3d_backward.default",
"aten.avg_pool3d.default",
"aten.bernoulli_.float",
"aten.bernoulli_.Tensor",
"aten.bmm.out",
"aten.bucketize.Tensor",
"aten.cat.default",
"aten._cdist_backward.default",
"aten._cdist_forward.default",
"aten.cholesky_inverse.default",
"aten.cholesky_solve.default",
"aten.convolution_backward.default",
"aten._cudnn_rnn.default",
"aten._cudnn_rnn_backward.default",
"aten.convolution.default",
"aten.cummax.default",
"aten.cummin.default",
"aten.cumprod.default",
"aten.cumsum.default",
"aten._efficient_attention_backward.default",
"aten._efficient_attention_forward.default",
"aten._efficientzerotensor.default",
"aten._embedding_bag.default",
"aten._embedding_bag_dense_backward.default",
"aten._embedding_bag_forward_only.default",
"aten._embedding_bag_per_sample_weights_backward.default",
"aten.exponential.default",
"aten._fft_c2c.default",
"aten._fft_r2c.default",
"aten._flash_attention_backward.default",
"aten._flash_attention_forward.default",
"aten.fractional_max_pool2d_backward.default",
"aten.fractional_max_pool2d.default",
"aten.fractional_max_pool3d.default",
"aten.fractional_max_pool3d_backward.default",
"aten._fused_moving_avg_obs_fq_helper.default",
"aten._fused_moving_avg_obs_fq_helper_functional.default",
"aten.gcd.default",
"aten.geqrf.default",
"aten.grid_sampler_2d_backward.default",
"aten.histc.default",
"aten.histogram.bin_ct",
"aten._histogramdd_bin_edges.default",
"aten._histogramdd_from_bin_cts.default",
"aten.index_put.default",
"aten.index_reduce.default",
"aten.index.Tensor",
"aten.kthvalue.default",
"aten.logcumsumexp.default",
"aten.lu_unpack.default",
"aten.masked_scatter.default",
"aten.masked_scatter_backward.default",
"aten.max_pool2d_with_indices_backward.default",
"aten.max_pool2d_with_indices.default",
"aten.max_pool3d_with_indices.default",
"aten.max_pool3d_with_indices_backward.default",
"aten.max_unpool2d.default",
"aten.max_unpool3d.default",
"aten.median.default",
"aten.mm.out",
"aten.mode.default",
"aten.mul.Scalar",
"aten.mul.Tensor",
"aten.nanmedian.default",
"aten.native_dropout.default",
"aten.normal_functional.default",
"aten.nonzero.default",
"aten.ormqr.default",
"aten._pdist_backward.default",
"aten._pdist_forward.default",
"aten.polar.default",
"aten.pow.Scalar",
"aten.pow.Tensor_Scalar",
"aten.pow.Tensor_Tensor",
"aten.rand.default",
"aten.rand.generator",
"aten.randint.default",
"aten.randint.generator",
"aten.randint.low",
"aten.randint.low_out",
"aten.randn.default",
"aten.randn.generator",
"aten.randperm.default",
"aten.repeat_interleave.Tensor",
"aten.replication_pad1d_backward.default",
"aten.replication_pad2d_backward.default",
"aten.reshape.default",
"aten.resize_.default",
"aten.resize_as_.default",
"aten._scaled_dot_product_efficient_attention_backward.default",
"aten._scaled_dot_product_efficient_attention.default",
"aten._scaled_dot_product_flash_attention_backward.default",
"aten._scaled_dot_product_flash_attention.default",
"aten._scaled_dot_product_cudnn_attention_backward.default",
"aten._scaled_dot_product_cudnn_attention.default",
"aten._scaled_dot_product_flash_attention_for_cpu_backward.default",
"aten._scaled_dot_product_flash_attention_for_cpu.default",
"aten._scaled_mm.default",
"aten.scatter_reduce.two_out",
"aten.scatter.src_out",
"aten.scatter.value_out",
"aten.searchsorted.default",
"aten._segment_reduce_backward.default",
"aten.segment_reduce.default",
"aten.slice.Tensor",
"aten.soft_margin_loss_backward.default",
"aten.sort.default",
"aten.sort.stable",
"aten._sparse_coo_tensor_with_dims_and_tensors.default",
"aten._thnn_fused_lstm_cell.default",
"aten.topk.default",
"aten._to_sparse.default",
"aten.to_sparse.default",
"aten.triangular_solve.default",
"aten._trilinear.default",
"aten.uniform.default",
"aten.upsample_bicubic2d_backward.default",
"aten.upsample_linear1d_backward.default",
"aten.upsample_trilinear3d_backward.default",
"aten.view_as_complex.default",
"aten.view_as_real.default",
"aten.view.dtype",
"aten.zeros.names",
}

View File

@ -0,0 +1,870 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import cast, Sequence
from torchgen import local
from torchgen.api import cpp
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
from torchgen.model import (
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsViewGroup,
SchemaKind,
Type,
)
from torchgen.utils import IDENT_REGEX
# Represents a saved attribute involved in backward calculation.
# Note that it can be a derived property of an input argument, e.g.:
# we could save `other.scalar_type()` instead of the entire `other` tensor.
@dataclass(frozen=True)
class SavedAttribute:
# The NamedCType holds the updated name and cpp type of the attribute
# for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
nctype: NamedCType
# The expression to read the derived property at save time, e.g.:
# `other.scalar_type()`.
expr: str
# Represents a backward formula that calculates derivatives for one
# or more tensors.
@dataclass(frozen=True)
class Derivative:
# The formula string (legit C++ expression).
# Note that expressions against input arguments have been replaced with the
# corresponding saved attributes.
# E.g.:
# raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
# here: `mul_tensor_backward(grad, self, other_scalar_type)`
formula: str
# The formula string before input argument replacement
original_formula: str
# Names of the arguments for which this formula calculates derivatives.
var_names: tuple[str, ...]
# Saved inputs that are referenced by the formula.
saved_inputs: tuple[SavedAttribute, ...]
# Saved outputs that are referenced by the formula.
saved_outputs: tuple[SavedAttribute, ...]
# Gradients that are referenced by name in the formula.
named_gradients: set[str]
# Represents a forward formula that calculates forward derivatives
# for one tensor.
@dataclass(frozen=True)
class ForwardDerivative:
# The formula string (legit C++ expression).
# Note that special keywords such as "linear" or "element_wise" have been
# replaced by the automatically generated formula.
formula: str
# Name of the output arguments for which this formula calculates forward
# derivatives
var_names: tuple[str, ...]
# Type of the output arguments for which this formula calculates forward
# derivatives
var_types: tuple[Type, ...]
# Inputs for which the forward derivatives are required for this formula
required_inputs_fw_grad: tuple[str, ...] | None
# Inputs for which the primal is required for this formula
required_inputs_primal: tuple[str, ...] | None
# Flag to specify if this formula requires the original value of self
# This is only used by inplace operations
required_original_self_value: bool
# If this formula is specified in derivatives.yaml or if we are re-using the
# out of place formula for inplace
is_reusing_outplace_formula: bool
# Represents differentiability info for a NativeFunction.
@dataclass(frozen=True)
class DifferentiabilityInfo:
# The base name read from derivatives.yaml.
name: str
# The matching native function.
#
# There can be multiple NativeFunction having the same base name:
# - different overloads with different types of input arguments;
# - in-place/out/functional variants of the same function;
#
# We first use the schema string (under the 'name' key) in derivatives.yaml
# to find the NativeFunction having the same schema string.
# Then we find the in-place/out/functional variants of the matching function.
# Among these variants, we choose the one having the same name as the
# derivatives.yaml entry. If there is no exact match, then we choose the
# in-place variant.
# TODO: maybe the logic to search for all variants is no longer necessary?
func: NativeFunction
# The name of the generated autograd function.
# It's set only if we will calculate a derivative, i.e.
# 'args_with_derivatives' is not empty.
op: str | None
# The derivatives formulae for this function.
# Note that the length of this sequence is the number of differentiable inputs
derivatives: Sequence[Derivative]
# The forward derivatives formulae for this function.
# Note that the length of this sequence is the number of differentiable outputs
forward_derivatives: Sequence[ForwardDerivative]
# The union of 'saved_inputs' of all 'derivatives'.
all_saved_inputs: Sequence[SavedAttribute]
# The union of 'saved_outputs' of all 'derivatives'.
all_saved_outputs: Sequence[SavedAttribute]
# All named gradients that are available for use, in the same
# order as in the grads vector.
available_named_gradients: Sequence[str]
# The named gradients that are used in any of the derivatives.
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
used_named_gradients: set[str]
# The function's input arguments for which it calculates derivatives.
# It's the union of 'var_names' of all 'derivatives', sorted by the
# argument order in the function schema.
args_with_derivatives: Sequence[Binding]
# Names of arguments whose derivative formula is 'non_differentiable'.
non_differentiable_arg_names: Sequence[str]
# Raw data read from derivatives.yaml.
output_differentiability: list[bool] | None
# output_differentiability in derivatives.yaml can be a list of
# conditions that express if the output is differentiable. In this case,
# the number of conditions must match the number of outputs
# (NB: we only support one condition right now).
# output_differentiability gets populated with True for each condition,
# while output_differentiability_conditions gets populated with the conditions
output_differentiability_conditions: list[str] | None
@property
def has_derivatives(self) -> bool:
return len(self.args_with_derivatives) > 0
# Generates a new DifferentiabilityInfo using the exact same set of derivative information,
# but with a new operator name.
# This is used when generating "copy" variants of view ops,
# which are able to use the exact same derivative formula as the original view op
# See Note [Codegen'd {view}_copy Operators]
def create_view_copy_from_view_derivative(
self, g: NativeFunctionsViewGroup
) -> DifferentiabilityInfo | None:
if g.view_copy is None:
return None
f = g.view_copy
name_split_by_period = self.name.split(".", maxsplit=2)
# Append a "_copy" to the base name of the operator (but keep the overload name the same)
view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
name_split_by_period[1:]
)
view_copy_op_name = None if self.op is None else f"{self.op}_copy"
return DifferentiabilityInfo(
# Use the "_copy" version of name/func/op
name=view_copy_name,
func=f,
op=view_copy_op_name,
# But keep all derivative info the same
derivatives=self.derivatives,
forward_derivatives=self.forward_derivatives,
all_saved_inputs=self.all_saved_inputs,
all_saved_outputs=self.all_saved_outputs,
available_named_gradients=self.available_named_gradients,
used_named_gradients=self.used_named_gradients,
args_with_derivatives=self.args_with_derivatives,
non_differentiable_arg_names=self.non_differentiable_arg_names,
output_differentiability=self.output_differentiability,
output_differentiability_conditions=self.output_differentiability_conditions,
)
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
if info is None:
return False
for derivative in info.derivatives:
formula = derivative.formula
if re.search(IDENT_REGEX.format(ident), formula):
return True
return False
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "retain_variables")
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "grad")
# Represents a differentiable `Argument`.
# How is it different from the `Argument` type?
# - It's processed Arguments which are differentiable and only used in the
# context of the autograd codegen;
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
@dataclass(frozen=True)
class DifferentiableInput:
name: str
type: Type
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
cpp_type: str
# Represents a differentiable `Return`.
# How it it different from the `Return` type?
# - The name in `Return` is optional. Here it is always populated using the same
# `cpp.return_names()` method.
# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
# - It's processed Returns which are differentiable, in compliance with the
# `output_differentiability` field defined in derivatives.yaml (if specified),
# and are only used in the context of the autograd codegen;
@dataclass(frozen=True)
class DifferentiableOutput:
name: str
type: Type
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
cpp_type: str
@dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo:
func: NativeFunction
info: dict[str, DifferentiabilityInfo] | None
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
# TODO: Update comment below since it is out of date.
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
"""How are we going to call the underlying implementation of a
declaration? There are two strategies:
- use_derived: we want to call the implementation on CPUDoubleType
(or a similar, derived Type instance). Because these derived
instances deal in Tensors, not Variables (it's a completely different
object, so it doesn't dispatch back to VariableType), code on
this dispatch path needs to wrap/unwrap tensors. If the
derived implementation takes and returns tensors, the
implementation is usually differentiable (although we also use
the derived dispatch path for non-differentiable functions
that we still want to dispatch on the derived Type instance;
e.g., size())
- use_type: we want to call the implementation on Type, because
it is implemented concretely, and the functions it invokes will
get dispatched back to VariableType (which will ensure that they
are differentiable.)
"""
# fn is derived as long as any of its per-key differentiability infos
# has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
# and ADInplaceOrViewType. We want to generate these functions as long as a
# derivative is defined for ANY dispatch key.
if fn.func.is_abstract or (
fn.info is not None and any(info.has_derivatives for info in fn.info.values())
):
# If the function is abstract (not implemented on at::Type), we must
# call the implementation on the derived type with unpacked tensors.
# If the function has a derivative specified and is concrete, we could
# call either implementation. We prefer the calling the derived
# type's implementation with unpacked tensors because it is more
# performant in some cases: any internal calls to other ATen functions
# won't have the history tracked.
# If the function has a type dispatched argument (i.e. is a factory),
# we prefer calling the derived type's implementation both because it is
# more performant and to ensure factory functions return tensors with _version
# of 0 (probably not strictly necessary, but nice to have to keeps versions simple
# to understand.
return "use_derived"
else:
# If the function is concrete (we don't have to override it) and we
# didn't declare it in derivatives.yaml, we'll assume that it is
# actually implemented out of differentiable functions. (This
# assumption might not hold, but then you'll see gradcheck fail.)
return "use_type"
def is_foreach_func(f: NativeFunction) -> bool:
return f.func.name.name.base.startswith("_foreach_")
# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
_foreach_with_inplace_ref = {"_foreach_zero_"}
_foreach_with_tensor_overload = {
"_foreach_add.Tensor",
"_foreach_mul.Tensor",
"_foreach_div.Tensor",
}
# The following do not support the alpha kwarg, which the nonforeach versions support.
_skip_argument_len_check = {
"_foreach_add.Scalar",
"_foreach_add_.Scalar",
"_foreach_add.ScalarList",
"_foreach_add_.ScalarList",
"_foreach_sub.Scalar",
"_foreach_sub_.Scalar",
"_foreach_sub.ScalarList",
"_foreach_sub_.ScalarList",
}
# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
# reference to generate derivatives.
def is_reference_for_foreach(
f: NativeFunction,
function_schema: FunctionSchema,
) -> bool:
return (
f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
and (
not function_schema.name.name.inplace
or str(f.func.name) in _foreach_with_inplace_ref
)
and (
str(f.func.name) in _skip_argument_len_check
or len(f.func.arguments.flat_non_out)
== len(function_schema.arguments.flat_non_out)
)
and all(
ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
for arg, ref_arg in zip(
f.func.arguments.flat_non_out,
function_schema.arguments.flat_non_out,
)
)
)
# TODO(crcrpar): Avoid hard coding "Default" ideally.
def gen_foreach_derivativeinfo(
foreach_function: NativeFunction,
functional_info_by_signature: dict[
FunctionSchema, dict[str, DifferentiabilityInfo]
],
non_functional_info_by_signature: dict[
FunctionSchema, dict[str, DifferentiabilityInfo]
],
dispatch_key: str = "Default",
) -> tuple[DifferentiabilityInfo | None, bool]:
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
The second return value indicates whether the info is generated in this function.
"""
ref_diff_info: DifferentiabilityInfo | None = None
for function_schema, diff_info in functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
continue
ref_diff_info = diff_info[dispatch_key]
if ref_diff_info is not None:
break
# note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
# while the info of `zero_` is in non_functional_info_by_signature
if (
ref_diff_info is None
and foreach_function.func.kind() == SchemaKind.inplace
and str(foreach_function.func.name) in _foreach_with_inplace_ref
):
for function_schema, diff_info in non_functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
continue
ref_diff_info = diff_info[dispatch_key]
if ref_diff_info is not None:
break
if ref_diff_info is None:
return None, False
# non out-place uses the existing Derivative.
if foreach_function.func.kind() == SchemaKind.inplace:
return ref_diff_info, False
map_refarg2foreacharg, map_name2arg = {}, {}
for i, (arg, ref_arg) in enumerate(
zip(
foreach_function.func.arguments.flat_non_out,
function_schema.arguments.flat_non_out,
)
):
map_refarg2foreacharg[ref_arg.name] = arg.name
map_name2arg[arg.name] = arg
all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
modified_derivative_formulas = []
for i, derivative in enumerate(ref_diff_info.derivatives):
modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
"result", "result[i]"
)
saved_inputs, saved_outputs = [], []
# note(crcrpar): This context seems necessary to call `cpp.argument_type`
with local.parametrize(
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
):
for ref_input in derivative.saved_inputs:
ref_input_jit_name = ref_input.expr.split(".")[0]
mapped_name = map_refarg2foreacharg[ref_input_jit_name]
if isinstance(map_name2arg[mapped_name].type, ListType):
mapped_expr = mapped_name + "[i]"
else:
mapped_expr = mapped_name
new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
modified_formula = modified_formula.replace(
cast(str, ref_input.nctype.name), new_expr
)
nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
canonical_nctype = NamedCType(
nctype.name, nctype.type.remove_const_ref()
)
saved_inputs.append(
SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
)
for ref_output in derivative.saved_outputs:
if ref_output.nctype.name == "result":
saved_outputs.append(
SavedAttribute(
nctype=NamedCType(
name="result", type=BaseCType(tensorListT)
),
expr="result",
)
)
else:
raise RuntimeError("")
var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
all_var_names.extend(var_names)
all_saved_inputs.extend(saved_inputs)
all_saved_outputs.extend(saved_outputs)
modified_derivative = Derivative(
formula=modified_formula,
original_formula=derivative.formula,
var_names=tuple(var_names),
saved_inputs=tuple(saved_inputs),
saved_outputs=tuple(saved_outputs),
named_gradients=set(),
)
modified_derivative_formulas.append(modified_derivative)
with local.parametrize(
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
):
args_with_derivatives = [
Binding(
name=arg.name,
nctype=cpp.argument_type(arg, binds=arg.name),
argument=arg,
default=None,
)
for arg in foreach_function.func.arguments.flat_non_out
if arg.name in all_var_names
]
forward_derivatives: list[ForwardDerivative] = []
fw_derivative: ForwardDerivative
for fw_derivative in ref_diff_info.forward_derivatives:
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: list[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: list[str] = []
required_inputs_primal: list[str] = []
if fw_derivative.required_inputs_fw_grad is not None:
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
if fw_derivative.required_inputs_primal:
required_inputs_primal = list(fw_derivative.required_inputs_primal)
modified_formula = fw_derivative.formula
# Foreach's result is TensorList
if "result" in modified_formula:
modified_formula = fw_derivative.formula.replace("result", "result[i]")
for foreach_arg, ref_arg in zip(
foreach_function.func.arguments.flat_non_out,
ref_diff_info.func.func.arguments.flat_non_out,
):
# Modify reference forward formula
if (
isinstance(foreach_arg.type, ListType)
and not foreach_arg.type.is_tensor_like()
):
# Assuming ScalarList
modified_formula = modified_formula.replace(
ref_arg.name, foreach_arg.name + "[i]"
)
elif foreach_arg.type.is_tensor_like():
# Assuming TensorList / Tensor
# assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
assert isinstance(foreach_arg.type, ListType) or (
foreach_arg.type == BaseType(BaseTy.Tensor)
and str(foreach_function.func.name) in _foreach_with_tensor_overload
), f"{foreach_function.func.name}, {foreach_arg.type}"
for suffix in ("_p", "_t"):
curr_expr = ref_arg.name + suffix
if curr_expr in modified_formula:
new_expr = foreach_arg.name + suffix
modified_formula = modified_formula.replace(curr_expr, new_expr)
else:
# Assuming Scalar
if foreach_arg.name != ref_arg.name:
modified_formula = modified_formula.replace(
ref_arg.name, foreach_arg.name
)
# note(crcrpar): there should exist a cooler way...
for i, name in enumerate(var_names):
if name == ref_arg.name:
var_names[i] = foreach_arg.name
var_types[i] = foreach_arg.type
for i, name in enumerate(required_inputs_fw_grad):
if name == ref_arg.name:
required_inputs_fw_grad[i] = foreach_arg.name
for i, name in enumerate(required_inputs_primal):
if name == ref_arg.name:
required_inputs_primal[i] = foreach_arg.name
forward_derivatives.append(
ForwardDerivative(
formula=modified_formula,
var_names=tuple(var_names),
var_types=tuple(var_types),
required_inputs_fw_grad=tuple(required_inputs_fw_grad),
required_inputs_primal=tuple(required_inputs_primal),
required_original_self_value=fw_derivative.required_original_self_value,
is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
)
)
return (
DifferentiabilityInfo(
name=foreach_function.func.name.name.base,
func=foreach_function,
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
derivatives=modified_derivative_formulas,
forward_derivatives=forward_derivatives,
all_saved_inputs=tuple(set(all_saved_inputs)),
all_saved_outputs=tuple(set(all_saved_outputs)),
available_named_gradients=(),
used_named_gradients=set(),
args_with_derivatives=args_with_derivatives,
non_differentiable_arg_names=[],
output_differentiability=None,
output_differentiability_conditions=None,
),
True,
)
def match_differentiability_info(
native_functions: list[NativeFunction],
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
) -> list[NativeFunctionWithDifferentiabilityInfo]:
"""Sets the "derivative" key on declarations to matching autograd function
In-place functions will use the out-of-place derivative definition if there
is no in-place specific derivative.
"""
functional_info_by_signature = {
schema.signature(strip_default=True): info_dict
for schema, info_dict in differentiability_infos.items()
if schema.kind() == SchemaKind.functional
}
non_functional_info_by_signature = {
schema.signature(strip_default=True): info_dict
for schema, info_dict in differentiability_infos.items()
if schema.kind() != SchemaKind.functional
}
def find_info(
f: NativeFunction,
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
# Don't bother matching info to generated out= variants
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
return None, False
# (1) Check for an exact match
if f.func in differentiability_infos:
return differentiability_infos[f.func], True
# (2) If no exact match, check if the out-of-place variant
# of this operator has a match.
# i.e mul() for mul_() or mul_out()
# note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
# native functions instead of the out-place counterparts.
f_sig = f.func.signature(strip_default=True)
if f_sig in functional_info_by_signature and not is_foreach_func(f):
return functional_info_by_signature[f_sig], False
# (3) Some operators have a derivative explicitly defined for the mutable
# variant, but get a code-generated out-of-place variant which does *not*
# come with a derivative formula.
# For the generated out-of-place variant, use the mutable variant's formula
# if it exists.
if "generated" in f.tags and f_sig in non_functional_info_by_signature:
info_dict = non_functional_info_by_signature[f_sig]
# See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
assert not any(
any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
for info in info_dict.values()
), f"""\
Attempted to convert a derivative formula for a mutable operator
to be used by automatically by its functional variant ("{str(f.func)}").
this is not currently supported (we'd need to fix up the formula in the codegen)."""
return info_dict, False
# (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
if is_foreach_func(f):
assert f.func not in differentiability_infos
diff_info, is_generated = gen_foreach_derivativeinfo(
f,
functional_info_by_signature,
non_functional_info_by_signature,
)
if diff_info is None:
return None, False
# TODO(crcrpar): Avoid hard coding "Default" ideally.
diff_info_dict = {"Default": diff_info}
if is_generated:
differentiability_infos[f.func] = diff_info_dict
functional_info_by_signature[f.func] = diff_info_dict
return diff_info_dict, is_generated
return None, False
result: list[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions:
info_dict, is_exact_match = find_info(f)
# Currently, the '.strides()' to 'strides_or_error' replacement does not support
# 'self' derivatives of an inplace function, so we must check for this case.
if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
for info in info_dict.values():
for derivative in info.derivatives:
if "self" in derivative.var_names:
for saved_input in derivative.saved_inputs:
assert "strides_or_error" not in saved_input.expr, (
"Calling '.strides()' in the 'self' derivative formula of an "
f"in-place function is not supported: {f.func}"
)
if not info_dict:
result.append(
NativeFunctionWithDifferentiabilityInfo(
func=f, info=None, fw_derivatives=None
)
)
continue
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
for key, info in info_dict.items():
if not info.forward_derivatives:
fw_derivative_dict[key] = []
continue
forward_derivatives = info.forward_derivatives
# For functions that have a single def for out-of-place and inplace (like abs())
if f.func.kind() == SchemaKind.inplace:
# For inplace functions there is a little bit of work to do:
# 1) Validate the formula and make sure the input that is modified in not used:
# - If there is a formula for the inplace variant of the function (is_exact_match == True) then
# we make sure that the original value of the input that is being modified inplace (self_p) is
# not used in the formula. Note that the formula can use "original_self_p" here and that would
# trigger a clone of the original input.
# - If we are re-using the out of place formula (is_exact_match == False) then we replace every
# occurrence of self_p and self_t by original_self_p and original_self_t. These will be
# populated by cloned version of the original input (either the clone done by the backward AD
# logic if self is also used in a backward formula or a special clone that we add).
# 2) At this point, there cannot be a self_p in the formula.
# 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
# simply called self (as it is modified inplace).
# 4) Update the required primals data in case it used to contain "result" but should now contain
# "self"
# 5) If it is not an exact match, the user formula is not modifying the existing forward grad
# inplace as it should. So add some code that makes sure that we do so if the forward grad
# already exists.
assert (
len(info.forward_derivatives) == 1
) # Only single output inplace should exist
fw_info = info.forward_derivatives[0]
formula = fw_info.formula
def replace_self_with_original_self(formula: str, postfix: str) -> str:
def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
if re.search(IDENT_REGEX.format("self_p"), formula):
if is_exact_match:
# For manually defined formulas, don't allow the original value to be used
raise RuntimeError(
f'The formula for "{f.func.name}" is using the original value of self '
"that is being modified inplace. This would lead to wrong forward gradients. "
'Please use "result" in the formula only.'
)
else:
# When the original formula is out of place, we save a clone of the primal
# value to be able to access this value if needed
# replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
formula = replace_self_with_original_self(formula, "_p")
formula = replace_self_with_original_self(formula, "_t")
# replace "result" from the formula by "self_p"
def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}self_p{m.group(2)}"
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
required_primals = fw_info.required_inputs_primal
if re.search(IDENT_REGEX.format("self_p"), formula):
required_primals = (
required_primals + ("self",) if required_primals else ("self",)
)
if not is_exact_match:
# NOTE [In-place forward AD formula Optimization]
#
# This optimization transforms the formula to directly do inplace, i.e.
# instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
#
# 1) the formula satisfies the pattern: "self_t.op(*args)"
# 2) "op" in (1) needs to be the same as the op the derivative is for
#
# (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
# If there is a need, we can relax (2) to allow any op that has an in-place variant
is_single_method_on_self_t = False
directly_do_inplace = False
op_name: str | None = None
between_parens: str | None = None
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
if match:
op_name, between_parens = match.group(1), match.group(2)
# We want to...
# Match: self_t.op1(other_p.op2(arg))
# Avoid: self_t.op1(args) + self_t.op2(args)
# Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
def check_parens_nest_level_gt_zero(s: str) -> bool:
level = 1
for ch in s:
if ch == ")":
level -= 1
if level == 0:
return False
if ch == "(":
level += 1
return True
is_single_method_on_self_t = check_parens_nest_level_gt_zero(
between_parens
)
directly_do_inplace = (
is_single_method_on_self_t and op_name == info.name
)
if directly_do_inplace:
assert op_name is not None
assert between_parens is not None
formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
else:
# Make sure that the forward grad is modified inplace when the original formula
# is out of place
formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
required_original_self_value = bool(
re.search(IDENT_REGEX.format("original_self_p"), formula)
) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
forward_derivatives = [
ForwardDerivative(
formula=formula,
var_names=("self",),
var_types=fw_info.var_types,
required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
required_inputs_primal=required_primals,
required_original_self_value=required_original_self_value,
is_reusing_outplace_formula=not is_exact_match,
),
]
fw_derivative_dict[key] = forward_derivatives
result.append(
NativeFunctionWithDifferentiabilityInfo(
func=f, info=info_dict, fw_derivatives=fw_derivative_dict
)
)
return result
def is_differentiable(
name: str, type: Type, info: DifferentiabilityInfo | None
) -> bool:
return type.is_tensor_like() and (
info is None or name not in info.non_differentiable_arg_names
)
def gen_differentiable_outputs(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> list[DifferentiableOutput]:
f = fn.func
info = fn.info[key] if fn.info else None
outputs: list[DifferentiableOutput] = [
DifferentiableOutput(
name=name,
type=ret.type,
cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
)
for name, ret in zip(cpp.return_names(f), f.func.returns)
]
output_differentiability = info.output_differentiability if info else None
if output_differentiability is not None:
if len(output_differentiability) != len(outputs):
raise RuntimeError(
f"The length of output_differentiability ({len(output_differentiability)}), "
f"does not match the number of outputs ({len(outputs)})."
)
differentiable_outputs: list[DifferentiableOutput] = []
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
raise RuntimeError(
"output_differentiability=False for inplace operation (version_counter won't get updated)"
)
for differentiable, output in zip(output_differentiability, outputs):
if differentiable:
differentiable_outputs.append(output)
return differentiable_outputs
candidate_differentiable_outputs = list(
filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
)
if uses_single_grad(info):
return candidate_differentiable_outputs[:1]
else:
return candidate_differentiable_outputs

View File

@ -0,0 +1,472 @@
from __future__ import annotations
from typing import Sequence
from torchgen import local
from torchgen.api.types import (
ArgName,
ArrayCType,
ArrayRefCType,
BaseCType,
BaseTypeToCppMapping,
Binding,
boolT,
ConstRefCType,
CType,
dimnameListT,
intArrayRefT,
iTensorListRefT,
ListCType,
longT,
MutRefCType,
NamedCType,
OptionalCType,
optionalIntArrayRefT,
optionalSymIntArrayRefT,
scalarT,
SpecialArgName,
symIntArrayRefT,
SymIntT,
tensorListT,
tensorOptionsT,
tensorT,
TupleCType,
VectorCType,
voidT,
)
from torchgen.model import (
Argument,
Arguments,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
OptionalType,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
#
# Prominent characteristics of the C++ API:
#
# - dtype, layout, device and pin_memory are collected into
# a single C++ type TensorOptions (the native functions API
# also has this, but tensor options is really most relevant
# for the C++ API; it makes calling kwarg factory functions
# pleasant)
#
# - defaulting lives here (in fact, the dispatcher is completely
# oblivious of defaults!)
#
# BTW: policy on name collisions: we try not to have types with
# collisions, but functions are fair game to collide
def name(
func: FunctionSchema,
*,
faithful_name_for_out_overloads: bool = False,
symint_overload: bool = False,
) -> str:
name = str(func.name.name)
if symint_overload:
name += "_symint"
if func.is_out_fn():
if faithful_name_for_out_overloads:
name += "_outf"
else:
name += "_out"
return name
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(
t: Type,
*,
binds: ArgName,
mutable: bool = True,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
) -> NamedCType | None:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
elif str(t) == "SymInt":
if symint:
return NamedCType(binds, BaseCType(SymIntT))
else:
return NamedCType(binds, BaseCType(longT))
if remove_non_owning_ref_types:
if t.name == BaseTy.str:
raise AssertionError(
"string ref->value conversion: not implemented yet"
)
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if str(t.elem) == "bool":
assert t.size is not None
return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occurring in JIT arguments to a C++ argument type.
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
# For example, we'll return std::vector<int> instead of IntArrayRef.
# See Note [translation from C++ reference to value types]
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
) -> NamedCType:
# If it's a value type, do the value type translation
r = valuetype_type(
t,
binds=binds,
mutable=mutable,
symint=symint,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == "Tensor":
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(
binds, MutRefCType(BaseCType(tensorT))
) # TODO: fix this discrepancy
else:
return NamedCType(
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
)
elif str(t.elem) == "Scalar":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
if symint:
return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
else:
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
# TODO: remove these special cases, ArrayRef fallthrough works fine
if str(t.elem) == "int":
if remove_non_owning_ref_types:
return NamedCType(binds, VectorCType(BaseCType(longT)))
else:
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "SymInt":
if remove_non_owning_ref_types:
if symint:
return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
else:
return NamedCType(binds, VectorCType(BaseCType(longT)))
else:
if symint:
return NamedCType(binds, BaseCType(symIntArrayRefT))
else:
return NamedCType(binds, BaseCType(intArrayRefT))
if str(t.elem) == "Tensor":
if local.use_ilistref_for_tensor_lists():
return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
else:
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Scalar":
return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
elif str(t.elem) == "Dimname":
return NamedCType(binds, BaseCType(dimnameListT))
elif str(t.elem) == "Tensor?":
return NamedCType(
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
# N.B: returntype_type returns a CType, not a NamedCType.
# This is mostly because of the mismatch between return types and return names.
# e.g. a function with a return type of 'void' has 0 return names,
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
# placeholder is ignored
# NB: symint is ALWAYS respected for return types. So symint argument
# here is IGNORED
r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
if r is not None:
return r.type
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
if local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
else:
return MutRefCType(BaseCType(tensorT))
else:
# Note [Tensor Copy Returns]
# Currently, we use "Argument.is_write" to determine
# whether or not Tensor return types should be copies or references.
# If that ever changes, take a look at other locations of this note!
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)
elif isinstance(t, OptionalType):
elem = returntype_type(t.elem, mutable=mutable)
if str(t.elem) == "Tensor":
return OptionalCType(elem)
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return, *, symint: bool = False) -> CType:
return returntype_type(r.type, mutable=r.is_write, symint=symint)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0], symint=symint)
else:
return TupleCType([return_type(r, symint=symint) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: list[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
# TODO: Consider incorporating this into the data model
if f.func.name.name.inplace:
assert i == 0, "illegal inplace function with multiple returns"
name = "self"
# If we are out function, the name is the name of the
# corresponding output function (r.name will get recorded
# in field_name later.)
elif f.func.is_out_fn():
name = f.func.arguments.out[i].name
# If the return argument is explicitly named...
elif r.name:
name_conflict = any(
r.name == a.name for a in f.func.schema_order_arguments()
)
if name_conflict and not f.func.is_out_fn():
name = f"{r.name}_return"
else:
name = r.name
# If there is no explicit name and no fallback name was passed in, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
returns.append(name)
return returns
JIT_TO_CPP_DEFAULT = {
"False": "false",
"True": "true",
"None": "::std::nullopt", # UGH this one is type directed
"Mean": "at::Reduction::Mean",
"[]": "{}",
"contiguous_format": "c10::MemoryFormat::Contiguous",
"long": "at::kLong",
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type, *, symint: bool) -> str:
if d == "None" and str(t) == "Tensor?":
return "{}"
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ""
i = 1
while i + 1 < len(d):
if d[i] != "\\":
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i : i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == "None":
return "::std::nullopt"
return default_expr(d, t.elem, symint=symint)
if isinstance(t, ListType):
if d.startswith("[") and d.endswith("]"):
return "{" + d[1:-1] + "}"
elif symint and d.isdigit() and str(t.elem) == "SymInt":
return f"c10::SymInt({d})"
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument(
a: Argument | TensorOptionsArguments | SelfArgument,
*,
cpp_no_default_args: set[str],
method: bool,
faithful: bool,
symint: bool = False,
has_tensor_options: bool,
) -> list[Binding]:
def sub_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
method=method,
faithful=faithful,
symint=symint,
has_tensor_options=has_tensor_options,
)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: str | None = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type, symint=symint)
return [
Binding(
nctype=argument_type(a, binds=binds, symint=symint),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, TensorOptionsArguments):
if faithful:
return (
sub_argument(a.dtype)
+ sub_argument(a.layout)
+ sub_argument(a.device)
+ sub_argument(a.pin_memory)
)
else:
default = None
# Enforced by NativeFunction.__post_init__
assert "options" not in cpp_no_default_args
if all(x.default == "None" for x in a.all()):
default = "{}"
elif a.dtype.default == "long":
default = "at::kLong" # TODO: this is wrong
return [
Binding(
nctype=NamedCType("options", BaseCType(tensorOptionsT)),
name="options",
default=default,
argument=a,
)
]
elif isinstance(a, SelfArgument):
if method:
# Caller is responsible for installing implicit this in context!
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
arguments: Arguments,
*,
faithful: bool,
symint: bool = False,
method: bool,
cpp_no_default_args: set[str],
) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
else:
args.extend(arguments.out)
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r
for a in args
for r in argument(
a,
faithful=faithful,
symint=symint,
method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args,
)
]

View File

@ -0,0 +1,120 @@
from __future__ import annotations
import itertools
from typing import Sequence
from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType
from torchgen.model import (
Argument,
FunctionSchema,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never, concatMap
# This file describes the translation of JIT schema to the dispatcher
# API, the *unboxed* calling convention by which invocations through
# the dispatcher are made. Historically, the dispatcher API matched
# the C++ API, but with the establishment of the boxed API, we've
# made changes to the dispatcher API to so that the unboxed API
# better aligns with the boxed API. The dispatcher API hooks heavily
# into our template based boxing/unboxing machinery, so changes
# to this convention will usually need template updates too.
#
# Prominent characteristics of the dispatcher API:
#
# - dtype, layout, device and pin_memory are represented as separate
# arguments.
#
def name(func: FunctionSchema) -> str:
return cpp.name(func)
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = True,
) -> NamedCType:
# This is a faux amis. If it makes sense in the future to add
# more special cases here, or invert things so cpp.argument_type
# calls this, or just completely inline the function, please do
# it.
return cpp.argumenttype_type(
t,
mutable=mutable,
binds=binds,
symint=symint,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
def argument_type(
a: Argument,
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = True,
) -> NamedCType:
return argumenttype_type(
a.type,
mutable=a.is_write,
binds=binds,
remove_non_owning_ref_types=remove_non_owning_ref_types,
symint=symint,
)
def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
# At present, there is no difference. But there could be!
return cpp.returns_type(rs, symint=symint)
def jit_arguments(func: FunctionSchema) -> list[Argument]:
def to_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Argument]:
if isinstance(a, Argument):
return [a]
elif isinstance(a, SelfArgument):
return [a.argument]
elif isinstance(a, TensorOptionsArguments):
return [a.dtype, a.layout, a.device, a.pin_memory]
else:
assert_never(a)
return list(
concatMap(
to_argument,
itertools.chain(
func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
),
)
)
def argument(
a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
) -> Binding:
return Binding(
nctype=argument_type(
a,
binds=a.name,
remove_non_owning_ref_types=remove_non_owning_ref_types,
symint=symint,
),
name=a.name,
argument=a,
)
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
return [argument(a, symint=symint) for a in jit_arguments(func)]

View File

@ -0,0 +1,199 @@
from __future__ import annotations
from torchgen.api import dispatcher
from torchgen.api.types import (
BaseCppType,
BaseCType,
Binding,
boolT,
ConstRefCType,
CType,
longT,
NamedCType,
tensorT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
NativeFunction,
NativeFunctionsViewGroup,
)
# This file describes the translation of JIT schema to API's used
# when creating view lambdas that are used by the functionalization pass.
# There are two types of lambdas: forward lambdas and reverse lambdas.
# These API's mostly follow the dispatcher API, with a few quirks:
# - The lambda capture has to convert reference types to value types
# - While the forward lambda just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse lambda
# is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace).
# The lambdas generated for each view op in the functionalization pass are of the form
# [capture_arguments](outer_arguments) -> returns_type {
# return name(inner_arguments);
# }
# Define some specific lambda input arguments.
base_binding = Binding(
name="base",
nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
mutated_view_binding = Binding(
name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
mutated_view_idx_binding = Binding(
name="mutated_view_idx",
nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
reapply_views_binding = Binding(
name="reapply_views",
nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
argument=Argument(
name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
),
default=None,
)
InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
inverse_return_mode_binding = Binding(
name="inverse_return_mode",
nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
argument=Argument(
name="inverse_return_mode",
# NB: not actually a bool but it doesn't matter because this isn't used
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
# The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda.
def name(
g: NativeFunctionsViewGroup,
*,
is_reverse: bool,
include_namespace: bool,
reapply_views: bool | None = None,
) -> str:
if reapply_views is None:
# reapply_views is only important for the fwd lambda,
# since we always plumb the runtime "reapply_views" argument into the reverse function.
assert is_reverse
if is_reverse:
return reverse_name(g.view, include_namespace)
# in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
assert include_namespace
assert g.view_copy is not None
api_name = (
g.view.func.name.unambiguous_name()
if reapply_views
else g.view_copy.func.name.unambiguous_name()
)
return f"at::_ops::{api_name}::call"
def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
# for the reverse: we plumb the "reapply_views" flag into that function and support
# both copy and non-copy variants. (We could avoid doing that, but that would require
# writing out twice as many view inverse functions).
api_name = f.func.name.unambiguous_name()
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
if include_namespace:
return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
else:
return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [
inverse_return_mode_binding if is_reverse else reapply_views_binding
]
all_bindings.extend(non_self_value_bindings)
return all_bindings
def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1
for ret in func.returns:
assert ret.type.is_tensor_like()
# However, the return type of the lambda is always an individual tensor.
# For multi-tensor outputs, each tensor needs to be tracked individually.
return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
if is_reverse:
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
else:
return [base_binding, mutated_view_idx_binding]
def inner_call_index(func: FunctionSchema) -> Binding | None:
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
if len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like()
):
return mutated_view_idx_binding
return None
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
# The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
# Both of these follow the dispatcher API.
non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
if not is_reverse:
# the forward lambda swaps out the original tensor argument with the lambd arg "base"
return [base_binding] + non_self_bindings
else:
# the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument.
index_binding = inner_call_index(func)
if index_binding is not None:
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
index_binding,
] + non_self_bindings
else:
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
] + non_self_bindings

View File

@ -0,0 +1,467 @@
from __future__ import annotations
from typing import Any
from torchgen.api.types import (
BaseCppType,
BaseCType,
boolT,
CType,
deviceT,
doubleT,
generatorT,
layoutT,
ListCType,
longT,
memoryFormatT,
NamedCType,
OptionalCType,
scalarT,
scalarTypeT,
stringT,
SymIntT,
VectorCType,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
ListType,
OperatorName,
OptionalType,
Return,
TensorOptionsArguments,
Type,
)
_valueT: BaseCppType | None = None
# A ValueT is an IR type which represents the computation of a Tensor. In other
# words, a PyTorch user will do operations on lazy tensors, and each output lazy
# tensor internally tracks a ValueT representing the IR node that would have
# actually produced the value of this tensor for real.
#
# This is configurable because different lazy tensor backends (LTC vs XLA) will
# have different IR representations. (Though, arguably, after unification they
# shouldn't!)
def getValueT() -> BaseCppType:
global _valueT
if not _valueT:
raise NotImplementedError(
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
)
return _valueT
def setValueT(val: BaseCppType) -> None:
global _valueT
_valueT = val
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type(
typ: Type, properties: LazyIrProperties, *, symint: bool
) -> BaseCType | VectorCType | OptionalCType | ListCType:
"""
This function takes a type from NativeFunctions and converts it for use with
lazy tensor codegen.
Type conversion for lazy currently consists of
(1) changing at::Tensors into lazy::Values
(2) wrapping everything in a BaseCType
(3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
(1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
This is incomplete- there are assertions in places that it's expected to need to add
more types as the codegen is used with more operators.
"""
if isinstance(typ, BaseType):
if typ.name == BaseTy.Tensor:
return BaseCType(getValueT())
elif typ.name == BaseTy.Scalar:
if properties.TreatScalarsAsConstants:
return BaseCType(scalarT)
# at::scalar has special handling,
# and is wrapped in an lazy::Value just like at::tensor
return BaseCType(getValueT())
elif typ.name == BaseTy.ScalarType:
return BaseCType(scalarTypeT)
elif typ.name == BaseTy.int:
return BaseCType(longT)
elif typ.name == BaseTy.SymInt:
if symint:
return BaseCType(getValueT())
else:
return BaseCType(longT)
elif typ.name == BaseTy.bool:
return BaseCType(boolT)
elif typ.name == BaseTy.float:
return BaseCType(doubleT)
elif typ.name == BaseTy.str:
return BaseCType(stringT)
elif typ.name == BaseTy.Device:
return BaseCType(deviceT)
elif typ.name == BaseTy.Generator:
return BaseCType(generatorT)
elif typ.name == BaseTy.Layout:
return BaseCType(layoutT)
elif typ.name == BaseTy.MemoryFormat:
return BaseCType(memoryFormatT)
else:
raise AssertionError(f"TODO add support for type {repr(typ)}")
elif isinstance(typ, OptionalType):
return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
elif isinstance(typ, ListType):
if str(typ.elem) == "Tensor?":
# TODO(whc) is this actually correct? or should it use a Vector like above
return ListCType(OptionalCType(BaseCType(getValueT())))
elif str(typ.elem) == "Tensor":
# this is a TensorList which comes in from GetTensorList as a Value
return BaseCType(tensorListValueT)
elif typ.elem == BaseType(BaseTy.SymInt):
# TODO: return a value type. The problem here is analogous to
# the problem with tensorListValueT: if you have SymInt[] you
# cannot conveniently save the list of Value directly, as nodes
# expect to save values as a vector for ALL arguments. So you
# need a separate IR node that represents all of the size nodes
# assembled into a list. I'm not an LTC dev so I don't want to
# figure it out right now. Y'all figure it out...
return VectorCType(BaseCType(longT))
else:
return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
else:
raise AssertionError(f"unrecognized type {repr(typ)}")
# TODO: Determining this based off of CType is bad; this should be computed
# from Type directly; then the same logic as process_ir_type can be used
#
# Invariant: passed typ should be an *owning* CType (e.g., we will report
# that ArrayRef<Value> is NOT a value type)
def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed.
"""
if isinstance(typ, BaseCType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
return (
typ.type == getValueT()
or (typ.type == scalarT and not treat_scalars_as_constants)
or typ.type == SymIntT
)
elif typ == VectorCType(BaseCType(SymIntT)):
# TODO: report True for this
return False
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
return isValueType(typ.elem, properties)
return False
def isSymIntType(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
def isWrappedScalarType(typ: Type) -> bool:
"""
Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
Since we literally change the type from scalarT to valueT, information is lost.
This function helps build a list of wrapped scalars to save that information
"""
if isinstance(typ, BaseType):
# I am regretting my naming conventions, but now we are wrapping at::scalar in
# lazy value, while preserving other 'scalar' types as scalars in the IR
return typ.name == BaseTy.Scalar
elif isinstance(typ, (OptionalType, ListType)):
return isWrappedScalarType(typ.elem)
return False
# TODO: dedupe with Type.is_generator_like
def isGeneratorType(typ: Type) -> bool:
if isinstance(typ, BaseType):
return typ.name == BaseTy.Generator
elif isinstance(typ, (OptionalType)):
return isGeneratorType(typ.elem)
return False
# This class caches a few derived properties computed from an Argument
# and LazyIrProperties
class LazyArgument:
name: str
orig_type: Type
lazy_type_: CType | None
is_wrapped_scalar: bool
is_generator: bool
# TODO: this is lies, it is false for symint list
is_symint_or_list: bool
# Whether or not we are treating this as symint or not
symint: bool
# true if this argument is or contains a lazy IR value
is_lazy_value: bool
def __init__(
self, arg: Argument, properties: LazyIrProperties, *, symint: bool
) -> None:
self.name = arg.name
self.orig_type = arg.type
self.symint = symint
self.is_optional = isinstance(arg.type, OptionalType)
self.is_generator = isGeneratorType(arg.type)
self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = symint and (
isSymIntType(arg.type)
or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
# TODO: lists of symints are not currently treated as value types
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
)
self.is_lazy_value = isValueType(self.lazy_type, properties)
@property
def lazy_type(self) -> CType:
assert (
self.lazy_type_ is not None
), f"Attempted to access lazy_type for invalid argument {self.name}"
return self.lazy_type_
class LazyIrProperties:
"""Collection of properties for an IR node
The property groups are listed below. Each group is mutually
exclusive, meaning that only one property from each group can be True
at any one time. The properties can be accessed as if they were normal
attributes. The mutual exclusivity is automatically handled.
"""
Properties: tuple[tuple[str, ...], ...] = (
(
"ShapePrecompute", # Assume shape has been precomputed
"ShapeCompute", # Need to compute the shape on construction
"ShapeCache", # Utilize the shape cache to defer computation
),
(
"Lower", # Codegen full lower function
"LowerDeclOnly", # Codegen only lower function declaration
),
(
"CanBeReused", # Codegen full reuse function
"CanBeReusedDeclOnly", # Codegen only reuse function declaration
),
(
"CreateFn", # Codegen full create function
"CreateFnDeclOnly", # Codegen only create function declaration
),
(
"TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
),
)
def __init__(self, *default_properties: str) -> None:
properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
LazyIrProperties.Properties
)
self.__dict__["properties"] = properties
for p in default_properties:
setattr(self, p, True)
def __getattr__(self, key: str) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
return properties[values] == key
return self.__getattribute__(key)
def __setattr__(self, key: str, value: Any) -> Any:
properties = self.__dict__["properties"]
for values in LazyIrProperties.Properties:
if key in values:
properties[values] = key if value else None
return value
raise KeyError(f"Invalid property: {key}")
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
#
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
class LazyIrSchema:
# The name of the operator this function schema describes.
name: OperatorName
positional_args: tuple[LazyArgument, ...]
keyword_args: tuple[LazyArgument, ...]
# TODO: Need to handle collisions with argument names at some point
returns: tuple[Return, ...]
# if this schema has a Generator arg, list its orig ctype/name but don't
# build a LazyArgument since lazy IR doesn't support it
generator_arg: NamedCType | None = None
# original function schema
func: FunctionSchema
# Whether or not we are code-genning for SymInt or not
symint: bool
properties: LazyIrProperties = LazyIrProperties(
# default properties
"ShapePrecompute",
"Lower",
"CanBeReused",
)
opkind: str | None = None
def __init__(
self,
func: FunctionSchema,
properties: LazyIrProperties | None = None,
*,
symint: bool,
) -> None:
if properties:
self.properties = properties
self.func = func
self.symint = symint
positional_args: list[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = func.arguments.self_arg.argument
positional_args.append(
LazyArgument(arg, self.properties, symint=symint)
)
elif getattr(func.arguments, arg_field) is not None:
positional_args.extend(
LazyArgument(arg, self.properties, symint=symint)
for arg in getattr(func.arguments, arg_field)
)
self.positional_args = tuple(positional_args)
keyword_args: list[LazyArgument] = []
for arg_field in [
"pre_tensor_options_kwarg_only",
"tensor_options",
"post_tensor_options_kwarg_only",
"out",
]:
curr_args = getattr(func.arguments, arg_field)
if curr_args is not None:
if isinstance(curr_args, TensorOptionsArguments):
curr_args = curr_args.all()
for arg in curr_args:
if isGeneratorType(arg.type):
assert (
self.generator_arg is None
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(
arg.name, arg.type # type:ignore[arg-type]
)
keyword_args.extend(
LazyArgument(arg, self.properties, symint=symint)
for arg in curr_args
)
self.keyword_args = tuple(keyword_args)
self.name = func.name
self.returns = func.returns
@property
def node_name(self) -> str:
"""
Return camel-case version of op in node.
Note: This function also appends any `overload_name` in the operation.
For example, if the op is `bitwise_and.Tensor`, the returned name
will be `BitwiseAndTensor`.
"""
op_name = f"{self.name.name}_{self.name.overload_name}".lower()
return "".join(word.capitalize() or "" for word in op_name.split("_"))
@property
def aten_name(self) -> str:
return str(self.name.name)
@property
def base_name(self) -> str:
return f"{self.name.name.base}"
def filtered_args(
self,
positional: bool = True,
keyword: bool = True,
values: bool = True,
scalars: bool = True,
generator: bool = True,
) -> list[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings),
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
# in TS lowerings and therefore also omitted from lazy IR.
args: list[LazyArgument] = []
if positional:
args.extend(self.positional_args)
if keyword:
args.extend(self.keyword_args)
if values and scalars and generator:
return args
elif values and scalars:
return [a for a in args if not a.is_generator]
elif values:
return [a for a in args if a.is_lazy_value]
elif scalars:
return [
a
for a in args
if not a.is_lazy_value and (generator or not a.is_generator)
]
return []
@property
def positional_values(self) -> list[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=True, scalars=False
)
@property
def positional_scalars(self) -> list[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=False, scalars=True
)
@property
def keyword_values(self) -> list[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=True, scalars=False
)
@property
def keyword_scalars(self) -> list[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=False, scalars=True
)

View File

@ -0,0 +1,13 @@
from torchgen.model import NativeFunctionsGroup
# Follows dispatcher calling convention, but:
# - Mutable arguments not allowed. Meta functions are always
# written in functional form. Look at FunctionSchema.signature()
# - No tensor returns; instead we return a TensorMeta describing
# the tensor in question
def name(g: NativeFunctionsGroup) -> str:
# use the overload name from the functional version
return str(g.functional.func.name).replace(".", "_")

View File

@ -0,0 +1,155 @@
from __future__ import annotations
from typing import Sequence
from torchgen import local
from torchgen.api import cpp
from torchgen.api.types import (
ArgName,
BaseCType,
Binding,
boolT,
ConstRefCType,
CType,
deviceT,
layoutT,
ListCType,
MutRefCType,
NamedCType,
OptionalCType,
scalarT,
scalarTypeT,
tensorT,
)
from torchgen.model import (
Argument,
FunctionSchema,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
# This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the
# idea was you wrote native functions to implement functions in the C++ API),
# but over time we have evolved the C++ API without actually changing our
# native:: kernels. The intention is to make native API and dispatcher API
# line up as closely as possible, since this results in the least overhead
# (no translation is needed from dispatcher API to native API).
#
# NB: this is symint aware, you will get the non-SymInt variant for some
# dispatch entries and SymInt for others.
def name(func: FunctionSchema) -> str:
name = str(func.name.name)
# TODO: delete this!
if func.is_out_fn():
name += "_out"
if func.name.overload_name:
name += f"_{func.name.overload_name}"
return name
def argumenttype_type(
t: Type, *, mutable: bool, binds: ArgName, symint: bool
) -> NamedCType:
if str(t) == "Tensor?":
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(tensor_type))
else:
return NamedCType(binds, ConstRefCType(tensor_type))
elif str(t) == "Tensor?[]":
return NamedCType(
binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
)
elif str(t) == "Scalar":
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
elif str(t) == "Scalar?":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
return cpp.returns_type(rs, symint=symint)
def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
def argument(
a: Argument | SelfArgument | TensorOptionsArguments,
*,
is_out: bool,
symint: bool,
) -> list[Binding]:
# Ideally, we NEVER default native functions. However, there are a number
# of functions that call native:: directly and rely on the defaulting
# existing. So for BC, we generate defaults for non-out variants (but not
# for out variants, where it is impossible to generate an appropriate
# default)
should_default = not is_out
if isinstance(a, Argument):
default: str | None = None
if should_default and a.default is not None:
default = cpp.default_expr(a.default, a.type, symint=symint)
return [
Binding(
nctype=argument_type(a, binds=a.name, symint=symint),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, SelfArgument):
# Erase SelfArgument from the distinction
return argument(a.argument, is_out=is_out, symint=symint)
elif isinstance(a, TensorOptionsArguments):
default = None
if should_default:
default = "{}"
# TODO: Not sure why the arguments assigned here are for
# TensorOptionsArguments and not the constituent pieces. It seems
# to matter
return [
Binding(
nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
name="dtype",
default=default,
argument=a,
),
Binding(
nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
name="layout",
default=default,
argument=a,
),
Binding(
nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
name="device",
default=default,
argument=a,
),
Binding(
nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
name="pin_memory",
default=default,
argument=a,
),
]
else:
assert_never(a)
def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(func.arguments.non_out)
args.extend(func.arguments.out)
return [
r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,157 @@
from __future__ import annotations
from torchgen.api import cpp
from torchgen.api.types import (
ArgName,
ArrayRefCType,
BaseCType,
Binding,
ConstRefCType,
dimnameListT,
intArrayRefT,
iOptTensorListRefT,
iTensorListRefT,
NamedCType,
OptionalCType,
optionalIntArrayRefT,
optionalScalarRefT,
optionalTensorRefT,
scalarT,
tensorT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunctionsGroup,
OptionalType,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
# This file describes the translation of JIT schema to the structured functions API.
# This is similar to native API, but a number of historical problems with native
# API have been fixed.
# Translation of types occurring in JIT arguments to a C++ argument type.
# NB: For now, mutable doesn't do anything; but it could if we make
# some more nominal types
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
# If it's a value type, do the value type translation
# NB: structured kernels ALWAYS have symint off, since they involve actual
# kernels that require real ints. The one exception is the
# CompositeExplicitAutograd and the meta function (which could
# hypothetically be SymInt), but for simplicity we plan for these to just
# be handled in Python
r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if t.elem == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(optionalTensorRefT))
elif t.elem == BaseType(BaseTy.Scalar):
return NamedCType(binds, BaseCType(optionalScalarRefT))
elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
return NamedCType(binds, BaseCType(optionalIntArrayRefT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if t.elem == BaseType(BaseTy.Tensor):
return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
return NamedCType(binds, BaseCType(iOptTensorListRefT))
# TODO: delete these special cases; see torchgen.api.cpp--these
# must be changed in tandem, but there are problems; see
# https://github.com/pytorch/pytorch/pull/51485
elif str(t.elem) == "int":
return NamedCType(binds, BaseCType(intArrayRefT))
elif str(t.elem) == "Dimname":
return NamedCType(binds, BaseCType(dimnameListT))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# returns_type intentionally omitted, because structured kernels never "return";
# instead, they always indirectly report their outputs (in the case of a meta
# function, by calling set_output; in the case of an impl function, by writing
# directly into the provided out argument).
# Structured kernels are never defaulted
def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
if isinstance(a, Argument):
return [
Binding(
nctype=argument_type(a, binds=a.name),
name=a.name,
default=None,
argument=a,
)
]
elif isinstance(a, SelfArgument):
return argument(a.argument)
elif isinstance(a, TensorOptionsArguments):
raise AssertionError("structured kernels don't support TensorOptions yet")
else:
assert_never(a)
def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if g.out.precomputed:
# A list of parameters for the impl function with
# certain parameters replaced with precomputed counterparts
# as specified in native_functions.yaml.
non_out_args_replaced: list[
Argument | TensorOptionsArguments | SelfArgument
] = []
for a in g.out.func.arguments.non_out:
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
# If a is in precompute.replace, append the parameters
# that should replace it onto non_out_args_replaced.
non_out_args_replaced.extend(g.out.precomputed.replace[a.name])
else:
# If not, push a as it is.
non_out_args_replaced.append(a)
args.extend(non_out_args_replaced)
# g.out.precomputed.add is the list of parameters that are added
# without replacement after the non out args and just before the out args
args.extend(g.out.precomputed.add)
else:
args.extend(g.out.func.arguments.non_out)
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]
def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(g.functional.func.arguments.non_out)
return [r for arg in args for r in argument(arg)]
def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]

View File

@ -0,0 +1,433 @@
from __future__ import annotations
from typing import NoReturn, Sequence
from torchgen.api.types import (
ArrayRefCType,
BaseCType,
Binding,
boolT,
ConstRefCType,
deviceT,
Expr,
intArrayRefT,
iOptTensorListRefT,
layoutT,
ListCType,
longT,
memoryFormatT,
MutRefCType,
NamedCType,
opmath_t,
OptionalCType,
optionalIntArrayRefT,
optionalScalarRefT,
optionalSymIntArrayRefT,
optionalTensorRefT,
scalar_t,
scalarT,
scalarTypeT,
SpecialArgName,
symIntArrayRefT,
SymIntT,
tensorOptionsT,
tensorT,
VectorCType,
)
# This file implements a small program synthesis engine that implements
# conversions between one API to another.
#
# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType
# represents a C++ type, plus semantic information about what it represents.
# For example, consider the argument "bool pin_memory"; its normal C++ type is
# "bool", but its C++ semantic type also keeps track that this represents a
# "pin_memory"; you can't just use a random other boolean in a context where you
# need a "pin_memory"!
#
# The translator takes a list of needed NamedCTypes, and then figures out how
# to construct expressions with these NamedCTypes from the given bindings. Many
# of these expressions are trivial (I need a Tensor other; there's a Tensor
# other scope); others are more nontrivial and may require packing/unpacking.
# Some examples of non-trivial action:
#
# - Need the "dtype" binding? Well, maybe "dtype" isn't available
# in the context, instead, "options" is, and you need to extract
# it from there. (Gather)
#
# - Need the "context" binding? Well, maybe "context" isn't available
# in the context, and you need to construct it from "dtype", "device",
# etc. (Scatter)
#
# - Need the "memory_format" binding? Well, actually, it's available
# from both "memory_format" and "options", so you had better make sure
# they are consistent. (Join)
options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
longVec_ctype = VectorCType(BaseCType(longT))
longSymVec_ctype = VectorCType(BaseCType(SymIntT))
optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
class UnsatError(RuntimeError):
pass
# Given a set of in-scope bindings and a set of target bindings, synthesize
# a list of expressions that uses only the in-scope bindings (bindings) that
# have all of the types of goals. You may want to use this function if
# you're generating code for a function like:
#
# void f({args}) {
# g({exprs}); // g is a different API
# }
#
# and you need to generate "exprs".
#
# Typically, a list of Bindings is convenient to get (you usually call something
# like arguments() to get them); but technically you only need less information:
# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing
# something more complicated, e.g., tracking the set of bindings in a context,
# you may find using these smaller types more convenient.
def translate(
bindings: Sequence[Expr | Binding],
goals: Sequence[NamedCType | Binding],
*,
method: bool = False,
allow_expensive_conversions: bool = False,
) -> list[Expr]:
binding_exprs: list[Expr] = []
for b in bindings:
if isinstance(b, Binding):
binding_exprs.append(
Expr(
expr=b.name,
type=b.nctype,
)
)
else:
binding_exprs.append(b)
goal_ctypes: list[NamedCType] = []
for g in goals:
if isinstance(g, Binding):
goal_ctypes.append(g.nctype)
else:
goal_ctypes.append(g)
# Add all the bindings to the context
ctx: dict[NamedCType, str] = {}
for b in binding_exprs:
ctx[b.type] = b.expr
# While we're at it, do some simple forward inference, looking through
# constructors.
#
# NB: When should you do forward inference versus backward inference?
# The general idea:
#
# - Backward inference WHEN the goal gets smaller
# - Forward inference WHEN the hypothesis gets smaller
#
# This helps ensure termination: backward inference starts with a goal
# and tries to make it simpler and simpler until it's trivial; if the
# goal can grow in size, we blow up to a really huge goal size.
# Similarly, with forward inference we take hypotheses and decompose
# them into simpler hypotheses; if hypotheses could expand in size,
# we also have potential nontermination. (In the code below, forward
# inference is only ever carried out at a single step, but you could
# imagine repeated application of forward inference being profitable.)
#
# A good starting point in the literature for exploring more about proof
# search are these lecture notes
# https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
#
# TODO: My kingdom for a pattern matcher
# https://www.python.org/dev/peps/pep-0634/
#
# TODO: This could get us in recomputation trouble if b.expr is nontrivial.
# Fix this by implementing some sort of sharing so that if multiple
# goals share the same expression, we only compute it once. This seems
# to matter in practice as compiler is often unwilling to CSE nontrivial
# expressions like scalar.to<scalar_t>()
t = b.type
if (
isinstance(t, ConstRefCType)
and isinstance(t.elem, OptionalCType)
and isinstance(t.elem.elem, BaseCType)
and str(t.elem.elem.type) == "at::Tensor"
):
ctx[
NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
ctx[
NamedCType(t.name, BaseCType(optionalTensorRefT))
] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
if t.type == ConstRefCType(BaseCType(scalarT)):
ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"
if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
ctx[
NamedCType(t.name, BaseCType(optionalScalarRefT))
] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
if t.type == BaseCType(scalar_t):
ctx[
NamedCType(t.name, BaseCType(opmath_t))
] = f"static_cast<opmath_t>({b.expr})"
# [Note: IOptTensorListRef]
if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
ctx[
NamedCType(t.name, BaseCType(iOptTensorListRefT))
] = f"at::IOptTensorListRef({b.expr})"
# Add implicit bindings if the generated code is inside a Tensor method
if method:
ctx[
NamedCType("self", MutRefCType(BaseCType(tensorT)))
] = "const_cast<Tensor&>(*this)"
ctx[
NamedCType("self", ConstRefCType(BaseCType(tensorT)))
] = "const_cast<Tensor&>(*this)"
# This is better! Byte-for-byte compat
# ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
def unsat(goal: NamedCType) -> NoReturn:
ctx_desc = "\n".join(
f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
)
raise UnsatError(
f"""
Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
When I failed, the following bindings were available in the context:
{ctx_desc}
This probably means there is a missing rule in the rules of torchgen.api.translate.
Check this module for more information.
"""
)
# A shitty backtracking search implementation. It's shitty because it
# does backtracking via stack (bad idea!) and for the most part tries to
# avoid backtracking. In particular, if
# direct=True, we won't try to do any fancy synthesis, just trivial
# conversions (e.g., "T a" is OK for "const T& a"). So all of the
# existing rules in this function simply try to solve immediately,
# and bail if things don't work out.
def solve(goal: NamedCType, *, direct: bool) -> str:
def direct_solve(goal: NamedCType) -> str:
return solve(goal, direct=True)
if goal in ctx:
# Trivial
return ctx[goal]
# const & is satisfied with mutable &
if isinstance(goal.type, ConstRefCType):
try:
# WARNING: not strictly decreasing; be careful not
# to add a direct conversion that goes satisfies
# mutable& with const&
return solve(
NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
)
except UnsatError:
pass
# mutable & is satisfied with value
if isinstance(goal.type, MutRefCType):
try:
return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
except UnsatError:
pass
# TODO: These are referentially equal, shouldn't have to do this;
# ensuring we don't use type synonym IntArrayRef in codegen would
# help
if goal.type == ArrayRefCType(BaseCType(longT)):
return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)
if direct:
unsat(goal)
# For now, all of these rules are mutually exclusive.
if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
memory_format = direct_solve(
NamedCType(
SpecialArgName.possibly_redundant_memory_format,
OptionalCType(BaseCType(memoryFormatT)),
)
)
# No need to join "memory_format" and "options" if the target API takes "options" directly.
# Otherwise it will cause the redundant memory_format error.
if options_ctype in goal_ctypes:
return memory_format
try:
options = direct_solve(options_ctype)
return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
except UnsatError:
return memory_format
elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
dtype = direct_solve(
NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
)
pin_memory = direct_solve(
NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
)
device = direct_solve(
NamedCType("device", OptionalCType(BaseCType(deviceT)))
)
layout = direct_solve(
NamedCType("layout", OptionalCType(BaseCType(layoutT)))
)
return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
try:
options = direct_solve(options_ctype)
return f"c10::optTypeMetaToScalarType({options}.dtype_opt())"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.scalar_type()"
elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
try:
options = direct_solve(options_ctype)
return f"{options}.layout_opt()"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.layout()"
elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
try:
options = direct_solve(options_ctype)
return f"{options}.device_opt()"
except UnsatError:
out_tensor = direct_solve(out_tensor_ctype)
return f"{out_tensor}.device()"
elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
try:
options = direct_solve(options_ctype)
return f"{options}.pinned_memory_opt()"
except UnsatError:
# If we're calling a factory op from its out= variant,
# We don't actually care about the value of pin_memory.
out_tensor = direct_solve(out_tensor_ctype)
return "::std::nullopt"
# We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
elif goal.type == BaseCType(intArrayRefT):
try:
return direct_solve(NamedCType(goal.name, longVec_ctype))
except UnsatError:
# We can also go SymIntArrayRef -> IntArrayRef
symIntArrayRef_type = direct_solve(
NamedCType(goal.name, BaseCType(symIntArrayRefT))
)
return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
elif goal.type == BaseCType(symIntArrayRefT):
try:
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
return f"c10::fromIntArrayRefSlow({r})"
except UnsatError:
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
elif goal.type == BaseCType(SymIntT):
return direct_solve(NamedCType(goal.name, BaseCType(longT)))
elif goal.type == OptionalCType(BaseCType(SymIntT)):
argname = direct_solve(
NamedCType(goal.name, OptionalCType(BaseCType(longT)))
)
return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(longT):
symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
elif goal.type == OptionalCType(BaseCType(longT)):
argname = direct_solve(
NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
)
return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt"
elif goal.type == BaseCType(optionalIntArrayRefT):
try:
return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
except UnsatError:
argname = direct_solve(
NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
)
return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(optionalSymIntArrayRefT):
# TODO: You might also want to solve this from longSymVec_ctype or
# an optional version of it
argname = direct_solve(
NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
)
return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt"
elif goal.type == BaseCType(optionalScalarRefT):
return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
elif goal.type == BaseCType(optionalTensorRefT):
return direct_solve(NamedCType(goal.name, optionalTensor_ctype))
# Note [translation from C++ reference to value types]
# The below cases are all for when we have an argument with a reference type,
# and a corresponding goal with a value type.
# These are needed when we populate the inputs to a lambda capture and we need
# to guarantee the lifetime of each captured argument.
# We guard it with an explicit kwarg because converting to a value type is expensive
# (O(n)) to convert from IntArrayRef to vector<int>),
# so the caller of translate() should be explicit that they need it.
if allow_expensive_conversions:
if goal.type == VectorCType(BaseCType(longT)):
intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
argname = direct_solve(intArrayRef_ctype)
return f"{argname}.vec()"
if goal.type == VectorCType(BaseCType(SymIntT)):
symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
argname = direct_solve(symIntArrayRef_ctype)
return f"{argname}.vec()"
elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
optionalIntArrayRef_ctype = NamedCType(
goal.name, BaseCType(optionalIntArrayRefT)
)
argname = direct_solve(optionalIntArrayRef_ctype)
return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt"
elif goal.type == OptionalCType(BaseCType(scalarT)):
optionalScalarRef_ctype = NamedCType(
goal.name, BaseCType(optionalScalarRefT)
)
argname = direct_solve(optionalScalarRef_ctype)
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
elif goal.type == OptionalCType(BaseCType(scalarT)):
optionalTensorRef_ctype = NamedCType(
goal.name, BaseCType(optionalTensorRefT)
)
argname = direct_solve(optionalTensorRef_ctype)
return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
# Technically, we also need to handle cases of C++ containers holding reference types.
# But there currently aren't any ops that require lambda capture codegen
# With arguments like ::std::vector<IntArrayRef>.
# If that changes, we'll have to add the translation here.
# We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
# We could probably generalize this to non-tensor types too.
if goal.type == MutRefCType(BaseCType(tensorT)):
const_ref_tensor_ctype = NamedCType(
goal.name, ConstRefCType(BaseCType(tensorT))
)
argname = direct_solve(const_ref_tensor_ctype)
return f"const_cast<Tensor&>({argname})"
unsat(goal)
return [Expr(solve(g, direct=False), g) for g in goal_ctypes]

View File

@ -0,0 +1,5 @@
from torchgen.api.types.types import *
from torchgen.api.types.types_base import *
from torchgen.api.types.signatures import * # usort: skip

View File

@ -0,0 +1,426 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, Sequence, TYPE_CHECKING
from torchgen.api.types.types_base import Binding, CType, Expr
if TYPE_CHECKING:
from torchgen.model import (
BackendIndex,
FunctionSchema,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
)
@dataclass(frozen=True)
class CppSignature:
"""
A CppSignature represents a single overload in the C++ API. For
any given function schema, there may be multiple CppSignatures
corresponding to it, based on how we desugar to C++. See also
CppSignatureGroup.
"""
# The schema this signature is derived from
func: FunctionSchema
# Is this a C++ signature for a method, i.e. Tensor::my_op(...)?
method: bool
# Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API
# (i.e. with a potential TensorOptions argument and out arguments in the front)
faithful: bool
# Is this a symint C++ signature. For BC reasons, functions that take
# SymInts still present as int64_t in C++, and the SymInt variant is
# offered at a different overload name
#
# NB: If a function RETURNS a SymInt, this is ALWAYS false
symint: bool
# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: set[str]
# Is this a fallback C++ binding? Fallback bindings are enabled by
# manual_cpp_binding: True and are alternate, non-public API that
# lets manual C++ binding implementors access the binding that would
# have been automatically generated
fallback_binding: bool = False
# Return the unpacked argument structure of this signature,
# discarding information about which arguments are semantically
# related to each other.
def arguments(self) -> Sequence[Binding]:
return cpp.arguments(
self.func.arguments,
faithful=self.faithful,
symint=self.symint,
method=self.method,
cpp_no_default_args=self.cpp_no_default_args,
)
def name(self, *, suppress_symint_suffix: bool = False) -> str:
n = cpp.name(
self.func,
faithful_name_for_out_overloads=self.faithful,
symint_overload=False if suppress_symint_suffix else self.symint,
)
if self.fallback_binding:
n = f"__dispatch_{n}"
return n
# Render the C++ declaration for this signature
def decl(
self,
*,
name: str | None = None,
prefix: str = "",
is_redispatching_fn: bool = False,
suppress_symint_suffix: bool = False,
) -> str:
returns_type = cpp.returns_type(
self.func.returns, symint=self.symint
).cpp_type()
cpp_args = [a.decl() for a in self.arguments()]
if is_redispatching_fn:
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
cpp_args_str = ", ".join(cpp_args)
if name is None:
name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix)
return f"{returns_type} {name}({cpp_args_str})"
# Render the C++ definition for this signature, not including
# the body (with curly braces)
def defn(
self,
*,
name: str | None = None,
prefix: str = "",
is_redispatching_fn: bool = False,
) -> str:
returns_type = cpp.returns_type(
self.func.returns, symint=self.symint
).cpp_type()
cpp_args = [a.defn() for a in self.arguments()]
if is_redispatching_fn:
cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
cpp_args_str = ", ".join(cpp_args)
if name is None:
name = prefix + self.name()
return f"{returns_type} {name}({cpp_args_str})"
def ptr_type(self) -> str:
args_types_str = ", ".join(a.type for a in self.arguments())
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})"
# Return the C++ function type, e.g., something like int(bool)
def type(self) -> str:
args_types_str = ", ".join(a.type for a in self.arguments())
return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})"
# Represents group of all CppSignatures associated with a
# FunctionSchema. Right now, that's the regular, user-visible
# signature, as well as a "faithful" signature which doesn't
# have grouping.
@dataclass(frozen=True)
class CppSignatureGroup:
func: FunctionSchema
signature: CppSignature
faithful_signature: CppSignature | None
symint_signature: CppSignature | None
symint_faithful_signature: CppSignature | None
def most_faithful_signature(self) -> CppSignature:
if self.faithful_signature:
return self.faithful_signature
else:
return self.signature
def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]:
yield self.signature
if self.faithful_signature:
yield self.faithful_signature
if symint:
if self.symint_signature:
yield self.symint_signature
if self.symint_faithful_signature:
yield self.symint_faithful_signature
@staticmethod
def from_native_function(
f: NativeFunction, *, method: bool, fallback_binding: bool = False
) -> CppSignatureGroup:
func = f.func
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
return CppSignature(
func=func,
faithful=faithful,
symint=symint,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args,
)
def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
faithful_signature: CppSignature | None = None
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = make_sig(faithful=True, symint=symint)
signature = make_sig(faithful=False, symint=symint)
return signature, faithful_signature
signature, faithful_signature = make_sigs(symint=False)
symint_signature: CppSignature | None = None
symint_faithful_signature: CppSignature | None = None
if func.has_symint():
symint_signature, symint_faithful_signature = make_sigs(symint=True)
return CppSignatureGroup(
func=func,
signature=signature,
faithful_signature=faithful_signature,
symint_signature=symint_signature,
symint_faithful_signature=symint_faithful_signature,
)
@dataclass(frozen=True)
class DispatcherSignature:
# The schema this signature is derived from
func: FunctionSchema
# Allows you to prepend an arbitrary prefix to the signature name.
# This is useful for parts of the codegen that generate wrappers around kernels,
# and need to avoid naming collisions.
prefix: str = ""
symint: bool = True
def arguments(self) -> list[Binding]:
return dispatcher.arguments(self.func, symint=self.symint)
def name(self) -> str:
return self.prefix + dispatcher.name(self.func)
def decl(self, name: str | None = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments())
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn(
self, name: str | None = None, *, is_redispatching_fn: bool = False
) -> str:
args = [a.defn() for a in self.arguments()]
if is_redispatching_fn:
args = ["c10::DispatchKeySet dispatchKeySet"] + args
args_str = ", ".join(args)
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def exprs(self) -> list[Expr]:
return [Expr(a.name, a.nctype) for a in self.arguments()]
def returns_type(self) -> CType:
return dispatcher.returns_type(self.func.returns, symint=self.symint)
def ptr_type(self) -> str:
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})"
# Return the C++ function type, e.g., something like int(bool)
def type(self) -> str:
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
@staticmethod
def from_schema(
func: FunctionSchema, *, prefix: str = "", symint: bool = True
) -> DispatcherSignature:
return DispatcherSignature(func, prefix, symint)
@dataclass(frozen=True)
class NativeSignature:
# The schema this signature is derived from
func: FunctionSchema
symint: bool
prefix: str = ""
def name(self) -> str:
return self.prefix + native.name(self.func)
def decl(self, name: str | None = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments())
if name is None:
name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
def defn(self, name: str | None = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments())
if name is None:
name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
def ptr_type(self) -> str:
# don't include defaults in type signature!
args_str = ", ".join(a.defn() for a in self.arguments())
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
def arguments(self) -> list[Binding]:
return native.arguments(self.func, symint=self.symint)
def returns_type(self) -> CType:
return native.returns_type(self.func.returns, symint=self.symint)
def dispatcher_exprs(self) -> list[Expr]:
return translate.translate(
self.arguments(), dispatcher.arguments(self.func), method=False
)
@dataclass(frozen=True)
class ViewInverseSignature:
g: NativeFunctionsViewGroup
def name(self) -> str:
return functionalization.reverse_name(self.g.view, include_namespace=False)
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
decls = [
a.decl()
for a in functionalization.inner_arguments(
self.g.view.func, is_reverse=True
)
]
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
@dataclass(frozen=True)
class FunctionalizationLambda:
g: NativeFunctionsViewGroup
# are we generating the forward lambda or the reverse lambda?
is_reverse: bool
def captures(self) -> list[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
outer_ctx = dispatcher.arguments(self.g.view.func) + [
functionalization.reapply_views_binding,
functionalization.inverse_return_mode_binding,
]
capture_bindings = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
# allow_expensive_conversions is set because we want to convert
# some reference types (IntArrayRef) to value types (vector<int64_t>).
capture_exprs = translate.translate(
outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
)
return capture_exprs
def decl(self) -> str:
return_type = functionalization.returns_type(self.g.view.func)
capture_str = ", ".join(
f"{val.type.name} = {val.expr}" for val in self.captures()
)
decls = [
a.decl()
for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: bool | None = None) -> str:
inner_call_name = functionalization.name(
self.g,
is_reverse=self.is_reverse,
include_namespace=True,
reapply_views=reapply_views,
)
arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
capture_ctx = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
)
full_ctx = arg_ctx + capture_ctx
assert self.g.view_copy is not None
call_bindings = functionalization.inner_arguments(
self.g.view_copy.func, is_reverse=self.is_reverse
)
maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
call_exprs = [
e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
]
if not self.is_reverse and maybe_index is not None:
return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
else:
return f'{inner_call_name}({", ".join(call_exprs)});'
@staticmethod
def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> FunctionalizationLambda:
return FunctionalizationLambda(g, is_reverse)
@dataclass(frozen=True)
class StructuredImplSignature:
g: NativeFunctionsGroup
name: str
def defn(self, name: str | None = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments())
return f"TORCH_IMPL_FUNC({self.name})({args_str})"
def arguments(self) -> list[Binding]:
return structured.impl_arguments(self.g)
# Helper functions
def kernel_signature(
f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
) -> NativeSignature | DispatcherSignature:
# Note [External Backends Follow Dispatcher API]
# Kernel signatures for in-tree backends follow the "native" API,
# while kernels for out-of-tree backends follow the dispatcher API.
# See the comments in `native.py` for details, but historically there have been
# some small differences in schema convention between them and the Dispatcher API.
# Any differences that require translating between the two will results in a runtime cost,
# so we'd like to keep the differences as small as possible.
# With external backends, we'd like to enforce that they write their kernels with schemas
# that match the Dispatcher API directly, if they can.
meta = backend_index.get_kernel(f)
symint = meta is not None and meta.supports_symint()
if symint:
assert (
f.func.has_symint()
), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
if backend_index.external:
return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
else:
return NativeSignature(f.func, prefix=prefix, symint=symint)
# Functions only, no types
from torchgen.api import (
cpp,
dispatcher,
functionalization,
native,
structured,
translate,
)

View File

@ -0,0 +1,191 @@
"""
Where should I add a new type? `types_base.py` vs `types.py`
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from __future__ import annotations
from dataclasses import dataclass
from torchgen.api.types.types_base import (
BaseCppType,
BaseCType,
boolT,
byteT,
charT,
CType,
doubleT,
floatT,
int32T,
longT,
shortT,
)
from torchgen.model import BaseTy, ScalarType
TENSOR_LIST_LIKE_CTYPES = [
"at::TensorList",
"const c10::List<::std::optional<at::Tensor>> &",
"const at::ITensorListRef &",
]
halfT = BaseCppType("at", "Half")
complexHalfT = BaseCppType(
"c10", "complex<c10::Half>"
) # stuffing template param here is an abuse
complexFloatT = BaseCppType("c10", "complex<float>")
complexDoubleT = BaseCppType("c10", "complex<double>")
bfloat16T = BaseCppType("at", "BFloat16")
float8_e5m2T = BaseCppType("at", "Float8_e5m2")
float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
stringT = BaseCppType("c10", "string_view")
generatorT = BaseCppType("at", "Generator")
scalarTypeT = BaseCppType("at", "ScalarType")
tensorT = BaseCppType("at", "Tensor")
optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
tensorListT = BaseCppType("at", "TensorList")
iTensorListRefT = BaseCppType("at", "ITensorListRef")
iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
dimnameT = BaseCppType("at", "Dimname")
dimnameListT = BaseCppType("at", "DimnameList")
dimVectorT = BaseCppType("at", "DimVector")
layoutT = BaseCppType("at", "Layout")
deviceT = BaseCppType("at", "Device")
deviceIndexT = BaseCppType("at", "DeviceIndex")
scalarT = BaseCppType("at", "Scalar")
optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
memoryFormatT = BaseCppType("at", "MemoryFormat")
qschemeT = BaseCppType("at", "QScheme")
storageT = BaseCppType("at", "Storage")
streamT = BaseCppType("at", "Stream")
intArrayRefT = BaseCppType("at", "IntArrayRef")
optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
tensorOptionsT = BaseCppType("at", "TensorOptions")
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
tensorGeometryT = BaseCppType("at", "TensorGeometry")
SymIntT = BaseCppType("c10", "SymInt")
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
# Types representing template parameters. Technically, we probably shouldn't
# represent them this way in codegen, but it was pretty convenient.
scalar_t = BaseCppType("", "scalar_t")
opmath_t = BaseCppType("", "opmath_t")
ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
ScalarType.Byte: byteT,
ScalarType.Char: charT,
ScalarType.Short: shortT,
ScalarType.Int: int32T,
ScalarType.Long: longT,
ScalarType.Half: halfT,
ScalarType.Float: floatT,
ScalarType.Double: doubleT,
ScalarType.ComplexHalf: complexHalfT,
ScalarType.ComplexFloat: complexFloatT,
ScalarType.ComplexDouble: complexDoubleT,
ScalarType.Bool: boolT,
ScalarType.Float8_e5m2: float8_e5m2T,
ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
ScalarType.Float8_e4m3fn: float8_e4m3fnT,
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
}
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
BaseTy.str: stringT,
BaseTy.Generator: generatorT,
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Dimname: dimnameT,
BaseTy.DimVector: dimVectorT,
BaseTy.Layout: layoutT,
BaseTy.Device: deviceT,
BaseTy.DeviceIndex: deviceIndexT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
BaseTy.QScheme: qschemeT,
BaseTy.Storage: storageT,
BaseTy.Stream: streamT,
BaseTy.SymInt: SymIntT,
}
# CTypes encode C++ type structure as needed for translation.
@dataclass(frozen=True)
class OptionalCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::optional<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ListCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"c10::List<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return ListCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"at::ArrayRef<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return ArrayRefCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class VectorizedCType(CType):
# This template is explicitly specialized, so the only valid
# elems are those we have specializations for (e.g., float, double, ...)
# scalar_t is also a common argument here (when we are codegen in
# a templated context)
elem: BaseCType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
def remove_const_ref(self) -> CType:
return self

View File

@ -0,0 +1,276 @@
"""
Where should I add a new type? `types_base.py` vs `types.py`
This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import auto, Enum
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
# An ArgName is just the str name of the argument in schema;
# but in some special circumstances, we may add a little extra
# context. The Enum SpecialArgName covers all of these cases;
# grep for their construction sites to see when they can occur.
class SpecialArgName(Enum):
possibly_redundant_memory_format = auto()
ArgName = Union[str, SpecialArgName]
# This class shouldn't be created directly; instead, use/create one of the singletons below.
@dataclass(frozen=True)
class BaseCppType:
ns: str | None
name: str
def __str__(self) -> str:
if self.ns is None or self.ns == "":
return self.name
return f"{self.ns}::{self.name}"
# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
# Templated types get their own dataclass, mainly to make namespace parsing easier.
byteT = BaseCppType("", "uint8_t")
charT = BaseCppType("", "int8_t")
shortT = BaseCppType("", "int16_t")
# It would be more symmetric for this to be called intT, but it easy to mix
# this up with JIT int (which is int64_t in C++), so we intentionally don't
# define intT to make it obvious when you've stuffed it up
int32T = BaseCppType("", "int32_t")
longT = BaseCppType("", "int64_t")
doubleT = BaseCppType("", "double")
floatT = BaseCppType("", "float")
boolT = BaseCppType("", "bool")
voidT = BaseCppType("", "void")
class CType(ABC):
@abstractmethod
def cpp_type(self, *, strip_ref: bool = False) -> str:
raise NotImplementedError
@abstractmethod
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
@abstractmethod
def remove_const_ref(self) -> CType:
return self
@dataclass(frozen=True)
class BaseCType(CType):
type: BaseCppType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return str(self.type)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return str(self.type).replace("at::", "")
def remove_const_ref(self) -> CType:
return self
@dataclass(frozen=True)
class ConstRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"const {self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"const {self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> CType:
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class VectorCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::vector<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return VectorCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayCType(CType):
elem: CType
size: int
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"::std::array<{self.elem.cpp_type()},{self.size}>"
def cpp_type_registration_declarations(self) -> str:
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
def remove_const_ref(self) -> CType:
return ArrayCType(self.elem.remove_const_ref(), self.size)
@dataclass(frozen=True)
class TupleCType(CType):
elems: list[CType]
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
def cpp_type_registration_declarations(self) -> str:
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
def remove_const_ref(self) -> CType:
return TupleCType([e.remove_const_ref() for e in self.elems])
@dataclass(frozen=True)
class MutRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f"{self.elem.cpp_type()} &"
def cpp_type_registration_declarations(self) -> str:
return f"{self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> CType:
return self.elem.remove_const_ref()
# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus
# semantic information about what it represents. For example, consider the
# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
# semantic type also keeps track that this represents a "pin_memory"; you can't
# just use a random other boolean in a context where you need a "pin_memory"!
#
@dataclass(frozen=True)
class NamedCType:
name: ArgName
type: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return self.type.cpp_type(strip_ref=strip_ref)
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def cpp_type_registration_declarations(self) -> str:
return self.type.cpp_type_registration_declarations()
def remove_const_ref(self) -> NamedCType:
return NamedCType(self.name, self.type.remove_const_ref())
def with_name(self, name: str) -> NamedCType:
return NamedCType(name, self.type)
# A binding represents any C++ binding site for a formal parameter.
# We don't distinguish between binding sites for different APIs;
# instead, all of the important distinctions are encoded in CType,
# which you can use to figure out if a given Binding is appropriate
# for use in another context. (See torchgen.api.translate)
@dataclass(frozen=True)
class Binding:
name: str
nctype: NamedCType
argument: Argument | TensorOptionsArguments | SelfArgument
# TODO: maybe don't represent default here
default: str | None = None
def rename(self, name: str) -> Binding:
return Binding(
name=name,
nctype=self.nctype,
argument=self.argument,
default=self.default,
)
@property
def type(self) -> str:
return self.nctype.cpp_type()
def no_default(self) -> Binding:
return Binding(
name=self.name,
nctype=self.nctype,
default=None,
argument=self.argument,
)
def decl(self, *, func_ptr_cast: bool = False) -> str:
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
# casting only needs to know the type
if func_ptr_cast:
return f"{self.type}"
else:
return f"{self.type} {self.name}{mb_default}"
# For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
# TODO: Kill this when we eventually remove it!
def decl_registration_declarations(self) -> str:
type_s = self.nctype.cpp_type_registration_declarations()
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
return f"{type_s} {self.name}{mb_default}"
def defn(self) -> str:
return f"{self.type} {self.name}"
def with_name(self, name: str) -> Binding:
return Binding(
name=name, nctype=self.nctype, argument=self.argument, default=self.default
)
# An Expr is a C++ expression. It has a C++ string representing its syntax,
# as well as a CType saying what it provides.
@dataclass(frozen=True)
class Expr:
expr: str
type: NamedCType

View File

@ -0,0 +1,209 @@
from __future__ import annotations
from dataclasses import dataclass
import torchgen.api.types as api_types
from torchgen.api import cpp, structured
from torchgen.api.types import (
ArgName,
BaseCppType,
BaseCType,
Binding,
ConstRefCType,
CType,
NamedCType,
scalarT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
DispatchKey,
FunctionSchema,
NativeFunctionsGroup,
Type,
)
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
return f"ufunc_{func.name.name}_{dispatch_key}"
def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
return schema_kernel_name(g.out.func, dispatch_key)
# Tensors are omitted (as they are stored in TensorIterator), everything else is
# passed along (technically, we can pass tensors along too, it just wastes
# argument registers)
#
# NB: used for CPU only
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
# Dispatch stubs are always plain ints
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
elif t == BaseType(BaseTy.Tensor):
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
if scalar_t == api_types.scalar_t:
return api_types.opmath_t
raise NotImplementedError
# NB: Tensors in constructor are stored in opmath_t, not scalar_t
# because Tensor in constructor = its a scalar tensor partially applied =
# it can be higher precision and we want to compute in that higher precision
#
# NB: CUDA only
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
elif t == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Only Tensors ever get passed directly to operator()
#
# NB: CUDA only
# (Actually, this works for CPU too)
def ufunctor_apply_type(
t: Type, *, binds: ArgName, scalar_t: BaseCppType
) -> NamedCType:
if t == BaseType(BaseTy.Tensor):
return NamedCType(binds, BaseCType(scalar_t))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# The actual ufunc template function the user writes. Everything here
# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
# in CPU
def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
return r
if t == BaseType(BaseTy.Scalar):
return NamedCType(binds, compute_t)
elif t == BaseType(BaseTy.Tensor):
return NamedCType(binds, compute_t)
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
return Binding(
nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
name=a.name,
default=None,
argument=a,
)
def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
return Binding(
nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
name=a.name,
default=None,
argument=a,
)
def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
return Binding(
nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
name=a.name,
default=None,
argument=a,
)
@dataclass(frozen=True)
class UfunctorBindings:
ctor: list[Binding]
apply: list[Binding]
# ufunctors are a CUDA-only concept representing functors that take some of
# their arguments on a host-side constructor, and the rest in the device-side
# apply. E.g.,
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
# using opmath_t = at::opmath_type<scalar_t>;
# opmath_t other_;
# opmath_t alpha_;
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
# __device__ scalar_t operator()(scalar_t self) {
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
# }
# };
#
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
# to the operator() definition
def ufunctor_arguments(
g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
) -> UfunctorBindings:
ctor = []
apply = []
for a in g.functional.func.arguments.flat_non_out:
if a.type.is_tensor_like():
if scalar_tensor_idx == 0:
# put it in the ctor anyway
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
scalar_tensor_idx = None
else:
if scalar_tensor_idx is not None:
scalar_tensor_idx -= 1
apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
else:
ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
assert scalar_tensor_idx is None
return UfunctorBindings(ctor=ctor, apply=apply)
# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
# which do the actual computation in question. E.g.,
#
# template <typename T>
# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
# return self + alpha * other;
# }
#
# In this file, we refer to T as compute_t which is bound by caller
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
return [
ufunc_argument(a, compute_t=compute_t)
for a in g.functional.func.arguments.flat_non_out
]
# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
# vectorized versions. E.g.,
#
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
# stubs drop all tensor arguments (they are implicit in the TensorIterator
# argument and keep everything else)
return [
r
for a in g.out.func.arguments.flat_non_out
if not a.type.is_tensor_like()
for r in structured.argument(a)
]

View File

@ -0,0 +1,249 @@
from __future__ import annotations
from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignatureGroup, CType
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Type,
)
# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is
# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the
# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register
# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase.
#
# Here's an example on how the codegen works:
#
# - Function Schema (source of truth)
#
# aten::empty.names(int[] size, *, Dimname[]? names,
# ScalarType? dtype=None, Layout? layout=None,
# Device? device=None, bool? pin_memory=None,
# MemoryFormat? memory_format=None) -> Tensor
# - Argument Conversion
# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type.
# - int[] size
# ```cpp
# const c10::List<c10::IValue> size_list_in = (std::move(peek(stack, 0, 7))).toList();
#
# std::vector<int64_t> size_vec;
# for (c10::IValue size_elem: size_list_in) {
# int64_t size_base = size_elem.to<int64_t>();
# size_vec.push_back(size_base);
# }
# at::ArrayRef<int64_t> size_list_out(size_vec);
# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack.
# Will be passed to unboxed kernel.
# ```
# - Dimname[]? names
# ```cpp
# ::std::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>();
# ::std::optional<at::ArrayRef<at::Dimname>> names_opt_out;
# if (names_opt.has_value()) {
# ~~~~~~~~~~~ <-- Unwrapping optional shell
# const c10::IValue names_opt_in = names_opt.value();
# const c10::List<c10::IValue> names_list_in = names_opt_in.toList();
#
# std::vector<at::Dimname> names_vec;
# for (c10::IValue names_elem: names_list_in) {
# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one.
# at::Dimname names_base = names_elem.to<at::Dimname>();
# names_vec.push_back(names_base);
# }
# at::ArrayRef<at::Dimname> names_list_out(names_vec);
#
# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(names_list_out);
# } else {
# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>();
# }
# ```
# - ScalarType? dtype (similarly for the rest of the arguments)
# ```cpp
# ::std::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>();
# ::std::optional<at::ScalarType> dtype_opt_out;
# if (dtype_opt.has_value()) {
# const c10::IValue dtype_opt_in = dtype_opt.value();
# at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>();
# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
# directly using ".to<T>()" API.
# dtype_opt_out = ::std::optional<at::ScalarType>(dtype_base);
# } else {
# dtype_opt_out = ::std::optional<at::ScalarType>();
# }
# ```
#
# - Unboxed Kernel Call
# ```cpp
# auto result_ = torch::empty(
# size_list_out,
# names_opt_out,
# options,
# memory_format_opt_out
# );
# ```
#
# - Push Result Back to Stack
# ```cpp
# drop(stack, 7);
# pack(stack, std::move(result_));
# ```
connector = "\n\t"
# Return unboxing function name for a NativeFunction
def name(f: NativeFunction) -> str:
return f.func.name.unambiguous_name()
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
# we need the 'self' argument so method needs to be False
args = (
CppSignatureGroup.from_native_function(f, method=False)
.most_faithful_signature()
.arguments()
)
code_list = [
f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));"
for i in range(len(args))
] + [""]
binding_list = []
for arg in args:
# expecting only Argument
if not isinstance(arg.argument, Argument):
raise Exception( # noqa: TRY002
f"Unexpected argument type, expecting `Argument` but got {arg}"
)
argument: Argument = arg.argument
unboxed_name, _, code, decl = argumenttype_ivalue_convert(
argument.type,
argument.name,
mutable=argument.is_write,
)
code_list.extend(decl)
code_list.extend(code)
binding_list.append(arg.with_name(unboxed_name))
return binding_list, code_list
# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
# (1) the C++ code necessary to unbox the argument
# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
def argumenttype_ivalue_convert(
t: Type, arg_name: str, *, mutable: bool = False
) -> tuple[str, CType, list[str], list[str]]:
# Unboxing is for mobile, which doesn't care about SymInts
ctype = cpp.argumenttype_type(
t=t, mutable=mutable, binds=arg_name, symint=False
).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"
code, decl = _gen_code_base_type(
arg_name=arg_name, out_name=out_name, ctype=ctype
)
elif isinstance(t, OptionalType):
out_name = f"{arg_name}_opt_out"
code, decl = _gen_code_optional_type(
arg_name=arg_name,
out_name=out_name,
t=t,
ctype=ctype,
)
elif isinstance(t, ListType):
out_name = f"{arg_name}_list_out"
code, decl = _gen_code_list_type(
arg_name=arg_name,
out_name=out_name,
t=t,
ctype=ctype,
)
else:
raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") # noqa: TRY002
return out_name, ctype, code, decl
def _gen_code_base_type(
arg_name: str, out_name: str, ctype: CType
) -> tuple[list[str], list[str]]:
return [
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_opt_in"
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
return (
f"""
auto {arg_name}_opt = {arg_name}.toOptional<c10::IValue>();
{ctype.cpp_type(strip_ref=True)} {out_name};
if ({arg_name}_opt.has_value()) {{
const c10::IValue {in_name} = {arg_name}_opt.value();
{connector.join(res_code)}
{out_name} = {ctype.cpp_type(strip_ref=True)}({res_name});
}} else {{
{out_name} = {ctype.cpp_type(strip_ref=True)}();
}}
""".split(
"\n"
),
decl,
)
def _gen_code_list_type(
arg_name: str, out_name: str, t: ListType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name)
# handle list type with size, e.g., bool[4]
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
""".split(
"\n"
)
)
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif isinstance(t.elem, OptionalType):
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name};
for (c10::IValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{out_name}.push_back({res_name});
}}
""".split(
"\n"
)
)
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};")
code.extend(
f"""
for (c10::IValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
return code, decl

View File

@ -0,0 +1,99 @@
from __future__ import annotations
import re
from typing import Mapping, Sequence
# match $identifier or ${identifier} and replace with value in env
# If this identifier is at the beginning of whitespace on a line
# and its value is a list then it is treated as
# block substitution by indenting to that depth and putting each element
# of the list on its own line
# if the identifier is on a line starting with non-whitespace and a list
# then it is comma separated ${,foo} will insert a comma before the list
# if this list is not empty and ${foo,} will insert one after.
class CodeTemplate:
substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
substitution = re.compile(substitution_str, re.MULTILINE)
pattern: str
filename: str
@staticmethod
def from_file(filename: str) -> CodeTemplate:
with open(filename) as f:
return CodeTemplate(f.read(), filename)
def __init__(self, pattern: str, filename: str = "") -> None:
self.pattern = pattern
self.filename = filename
def substitute(
self, env: Mapping[str, object] | None = None, **kwargs: object
) -> str:
if env is None:
env = {}
def lookup(v: str) -> object:
assert env is not None
return kwargs[v] if v in kwargs else env[v]
def indent_lines(indent: str, v: Sequence[object]) -> str:
return "".join(
[indent + l + "\n" for e in v for l in str(e).splitlines()]
).rstrip()
def replace(match: re.Match[str]) -> str:
indent = match.group(1)
key = match.group(2)
comma_before = ""
comma_after = ""
if key[0] == "{":
key = key[1:-1]
if key[0] == ",":
comma_before = ", "
key = key[1:]
if key[-1] == ",":
comma_after = ", "
key = key[:-1]
v = lookup(key)
if indent is not None:
if not isinstance(v, list):
v = [v]
return indent_lines(indent, v)
elif isinstance(v, list):
middle = ", ".join([str(x) for x in v])
if len(v) == 0:
return middle
return comma_before + middle + comma_after
else:
return str(v)
return self.substitution.sub(replace, self.pattern)
if __name__ == "__main__":
c = CodeTemplate(
"""\
int foo($args) {
$bar
$bar
$a+$b
}
int commatest(int a${,stuff})
int notest(int a${,empty,})
"""
)
print(
c.substitute(
args=["hi", 8],
bar=["what", 7],
a=3,
b=4,
stuff=["things...", "others"],
empty=[],
)
)

View File

@ -0,0 +1,130 @@
from __future__ import annotations
import contextlib
import functools
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
import torchgen.local as local
from torchgen.model import (
BackendIndex,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
)
from torchgen.utils import context, S, T
# Helper functions for defining generators on things in the model
F = TypeVar(
"F",
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
Union[NativeFunction, NativeFunctionsGroup],
Union[NativeFunction, NativeFunctionsViewGroup],
)
F2 = TypeVar(
"F2",
NativeFunction,
NativeFunctionsGroup,
Optional[NativeFunction],
bool,
str,
)
F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
@contextlib.contextmanager
def native_function_manager(
g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
) -> Iterator[None]:
if isinstance(g, NativeFunctionsGroup):
# By default, we associate all errors with structured native functions
# with the out variant. In some cases, it might be better to have
# a more specific place to hang things; if so, use
# native_function_manager again on the inside
f = g.out
elif isinstance(g, NativeFunctionsViewGroup):
# We associate errors with the view operator
f = g.view
else:
f = g
with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
with local.parametrize(
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
use_ilistref_for_tensor_lists=f.part_of_structured_group,
):
yield
# Given a function that operates on NativeFunction, wrap it into a new function
# that sets some appropriate context managers for that native function.
# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
# (you will get an error if we try to access the local variables without having
# set them).
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
@functools.wraps(func)
def wrapper(f: F) -> T:
with native_function_manager(f):
return func(f)
return wrapper
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
@functools.wraps(func)
def wrapper(f: F, f2: F2) -> T:
# The first native_function is assumed to be the one with the appropriate context.
with native_function_manager(f):
return func(f, f2)
return wrapper
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F) -> T:
with native_function_manager(f):
return func(slf, f)
return wrapper
def method_with_nested_native_function(
func: Callable[[S, F3], T]
) -> Callable[[S, F3], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F3) -> T:
with native_function_manager(f[0]):
return func(slf, f)
return wrapper
# Convenience decorator for functions that explicitly take in a BackendIndex,
# instead of indirectly taking one in as a closure
def with_native_function_and_index(
func: Callable[[F, BackendIndex], T]
) -> Callable[[F, BackendIndex], T]:
@functools.wraps(func)
def wrapper(f: F, backend_index: BackendIndex) -> T:
with native_function_manager(f):
return func(f, backend_index)
return wrapper
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
def with_native_function_and_indices(
func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
@functools.wraps(func)
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
with native_function_manager(f):
return func(f, backend_indices)
return wrapper

View File

@ -0,0 +1,19 @@
from torchgen.dest.lazy_ir import (
generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
GenLazyIR as GenLazyIR,
GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
)
from torchgen.dest.native_functions import (
compute_native_function_declaration as compute_native_function_declaration,
)
from torchgen.dest.register_dispatch_key import (
gen_registration_headers as gen_registration_headers,
gen_registration_helpers as gen_registration_helpers,
RegisterDispatchKey as RegisterDispatchKey,
)
from torchgen.dest.ufunc import (
compute_ufunc_cpu as compute_ufunc_cpu,
compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
compute_ufunc_cuda as compute_ufunc_cuda,
)

View File

@ -0,0 +1,707 @@
from __future__ import annotations
import itertools
from abc import ABC
from dataclasses import dataclass
from typing import Any
import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import (
getValueT,
isValueType,
LazyArgument,
LazyIrProperties,
LazyIrSchema,
tensorListValueT,
)
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
deviceT,
DispatcherSignature,
kernel_signature,
NativeSignature,
OptionalCType,
VectorCType,
)
from torchgen.context import method_with_native_function
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
from torchgen.model import (
Argument,
BackendIndex,
BackendMetadata,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsGroup,
)
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
"""
Given a LazyArgument,
generate a c++ string for materializing an rvalue of that arg for passing into
a lazy Node constructor.
"""
# TODO: Matching on CType seems wrong; should be matching on Type
if isValueType(arg.lazy_type):
if isinstance(arg.lazy_type, BaseCType):
if arg.is_wrapped_scalar:
return f"node_{arg.name}"
elif arg.lazy_type.type is tensorListValueT:
return f"lazy_{arg.name}_tensorlist"
elif arg.is_symint_or_list:
return f"GetSymIntValue({arg.name})"
return f"lazy_{arg.name}->GetIrValue()"
elif isinstance(arg.lazy_type, OptionalCType):
if arg.is_symint_or_list:
# TODO: I don't understand when you should put lazy_ in the name
# or not
return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
elif arg.is_wrapped_scalar:
return f"node_{arg.name}"
return (
f"lazy_{arg.name} ? "
f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
"::std::nullopt"
)
else:
raise AssertionError(
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
)
else:
# NB: this is here because right now we aren't treating SymInt[] as a
# value type; when we do this needs to move above
# NB: we cannot test arg.lazy_type as we've already specified it is an
# int64_t and so we cannot distinguish between SymInt and int64_t
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
BaseTy.SymInt
):
if arg.symint:
return f"GetSymIntArrayRefValue({arg.name})"
else:
return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
arg.lazy_type.elem, BaseCType
):
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
elif (
isinstance(arg.lazy_type, OptionalCType)
and isinstance(arg.lazy_type.elem, VectorCType)
and isinstance(arg.lazy_type.elem.elem, BaseCType)
):
return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
else:
return f"{arg.name}"
def node_ctor_inputs(schema: LazyIrSchema) -> str:
"""
Produce a formatted string with the arguments as passed into the constructor of a node class.
"""
node_ctor_values = [
node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
]
return ", ".join(node_ctor_values)
def gen_fallback_code(
schema: LazyIrSchema,
sig: DispatcherSignature | NativeSignature,
overload_name: str,
) -> str:
"""
Generate code that falls back to eager conditioned on a predicate
"""
dispatcher_sig = DispatcherSignature.from_schema(schema.func)
exprs = translate(sig.arguments(), dispatcher_sig.arguments())
fallback_args = ",\n ".join([a.expr for a in exprs])
if len(overload_name):
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
else:
aten_op_str = f"ATEN_OP({schema.aten_name})"
return f"""
if (force_eager_fallback({aten_symbol(schema)})) {{
return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
{fallback_args}
);
}}
"""
def aten_symbol(schema: LazyIrSchema) -> str:
missing_interned_strings = {
"sigmoid_backward",
}
if schema.aten_name in missing_interned_strings:
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
if not schema.aten_name.startswith("at::"):
return f"at::aten::{schema.aten_name}"
else:
return schema.aten_name
# converts all tensor-like arguments to meta tensors. Returns:
# (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
context: list[Binding] = []
unwrapped_tensor_args: list[str] = []
for arg in sig.arguments():
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
unwrapped_name = f"{arg.name}_meta"
unwrapped_tensor_args.append(
f"auto {unwrapped_name} = to_meta({arg.name});"
)
context.append(arg.with_name(unwrapped_name))
else:
context.append(arg)
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
return unwrap_tensor_args_str, context
@dataclass(frozen=True)
class GenLazyIR(ABC):
backend_index: BackendIndex
backend_name: str
node_base: str
use_lazy_shape: bool
@method_with_native_function
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
metadata = self.backend_index.get_kernel(
f.functional if isinstance(f, NativeFunctionsGroup) else f
)
schema = LazyIrSchema(
func, symint=metadata is not None and metadata.supports_symint()
)
return self.gen(schema)
# there is no lowering functionality generated unless this IR base class is subclassed and
# implemented as a backend-specific node
def lowering_function(self, schema: LazyIrSchema) -> str:
return ""
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return ""
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return f"""bool CanBeReused({node_ctor_args}) const {{
return false;
}}"""
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
# backends can customize the way the node base class constructor is called,
# as long as all of its arguments can be generated from information available from the schema
base_ctor_value_args_list = []
for arg in value_args:
if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
base_ctor_value_args_list.append(f"{arg.name}")
elif isinstance(arg.lazy_type, OptionalCType):
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
else:
raise AssertionError(
f"Unsupported type ({arg.lazy_type}) - add support if necessary"
)
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
scalar_args = schema.filtered_args(values=False, scalars=True)
# Shape construction.
# Conditionally build shape depending on specified shape property
if schema.properties.ShapePrecompute:
shape_ctor_arg = "std::move(shapes),"
elif schema.properties.ShapeCompute:
shape_args = [a.name for a in value_args]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
elif schema.properties.ShapeCache:
shape_args = [f"operand({i})" for i in range(len(value_args))]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
else:
shape_ctor_arg = ""
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
return f"""{self.node_base}(
{schema.node_name}::ClassOpKind(),
OpList{{{base_ctor_value_args}}},
{shape_ctor_arg}
/* num_outputs */ {len(schema.returns)},
torch::lazy::MHash({scalar_hashes}))"""
def gen(self, schema: LazyIrSchema) -> list[str]:
opkind = schema.opkind or aten_symbol(schema)
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
all_args = schema.filtered_args()
scalar_args = schema.filtered_args(values=False, scalars=True)
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
reuse_ctor_args = ", ".join(ctor_args)
if self.use_lazy_shape and schema.properties.ShapePrecompute:
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
node_ctor_args = ", ".join(ctor_args)
scalar_initializers = ",\n ".join(
[
# This code is just special casing the mapping from string_view -> strings
f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
else f"{a.name}({a.name})"
for a in scalar_args
]
)
if len(scalar_initializers):
scalar_initializers = f",\n {scalar_initializers}"
scalar_decls = "\n ".join(
[
f"std::string {a.name};"
if a.lazy_type.cpp_type() == "c10::string_view"
else f"::std::optional<std::string> {a.name};"
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
else f"{a.lazy_type.cpp_type()} {a.name};"
for a in scalar_args
]
)
optional_values = [
arg.name
for arg in schema.filtered_args(values=True, scalars=False)
if isinstance(arg.lazy_type, OptionalCType)
]
has_optional_decls = "\n ".join(
[f"bool has_{value}: 1;" for value in optional_values]
)
has_optional_defs = "\n ".join(
[f"has_{value} = !!{value};" for value in optional_values]
)
members_to_string = []
for arg in scalar_args:
if isinstance(arg.lazy_type, OptionalCType):
value = f"{arg.name}.value()"
if arg.is_generator:
value = '"torch.Generator()"'
members_to_string.append(
f"""if ({arg.name}.has_value()) {{
ss << ", {arg.name}=" << {value};
}} else {{
ss << ", {arg.name}=null";
}}"""
)
else:
members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
members_to_string_str = "\n ".join(members_to_string)
return [
f"""\
class {schema.node_name} : public {self.node_base} {{
public:
static torch::lazy::OpKind ClassOpKind() {{
return torch::lazy::OpKind({opkind});
}}
{schema.node_name}({node_ctor_args})
: {self.node_base_ctor_call(schema)}{scalar_initializers}
{{
{has_optional_defs}
}}
std::string ToString() const override {{
std::stringstream ss;
ss << {self.node_base}::ToString();
{members_to_string_str}
return ss.str();
}}
{self.create_function(schema, reuse_ctor_args)}
{self.can_be_reused_function(schema, reuse_ctor_args)}
{self.lowering_function(schema)}
{scalar_decls}
{has_optional_decls}
}};
""",
]
@dataclass(frozen=True)
class GenTSLazyIR(GenLazyIR):
def lowering_function(self, schema: LazyIrSchema) -> str:
signature = """
torch::lazy::TSOpVector Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
torch::lazy::TSLoweringContext* loctx) const override"""
if schema.properties.LowerDeclOnly:
return f"{signature};"
elif schema.properties.Lower:
return f"""{signature} {{
{ts_lowering_body(schema)}
}}
"""
else:
return ""
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
signature = f"static NodePtr Create({node_ctor_args})"
if schema.properties.CreateFnDeclOnly:
return f"{signature};"
elif not schema.properties.CreateFn:
return ""
return f"""{signature} {{
return ReuseOrMakeNode<{schema.node_name}>(data);
}}"""
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
signature = f"bool CanBeReused({node_ctor_args}) const"
if schema.properties.CanBeReusedDeclOnly:
return f"{signature};"
elif not schema.properties.CanBeReused:
return ""
value_comparison = []
for arg in itertools.chain(schema.positional_values, schema.keyword_values):
if isinstance(arg.lazy_type, OptionalCType):
value_comparison.append(
f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
)
else:
value_comparison.append(f"operand(i++) == {arg.name}")
for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
if isinstance(arg.lazy_type, OptionalCType):
value_comparison.append(
f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
)
else:
value_comparison.append(f"this->{arg.name} == {arg.name}")
value_comparison_str = " &&\n ".join(value_comparison)
return f"""{signature} {{
size_t i = 0;
return ({value_comparison_str});
}}"""
@dataclass(frozen=True)
class GenLazyNativeFuncDefinition:
class_method_name: str
backend_index: BackendIndex
tensor_class: str
gen_forced_fallback_code: bool
backend_namespace: str
get_tensorlist: str
get_tensor_or_wrap_number: str
try_get_tensor: str
metrics_counter: str
create_tensor: str
create_from_first_tensor: bool
create_aten_from_ltc_tensor: str
tuple_aten_from_ltc_tensors: str
lazy_tensor_ptr: str
get_device_fn: str
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
lazy_tensor_decls: list[str] = []
for arg in value_args:
if arg.is_wrapped_scalar:
if isinstance(arg.lazy_type, OptionalCType):
lazy_tensor_decls.append(
f"""auto node_{arg.name} = {arg.name} ?
std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
::std::nullopt;"""
)
else:
lazy_tensor_decls.append(
f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
)
elif arg.is_symint_or_list:
continue # values are extracted in isValueType
elif isinstance(arg.lazy_type, BaseCType):
if arg.lazy_type.type is tensorListValueT:
lazy_tensor_decls.append(
f"auto lazy_{arg.name}_tensorlist = "
f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
)
else:
lazy_tensor_decls.append(
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
)
elif isinstance(arg.lazy_type, OptionalCType):
assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
# TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
# until we encounter a real world example.
lazy_tensor_decls.append(
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
)
else:
raise AssertionError(
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
)
return ("\n ").join(lazy_tensor_decls)
def force_eager_fallback(
self,
func: NativeFunction,
schema: LazyIrSchema,
metadata: BackendMetadata,
sig: DispatcherSignature | NativeSignature,
) -> str:
if self.gen_forced_fallback_code:
return gen_fallback_code(
schema, sig, overload_name=func.func.name.overload_name
)
return ""
def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
return f"{self.metrics_counter};"
def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
scalar_args = schema.filtered_args(values=False, scalars=True)
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
optional_device = OptionalCType(BaseCType(deviceT))
optional_devices = [
a.name for a in scalar_args if a.lazy_type == optional_device
]
assert (
len(value_types_names) > 0 or len(optional_devices) > 0
), "Expected at least one Value or Device type"
get_device_str = (
f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
)
return f"""auto common_device = {get_device_str};
TORCH_INTERNAL_ASSERT(common_device);
"""
def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
metadata = self.backend_index.get_kernel(func)
assert metadata is not None
all_args = schema.filtered_args()
returns_length = len(schema.returns)
# call the meta kernel if it exists, to compute output shape/dtype for our IR
# Note [Generated LTC Shape Functions]
# LTC uses meta tensors from core to do shape inference when possible, and otherwise
# we generate a shape function declaration that needs to be manually implemented.
# How do we detect which ops are eligible to use meta tensors?
# In general we should be able to use meta tensors not just on structured operators,
# but also on composite operators that are implemented in terms of structured kernels.
# We don't currently have a way of knowing at codegen time which ops are implemented that way.
# This is the case for all view and view_copy operators however, so we're going to
# use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
is_view_copy_op = "view_copy" in func.tags
is_structured = func.structured or func.structured_delegate is not None
if is_structured or is_view_copy_op:
meta_out = """
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
if returns_length > 1:
def this_shape(i: int) -> str:
return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
# Convert tensor args to the meta device and call it.
# (We can't pass in the input tensors directly, because they are "functional wrappers".
# If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
# Even at::meta:: functions might redispatch, e.g. if they call into view ops.
dispatcher_sig = DispatcherSignature.from_schema(func.func)
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [
e.expr
for e in translate(
meta_call_ctx, dispatcher_sig.arguments(), method=False
)
]
if is_view_copy_op:
# view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
assert func.has_composite_explicit_autograd_non_functional_kernel
dispatch_ns = "compositeexplicitautogradnonfunctional"
else:
dispatch_ns = "meta"
aten_name = schema.aten_name
# TODO: this is trolling
if func.func.has_symint() and metadata.supports_symint():
aten_name += "_symint"
shape_str = f"""\
{meta_conversion_str}
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
{meta_out}"""
else:
shape_sig = ComputeShapeSignature(
metadata.kernel, func, symint=metadata.supports_symint()
)
shape_str = f"""
auto shapes = {shape_sig.shape_call};"""
shape_str += f"""
TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
# Calculating which dimensions are symbolic
func_schema_str = "aten::" + str(func.func)
shape_str += f"""
if(torch::lazy::symbolicShapeEnabled()){{
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
const char* schema_str = "{func_schema_str}";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}}
"""
return shape_str
def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
node_ctor_input_str = node_ctor_inputs(schema)
return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
if (!node) {{
{self.shape_inference(func, schema)}
node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
CacheNode(node);
}}
"""
def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
# xla uses an instance method for tensor creation, for the time being
if self.create_from_first_tensor:
# TODO(whc) remove this if XLA switches to using static method for creation
assert (
first_tensor_name is not None
), "Requires first tensor to create lazy tensor"
return f"{first_tensor_name}.{self.create_tensor}"
return f"{self.backend_namespace}::{self.create_tensor}"
def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
returns_length = len(schema.returns)
value_args = schema.filtered_args(values=True, scalars=False)
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
if returns_length > 1:
assert (
len(value_types_names) > 0
), "Code below assumes there is at least one tensor arg"
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
for (int i = 0; i < {returns_length}; i++) {{
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
}}
auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
if schema.name.name.inplace or func.func.is_out_fn():
assert returns_length == 1, (
"We assumed there was no such case where an op is an in-place variant "
f"and has tuple outputs, but got tuple of len {returns_length}."
)
bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
auto& result = {first_tensor_name};"""
bridge_str += """
return result;"""
return bridge_str
@method_with_native_function
def __call__(self, func: NativeFunction) -> list[str]:
sig = kernel_signature(func, self.backend_index)
metadata = self.backend_index.get_kernel(func)
assert metadata is not None
schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
return [
f"""\
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
{self.force_eager_fallback(func, schema, metadata, sig)}
{self.metrics(func, schema)}
{self.get_device(func, schema)}
{self.lazy_tensor_decls(func, schema)}
{self.build_ir_node(func, schema)}
{self.return_aten_tensor(func, schema)}
}}\n
"""
]
class ComputeShapeSignature:
"""
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
"""
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
self.__schema = LazyIrSchema(f.func, symint=symint)
self.__dispatch_args = ", ".join(
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
)
self.__call_args = ", ".join(
[f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
)
self.__kernel_name = kernel_name
def __decl_suffix(self) -> str:
return f"{self.__kernel_name}({self.__dispatch_args})"
def __call_suffix(self) -> str:
return f"{self.__kernel_name}({self.__call_args})"
@property
def shape_decl(self) -> str:
return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
@property
def shape_call(self) -> str:
return f"torch::lazy::compute_shape_{self.__call_suffix()}"
@dataclass(frozen=True)
class GenLazyShapeInferenceDefinition:
backend_index: BackendIndex
tensor_class: str
@method_with_native_function
def __call__(self, f: NativeFunction) -> list[str]:
metadata = self.backend_index.get_kernel(f)
assert metadata is not None
# See Note [Generated LTC Shape Functions]
is_view_copy_op = "view_copy" in f.tags
is_structured = f.structured or f.structured_delegate is not None
if is_structured or is_view_copy_op:
return []
else:
shape_sig = ComputeShapeSignature(
metadata.kernel, f, symint=metadata.supports_symint()
)
return ["\n".join([f"{shape_sig.shape_decl};"])]
def generate_non_native_lazy_ir_nodes(
non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
) -> list[str]:
"""Generate the non-native lazy IR node classes"""
nodes = []
for op in non_native:
# Set default properties for Non-Native IRs
properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
for p in op.get("properties", []):
setattr(properties, p, True)
# non-native is assumed to want symint bindings if you wrote symint
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
schema.opkind = op.get("opkind")
nodes.append(gen_lazy_ir.gen(schema)[0])
return nodes

View File

@ -0,0 +1,48 @@
from torchgen.api.lazy import LazyArgument, LazyIrSchema
from torchgen.api.types import OptionalCType
def ts_lowering_body(schema: LazyIrSchema) -> str:
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
emplace_arguments = []
def get_value(arg: LazyArgument) -> str:
if isinstance(arg.lazy_type, OptionalCType):
return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
return "loctx->GetOutputOp(operand(i++))"
for arg in schema.positional_args:
if arg.is_lazy_value:
emplace_arguments.append(get_value(arg))
continue
emplace_arguments.append(f'"{arg.name}", {arg.name}')
emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments]
)
emplace_kwarg_values = [
f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
]
emplace_kwarg_scalars = [
f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
]
emplace_kwarguments = "\n ".join(
[
f"kwarguments.emplace_back({a});"
for a in emplace_kwarg_values + emplace_kwarg_scalars
]
)
return f"""\
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
"""

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import torchgen.api.meta as meta
import torchgen.api.structured as structured
from torchgen.api.types import kernel_signature
from torchgen.context import with_native_function_and_index
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
from torchgen.utils import mapMaybe
@with_native_function_and_index
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
sig = kernel_signature(f, backend_index)
metadata = backend_index.get_kernel(f)
if metadata is None:
return None
if "legacy::" in metadata.kernel:
return None
else:
prefix = "static" if backend_index.external else "TORCH_API"
return f"{prefix} {sig.decl(name=metadata.kernel)};"
@with_native_function_and_index
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
meta_name = meta.name(g)
out_args = structured.impl_arguments(g)
metadata = backend_index.get_kernel(g)
if metadata is None:
return []
prefix = "" if backend_index.external else "TORCH_API "
return [
f"""\
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
"""
]
# Generates NativeFunctions.h, a list of forward declarations of all
# actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function_and_index
def compute_native_function_declaration(
g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
) -> list[str]:
metadata = backend_index.get_kernel(g)
if isinstance(g, NativeFunctionsGroup):
if metadata is not None and metadata.structured:
if backend_index.external:
# Structured hasn't been tested with external backends yet.
raise AssertionError(
"Structured external backend functions are not implemented yet."
)
else:
return gen_structured(g, backend_index)
else:
return list(
mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
)
else:
x = gen_unstructured(g, backend_index)
return [] if x is None else [x]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,551 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
CType,
Expr,
NamedCType,
opmath_t,
scalar_t,
StructuredImplSignature,
VectorizedCType,
)
from torchgen.context import with_native_function
from torchgen.model import (
Argument,
BaseTy,
BaseType,
DispatchKey,
NativeFunctionsGroup,
ScalarType,
UfuncKey,
)
from torchgen.utils import OrderedSet
if TYPE_CHECKING:
from torchgen.api.ufunc import UfunctorBindings
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CUDA STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# NB: not bothering to generate dispatch stub forward declaration in header,
# we can just paste it whereever necessary
# TODO: use BackendIndex
# dispatch_key: DispatchKey # only CPU/CUDA right now
# Represents functors for implementing CUDA ufuncs.
# Functors are templated by scalar_t because when USERS instantiate functors
# they are templated. A functor looks something like this:
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
# using opmath_t = at::opmath_type<scalar_t>;
# opmath_t other_;
# opmath_t alpha_;
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
# : other_(other), alpha_(alpha) {}
# __device__ scalar_t operator()(scalar_t self) {
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
# }
# };
#
@dataclass(frozen=True)
class UfunctorSignature:
g: NativeFunctionsGroup
scalar_tensor_idx: int | None
name: str
def arguments(self) -> UfunctorBindings:
return ufunc.ufunctor_arguments(
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
)
def fields(self) -> list[Binding]:
# fields are renamed to have a trailing underscore, as is conventional
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
def returns_type(self) -> CType:
# TODO: don't hardcode; return type will be inferred based on tags on
# the native function
return BaseCType(scalar_t)
def decl_fields(self) -> str:
return "\n".join(f"{f.type} {f.name};" for f in self.fields())
def inline_defn_ctor(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().ctor)
# NB: hypothetically could do this with translate but the
# transition here is very regular
init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
return f"{self.name}({args_str}) : {init_str} {{}}"
def decl_apply(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().apply)
return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
@dataclass(frozen=True)
class UfuncSignature:
g: NativeFunctionsGroup
name: str
compute_t: CType
def arguments(self) -> list[Binding]:
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
def call(self, ctx: Sequence[Binding | Expr]) -> str:
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# steps:
# 1. take the functional signature
# 2. use api.ufunc to convert it to template signature. this establishes
# the type of the template function
# 3. use api.ufunc (II) to generate a split struct / operator() signature.
# this establish context in which we call the template signature
#
# StructuredImplSignature context
# ~> functor constructor sig
#
# Functor constructor context
# ~> functor fields sig
#
# Functor apply context (functor fields + functor apply sig)
# ~> template sig
#
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
num_tensors = sum(
1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
)
return num_tensors == 2
def compute_ufunc_cuda_functors(
g: NativeFunctionsGroup,
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
# First, build the functors.
ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
ufunctors: list[str] = []
loops = g.out.ufunc_inner_loop
scalar_tensor_idx_lookup = {
UfuncKey.CUDAFunctorOnSelf: 1,
UfuncKey.CUDAFunctorOnOther: 0,
UfuncKey.CUDAFunctor: None,
}
if eligible_for_binary_scalar_specialization(g):
keys = [
UfuncKey.CUDAFunctorOnSelf,
UfuncKey.CUDAFunctorOnOther,
UfuncKey.CUDAFunctor,
]
else:
keys = [UfuncKey.CUDAFunctor]
for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
assert k not in loops, f"cannot use {k} on non-binary function"
for k in keys:
# If the key was directly defined, skip functor codegen; we assume the
# user already done it for us
if k in loops:
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
)
for dtype in loops[k].supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
continue
# Note [ScalarOnly and Generic must match names for CUDA]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Otherwise, look in ANY of the generic entries. For simplicity of
# codegen, both ScalarOnly and Generic are defined, the ufunc name
# must match (if they didn't match, we'd have to generate distinct
# functors per dtype, which is awful, so we're not going to do it unless
# someone really forces us to)
ufunc_name = None
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
if lk not in loops:
continue
if ufunc_name is None:
ufunc_name = loops[lk].name
else:
# See Note [ScalarOnly and Generic must match names for CUDA]
assert (
ufunc_name == loops[lk].name
), "ScalarOnly and Generic must have same ufunc name"
supported_dtypes |= loops[lk].supported_dtypes
assert ufunc_name is not None
name = f"{k}_{ufunc_name}"
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
)
for dtype in supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
ufunc_sig = UfuncSignature(
g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
)
apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
ufunctors.append(
f"""
template <typename scalar_t>
struct {ufunctor_sig.name} {{
using opmath_t = at::opmath_type<scalar_t>;
{ufunctor_sig.decl_fields()}
{ufunctor_sig.inline_defn_ctor()}
__device__ {ufunctor_sig.decl_apply()} {{
return {ufunc_sig.call(apply_ctx)};
}}
}};
"""
)
return ufunctor_sigs, "\n".join(ufunctors)
@dataclass(frozen=True)
class BinaryScalarSpecializationConfig:
scalar_idx: int
ctor_tensor: str
ufunc_key: UfuncKey
BinaryScalarSpecializationConfigs = [
BinaryScalarSpecializationConfig(
scalar_idx=0,
ctor_tensor="self",
ufunc_key=UfuncKey.CUDAFunctorOnOther,
),
BinaryScalarSpecializationConfig(
scalar_idx=1,
ctor_tensor="other",
ufunc_key=UfuncKey.CUDAFunctorOnSelf,
),
]
def compute_ufunc_cuda_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: dict[UfuncKey, UfunctorSignature],
parent_ctx: Sequence[Binding],
) -> str:
body = "using opmath_t = at::opmath_type<scalar_t>;"
body += "if (false) {}\n" # for ease of codegen
for config in BinaryScalarSpecializationConfigs:
if config.ufunc_key not in inner_loops:
continue
ufunctor_sig = inner_loops[config.ufunc_key]
scalar_idx = config.scalar_idx + 1
# Make a copy and at the same time widen the type (not permissible
# without copy; we don't want to mutate the input argument anyway)
ctx: list[Expr | Binding] = list(parent_ctx)
ctx.append(
Expr(
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
)
)
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
)
# NB: ufunctor must be allocated before iter.remove_operand is called,
# as it relies on iter
body += f"""\
else if (iter.is_cpu_scalar({scalar_idx})) {{
{ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
iter.remove_operand({scalar_idx});
gpu_kernel(iter, ufunctor);
}}"""
ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
)
body += f"""
else {{
gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
}}
"""
return body
@with_native_function
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
# First, build the functors, indexing them by dtype
ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
# Next, build the conditionals
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
stub_sig = StubSignature(g)
return f"""
{ufunctors}
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
{dtype_cases_str}
);
}}
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
{sig.defn()} {{
{stub_sig.direct_call(sig.arguments())};
}}
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CPU STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@dataclass(frozen=True)
class StubSignature:
g: NativeFunctionsGroup
@property
def name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_stub"
@property
def kernel_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_kernel"
@property
def type_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_fn"
def arguments(self) -> list[Binding]:
return ufunc.stub_arguments(self.g)
def type(self) -> str:
cpp_args = self.arguments()
return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
def dispatch_decl(self) -> str:
return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
def dispatch_defn(self) -> str:
return f"DEFINE_DISPATCH({self.name})"
def kernel_defn(self) -> str:
return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
def type_defn(self) -> str:
return f"using {self.type_name} = {self.type()}"
# must be called from context where this is TensorIteratorBase*
def call(self, ctx: Sequence[Binding]) -> str:
return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# used in CUDA to skip the unnecessary dynamic dispatch
def direct_call(self, ctx: Sequence[Binding]) -> str:
return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
@with_native_function
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
return f"""
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.dispatch_defn()};
{sig.defn()} {{
{stub_sig.call(sig.arguments())};
}}
"""
def compute_ufunc_cpu_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: dict[UfuncKey, UfuncSignature],
parent_ctx: Sequence[Binding],
) -> str:
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
scalar_loop = inner_loops[UfuncKey.CPUScalar]
vec_loop = None
if UfuncKey.CPUVector in inner_loops:
vec_loop = inner_loops[UfuncKey.CPUVector]
# NB: We DON'T use translate here, because translate is
# incapable of CSE'ing the scalar accesses in case it is also
# used by Vectorized; also, the unpacking here is very simple
# and only affects Scalar; everything else is implicitly captured
# by the lambda
# Setup scalar in scope
body = []
ctx = []
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
if vec_loop is not None:
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(
f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
)
ctx.append(
Expr(
f"_v_{b.name}",
NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
)
)
# Setup lambda signature
# NB: simplified version of ufunctor_arguments
scalar_bindings = []
vec_bindings = []
for a in g.functional.func.arguments.flat_non_out:
if not a.type.is_tensor_like():
continue
assert a.type == BaseType(BaseTy.Tensor)
scalar_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, BaseCType(scalar_t)),
argument=a,
)
)
if vec_loop is not None:
vec_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
argument=a,
)
)
def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
r: list[Expr | Binding] = []
r.extend(ctx)
r.extend(b)
return r
body_str = "\n".join(body)
if vec_loop is not None:
return f"""
{body_str}
cpu_kernel_vec(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
);
"""
else:
return f"""
{body_str}
cpu_kernel(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
);
"""
@with_native_function
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
loops = g.out.ufunc_inner_loop
ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
lks = []
# ORDER MATTERS: this specifies overriding precedence
if k in loops: # should happen rarely
lks.append(k)
if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
lks.append(UfuncKey.ScalarOnly)
if UfuncKey.Generic in loops:
lks.append(UfuncKey.Generic)
# TODO: don't hardcode ufunc:: namespace here, should be centralized smh
for lk in lks:
for dtype in loops[lk].supported_dtypes:
compute_t: CType
if k is UfuncKey.CPUScalar:
compute_t = BaseCType(scalar_t)
elif k is UfuncKey.CPUVector:
compute_t = VectorizedCType(BaseCType(scalar_t))
else:
raise AssertionError
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
if k not in inner_ufunc_sigs:
inner_ufunc_sigs[k] = UfuncSignature(
g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
)
# Build the conditionals
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunc_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
return f"""
namespace {{
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
{dtype_cases_str}
);
}}
}} // anonymous namespace
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
"""

View File

@ -0,0 +1,149 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
from torchgen import dest
# disable import sorting to avoid circular dependency.
from torchgen.api.types import DispatcherSignature # usort: skip
from torchgen.context import method_with_native_function
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
from torchgen.utils import concatMap, Target
if TYPE_CHECKING:
from torchgen.executorch.model import ETKernelIndex
from torchgen.selective_build.selector import SelectiveBuilder
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
# model authoring side.
@dataclass(frozen=True)
class ComputeNativeFunctionStub:
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
if Variant.function not in f.variants:
return None
sig = DispatcherSignature.from_schema(
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
)
assert sig is not None
if len(f.func.returns) == 0:
ret_name = ""
elif len(f.func.returns) == 1:
if f.func.arguments.out:
ret_name = f.func.arguments.out[0].name
else:
ret_name = next(
(
a.name
for a in f.func.arguments.flat_non_out
if a.type == f.func.returns[0].type
),
"",
)
if not ret_name:
# if return type is tensor
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
# Returns an empty tensor
ret_name = "at::Tensor()"
else:
raise Exception( # noqa: TRY002
f"Can't handle this return type {f.func}"
) # noqa: TRY002
elif len(f.func.arguments.out) == len(f.func.returns):
# Returns a tuple of out arguments
tensor_type = "at::Tensor &"
comma = ", "
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
{comma.join([r.name for r in f.func.arguments.out])}
)"""
else:
assert all(
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
), f"Only support tensor returns but got {f.func.returns}"
# Returns a tuple of empty tensors
tensor_type = "at::Tensor"
comma = ", "
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
{comma.join(["at::Tensor()" for _ in f.func.returns])}
)"""
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
return f"""
{sig.defn()} {{
{ret_str}
}}
"""
def gen_custom_ops_registration(
*,
native_functions: Sequence[NativeFunction],
selector: SelectiveBuilder,
kernel_index: ETKernelIndex,
rocm: bool,
) -> tuple[str, str]:
"""
Generate custom ops registration code for dest.RegisterDispatchKey.
:param native_functions: a sequence of `NativeFunction`
:param selector: for selective build.
:param kernel_index: kernels for all the ops.
:param rocm: bool for dest.RegisterDispatchKey.
:return: generated C++ code to register custom operators into PyTorch
"""
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
dispatch_key = DispatchKey.CPU
backend_index = kernel_index._to_backend_index()
static_init_dispatch_registrations = ""
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
for native_function in native_functions:
ns_grouped_native_functions[native_function.namespace].append(native_function)
for namespace, functions in ns_grouped_native_functions.items():
if len(functions) == 0:
continue
dispatch_registrations_body = "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.REGISTRATION,
selector,
rocm=rocm,
symint=False,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
functions,
)
)
)
static_init_dispatch_registrations += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{dispatch_registrations_body}
}};"""
anonymous_definition = "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
symint=False,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
native_functions,
)
)
)
return anonymous_definition, static_init_dispatch_registrations

View File

@ -0,0 +1,370 @@
from __future__ import annotations
from typing import Sequence
from torchgen import local
from torchgen.api.types import (
ArgName,
BaseCType,
Binding,
ConstRefCType,
CType,
MutRefCType,
NamedCType,
SpecialArgName,
TupleCType,
VectorCType,
voidT,
)
from torchgen.executorch.api.types import (
ArrayRefCType,
BaseTypeToCppMapping,
OptionalCType,
scalarT,
tensorListT,
tensorT,
)
from torchgen.model import (
Argument,
Arguments,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
"""
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
functions like at::add. It also serves as a native function API, which is the signature of kernels,
since in Executorch CppSignature is the same as NativeSignature.
Difference between this file and torchgen.api.cpp.py:
- Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
- Executorch doesn't support Dimname.
- Executorch runtime doesn't support SymInt, will treat it as int.
"""
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(
t: Type,
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
) -> NamedCType | None:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
# For SymInt we simply treat it as int.
elif str(t) == "SymInt":
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
if remove_non_owning_ref_types:
if t.name == BaseTy.str:
raise AssertionError(
"string ref->value conversion: not implemented yet"
)
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if str(t.elem) == "bool":
assert t.size is not None
return NamedCType(
binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]))
)
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occurring in JIT arguments to a C++ argument type.
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
# For example, we'll return std::vector<int> instead of IntArrayRef.
# See Note [translation from C++ reference to value types]
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
) -> NamedCType:
# If it's a value type, do the value type translation
r = valuetype_type(
t,
binds=binds,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == "Tensor":
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(
binds, MutRefCType(BaseCType(tensorT))
) # TODO: fix this discrepancy
else:
return NamedCType(
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
)
elif str(t.elem) == "Scalar":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
# TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
if str(t.elem) == "Tensor":
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Dimname":
raise NotImplementedError("Executorch doesn't support Dimname")
elif str(t.elem) == "Tensor?":
return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
# N.B: returntype_type returns a CType, not a NamedCType.
# This is mostly because of the mismatch between return types and return names.
# e.g. a function with a return type of 'void' has 0 return names,
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool) -> CType:
# placeholder is ignored
r = valuetype_type(t, binds="__placeholder__")
if r is not None:
return r.type
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
if local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
else:
return MutRefCType(BaseCType(tensorT))
else:
# Note [Tensor Copy Returns]
# Currently, we use "Argument.is_write" to determine
# whether or not Tensor return types should be copies or references.
# If that ever changes, take a look at other locations of this note!
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return) -> CType:
return returntype_type(r.type, mutable=r.is_write)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return]) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0])
else:
return TupleCType([return_type(r) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: list[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
# TODO: Consider incorporating this into the data model
if f.func.name.name.inplace:
assert i == 0, "illegal inplace function with multiple returns"
name = "self"
# If we are out function, the name is the name of the
# corresponding output function (r.name will get recorded
# in field_name later.)
elif f.func.is_out_fn():
name = f.func.arguments.out[i].name
# If the return argument is explicitly named...
elif r.name:
name_conflict = any(
r.name == a.name for a in f.func.schema_order_arguments()
)
if name_conflict and not f.func.is_out_fn():
name = f"{r.name}_return"
else:
name = r.name
# If there is no explicit name and no fallback name was passed in, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
returns.append(name)
return returns
JIT_TO_CPP_DEFAULT = {
"False": "false",
"True": "true",
"None": "torch::executorch::nullopt", # UGH this one is type directed
"[]": "{}",
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
"long": "torch::executorch::kLong",
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type) -> str:
if d == "None" and str(t) == "Tensor?":
return "{}"
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ""
i = 1
while i + 1 < len(d):
if d[i] != "\\":
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i : i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == "None":
return "torch::executor::nullopt"
return default_expr(d, t.elem)
if isinstance(t, ListType):
if d.startswith("[") and d.endswith("]"):
return "{" + d[1:-1] + "}"
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument(
a: Argument | TensorOptionsArguments | SelfArgument,
*,
cpp_no_default_args: set[str],
method: bool,
faithful: bool,
has_tensor_options: bool,
) -> list[Binding]:
def sub_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
method=method,
faithful=faithful,
has_tensor_options=has_tensor_options,
)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: str | None = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [
Binding(
nctype=argument_type(a, binds=binds),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, TensorOptionsArguments):
raise NotImplementedError("Need to implement type resolution for TensorOptions")
elif isinstance(a, SelfArgument):
if method:
# Caller is responsible for installing implicit this in context!
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
arguments: Arguments,
*,
faithful: bool,
method: bool,
cpp_no_default_args: set[str],
) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
else:
args.extend(arguments.out)
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r
for a in args
for r in argument(
a,
faithful=faithful,
method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args,
)
]

View File

@ -0,0 +1,4 @@
from torchgen.executorch.api.types.types import *
from torchgen.executorch.api.types.signatures import * # usort: skip

View File

@ -0,0 +1,76 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torchgen.api.cpp as aten_cpp
from torchgen.executorch.api.types.types import contextArg
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType
from torchgen.model import FunctionSchema, NativeFunction
@dataclass(frozen=True)
class ExecutorchCppSignature:
"""
This signature is merely a CppSignature with Executorch types (optionally
contains KernelRuntimeContext as well). The inline definition of
CppSignature is generated in Functions.h and it's used by unboxing
functions.
"""
# The schema this signature is derived from
func: FunctionSchema
# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: set[str]
# Allows you to prepend an arbitrary prefix to the signature name.
# This is useful for parts of the codegen that generate wrappers around kernels,
# and need to avoid naming collisions.
prefix: str = ""
def arguments(self, *, include_context: bool = True) -> list[Binding]:
return ([contextArg] if include_context else []) + et_cpp.arguments(
self.func.arguments,
faithful=True, # always faithful, out argument at the end
method=False, # method not supported
cpp_no_default_args=self.cpp_no_default_args,
)
def name(self) -> str:
return self.prefix + aten_cpp.name(
self.func,
faithful_name_for_out_overloads=True,
)
def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
args_str = ", ".join(
a.decl() for a in self.arguments(include_context=include_context)
)
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn(self, name: str | None = None) -> str:
args = [a.defn() for a in self.arguments()]
args_str = ", ".join(args)
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def returns_type(self) -> CType:
return et_cpp.returns_type(self.func.returns)
@staticmethod
def from_native_function(
f: NativeFunction, *, prefix: str = ""
) -> ExecutorchCppSignature:
return ExecutorchCppSignature(
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
)
from torchgen.executorch.api import et_cpp

View File

@ -0,0 +1,83 @@
from __future__ import annotations
from dataclasses import dataclass
from torchgen.api.types import (
BaseCppType,
BaseCType,
Binding,
boolT,
CType,
doubleT,
Expr,
longT,
MutRefCType,
NamedCType,
)
from torchgen.model import BaseTy
halfT = BaseCppType("torch::executor", "Half")
bfloat16T = BaseCppType("torch::executor", "BFloat16")
stringT = BaseCppType("torch::executor", "string_view")
scalarTypeT = BaseCppType("torch::executor", "ScalarType")
tensorT = BaseCppType("torch::executor", "Tensor")
tensorListT = BaseCppType("torch::executor", "TensorList")
scalarT = BaseCppType("torch::executor", "Scalar")
memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
optionalT = BaseCppType("torch::executor", "optional")
contextT = BaseCppType("torch::executor", "KernelRuntimeContext")
contextExpr = Expr(
expr="context",
type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
)
contextArg = Binding(
name="context",
nctype=contextExpr.type,
argument=None, # type: ignore[arg-type]
default=None,
)
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
BaseTy.str: stringT,
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
}
@dataclass(frozen=True)
class OptionalCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"torch::executor::optional<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType(CType):
elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>"
def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
return ArrayRefCType(self.elem.remove_const_ref())

View File

@ -0,0 +1,230 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Sequence, TYPE_CHECKING
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Type,
)
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType, NamedCType
connector = "\n\t"
# Return unboxing function name for a NativeFunction
def name(f: NativeFunction) -> str:
return f.func.name.unambiguous_name()
@dataclass(frozen=True)
class Unboxing:
"""
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
A sample generated code:
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
void mul_out(EValue** stack) {
EValue& self = *stack[0];
EValue& other = *stack[1];
EValue& out = *stack[2];
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
torch::executor::mul_outf(self_base, other_base, out_base);
}
"""
# this is a callable that converts a JIT argument, into its C++ type.
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
argument_type_gen: Callable[
...,
NamedCType,
]
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(
self, args: Sequence[Binding]
) -> tuple[list[Binding], list[str]]:
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
binding_list = []
for arg in args:
# expecting only Argument
if not isinstance(arg.argument, Argument):
raise Exception( # noqa: TRY002
f"Unexpected argument type, expecting `Argument` but got {arg}"
)
argument: Argument = arg.argument
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
argument.type, argument.name, mutable=argument.is_write
)
code_list.extend(decl)
code_list.extend(code)
binding_list.append(arg.with_name(unboxed_name))
return binding_list, code_list
def argumenttype_evalue_convert(
self, t: Type, arg_name: str, *, mutable: bool = False
) -> tuple[str, CType, list[str], list[str]]:
"""
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
(1) the C++ code necessary to unbox the argument
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
:param t: a `Type` of an argument
:param arg_name: argument name
:param mutable: boolean for whether this argument type is mutable
:return: unboxed result
"""
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"
code, decl = self._gen_code_base_type(
arg_name=arg_name, out_name=out_name, ctype=ctype
)
elif isinstance(t, OptionalType):
out_name = f"{arg_name}_opt_out"
code, decl = self._gen_code_optional_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
elif isinstance(t, ListType):
out_name = f"{arg_name}_list_out"
code, decl = self._gen_code_list_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
else:
raise Exception( # noqa: TRY002
f"Cannot handle type {t}. arg_name: {arg_name}"
) # noqa: TRY002
return out_name, ctype, code, decl
def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType
) -> tuple[list[str], list[str]]:
return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_opt_in"
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name
)
return (
f"""
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
""".split(
"\n"
),
decl,
)
def _gen_code_list_type(
self, arg_name: str, out_name: str, t: ListType, ctype: CType
) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = []
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
t.elem, elem_name
)
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
code.extend(
f"""
auto {out_name} = {arg_name}.toTensorList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and (
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
):
code.extend(
f"""
auto {out_name} = {arg_name}.toIntList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
code.extend(
f"""
auto {out_name} = {arg_name}.toDoubleList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
# handle list type with size, e.g., bool[4]
code.extend(
f"""
#ifdef USE_ATEN_LIB
std::array<bool, {t.size}> {out_name};
auto {in_name} = {arg_name}.toBoolList();
size_t _i = 0;
for (auto {elem_name}: {in_name}) {{
{out_name}[_i++] = {elem_name};
}}
#else
auto {out_name} = {arg_name}.toBoolList();
#endif
""".split(
"\n"
)
)
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif (
isinstance(t.elem, OptionalType)
and isinstance(t.elem.elem, BaseType)
and t.elem.elem.name == BaseTy.Tensor
):
code.extend(
f"""
#ifdef USE_ATEN_LIB
auto {in_name} = {arg_name}.toListOptionalTensor();
c10::List<::std::optional<at::Tensor>> {out_name};
for (auto {elem_name}: {in_name}) {{
{out_name}.push_back({elem_name});
}}
#else
auto {out_name} = {arg_name}.toListOptionalTensor();
#endif
""".split(
"\n"
)
)
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
)
code.extend(
f"""
for (EValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
return code, decl

View File

@ -0,0 +1,220 @@
# Represents all kernels used by an Executorch model.
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
from __future__ import annotations
import itertools
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from enum import IntEnum
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
)
from torchgen.utils import assert_never
KERNEL_KEY_VERSION = 1
# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
class ScalarType(IntEnum):
Byte = 0
Char = 1
Short = 2
Int = 3
Long = 4
Float = 6
Double = 7
Bool = 11
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
@dataclass(frozen=True)
class ETKernelKeyOpArgMeta:
arg_name: str
dtype: str
# The order of the dimensions if entry is a Tensor
dim_order: tuple[int, ...]
def to_native_string(self) -> str:
dtype_str = ScalarType[self.dtype].value
dim_str = str(self.dim_order)[1:-1].replace(" ", "")
return f"{dtype_str};{dim_str}"
@dataclass(frozen=True)
class ETKernelKey:
# Field undefined is default = True
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
# Indicator for this kernel being used as a catch all
default: bool = False
version: int = KERNEL_KEY_VERSION
@staticmethod
def gen_from_yaml(
args: dict[str, tuple[str, str]],
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
dim_order_alias_map: dict[str, list[int]],
) -> list[ETKernelKey]:
"""Generate ETKernelKeys from arg kernel specs
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
type_alias_map (actualizing each potential type permutation as a KernelKey)
Args:
args: Mapping from argument name to kernel specs
Kernel specs are a tuple of (dtype, dim_order).
Currently tuple entries must be aliased via the alias map arguments
type_alias_map: Mapping from type alias to potential type enums
i.e { T0 : [Double, Int] } means T0 can be either Double or Int
Used for lookup by args
dim_order_alias_map: Mapping from alias to a list of dimension orders
Used for lookup by args
"""
# Cast to dim order to int
dim_order_alias_map = {
k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
}
kernel_keys = []
# Get all used Dtype Alias
dtype_alias_used = set()
for type_alias, dim_order in args.values():
# Enforce usage of alias initially
# TODO: Support inlined arguments
assert type_alias in type_alias_map, "Undefined type alias: " + str(
type_alias
)
assert (
dim_order in dim_order_alias_map
), "Undefined dim_order alias: " + str(dim_order)
dtype_alias_used.add(type_alias)
# Generate all permutations of dtype alias values
alias_dtypes = [
[(alias, dtype) for dtype in type_alias_map[alias]]
for alias in dtype_alias_used
]
alias_permutations = [
dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
]
# Using each alias value permutation, generate kernel keys
op_arg_cache = {}
for permutation in alias_permutations:
arg_list = []
for arg_name, arg_spec in args.items():
dtype = permutation[arg_spec[0]]
dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
if (
cache_key := (arg_name, dtype, tuple(dim_order))
) not in op_arg_cache:
op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
arg_list.append(op_arg_cache[cache_key])
kernel_keys.append(ETKernelKey(tuple(arg_list)))
return kernel_keys
def to_native_string(self) -> str:
if self.default:
return "default"
return (
"v"
+ str(KERNEL_KEY_VERSION)
+ "/"
+ "|".join([arg.to_native_string() for arg in self.arg_meta])
)
@dataclass(frozen=True)
class ETKernelIndex:
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
m = self.get_kernels(g)
return m is not None
def get_kernels(
self, g: NativeFunction | NativeFunctionsGroup
) -> dict[ETKernelKey, BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
f = g.functional
else:
assert_never(g)
if f.func.name not in self.index:
return {}
return self.index[f.func.name]
@staticmethod
def grow_from_backend_indices(
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> None:
for dk in backend_indices:
index = backend_indices[dk]
for op, backend_metadata in index.items():
if op in kernel_index:
kernel_index[op][ETKernelKey(default=True)] = backend_metadata
else:
kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
@staticmethod
def from_backend_indices(
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> ETKernelIndex:
kernel_index: dict[
OperatorName, dict[ETKernelKey, BackendMetadata]
] = defaultdict(dict)
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
return ETKernelIndex(kernel_index)
def grow(
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> ETKernelIndex:
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
return self
def _to_backend_index(self) -> BackendIndex:
"""
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
"""
index: dict[OperatorName, BackendMetadata] = {}
for op in self.index:
kernel_dict = self.index[op]
assert (
len(kernel_dict.values()) == 1
), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
index[op] = kernel_dict.get(
ETKernelKey(default=True),
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
)
return BackendIndex(
dispatch_key=DispatchKey.CPU,
use_out_as_primary=False,
device_guard=False,
external=False,
index=index,
)
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
@staticmethod
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
combined = defaultdict(dict, index_a.index.copy())
for op, entry in index_b.index.items():
for key, metadata in entry.items():
combined[op][key] = metadata
return ETKernelIndex(combined)

View File

@ -0,0 +1,153 @@
from __future__ import annotations
from collections import defaultdict, namedtuple
from typing import Any
import yaml
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader, parse_native_yaml
from torchgen.model import (
BackendMetadata,
DispatchKey,
FunctionSchema,
NativeFunction,
OperatorName,
)
from torchgen.utils import NamespaceHelper
# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
# Fields in native_functions.yaml used to determine which kernels should be used
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
"""Given a loaded yaml representing kernel assignment information, extract the
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
Args:
ei: Dict keys {kernels, type_alias, dim_order_alias}
See ETKernelKey for description of arguments
"""
e = ei.copy()
if (kernels := e.pop("kernels", None)) is None:
return {}
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
dim_order_alias.pop("__line__", None)
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
for entry in kernels: # type: ignore[attr-defined]
arg_meta = entry.get("arg_meta")
if arg_meta is not None:
arg_meta.pop("__line__")
kernel_name = entry.get("kernel_name")
namespace_helper = NamespaceHelper.from_namespaced_entity(
kernel_name, max_level=3
)
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
backend_metadata = BackendMetadata(
kernel=namespace_helper.entity_name,
structured=False,
cpp_namespace=(kernel_namespace + "::native"),
)
kernel_keys = (
[ETKernelKey((), default=True)]
if arg_meta is None
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
)
for kernel_key in kernel_keys:
assert kernel_key not in kernel_mapping, (
"Duplicate kernel key: " + str(kernel_key) + " " + str(e)
)
kernel_mapping[kernel_key] = backend_metadata
return kernel_mapping
def parse_et_yaml_struct(es: object) -> ETKernelIndex:
"""Given a loaded yaml representing a list of operators, for each op extract the mapping
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
that should be used by the kernel key).
"""
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
for ei in es: # type: ignore[attr-defined]
e = ei.copy()
funcs = e.pop("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
opname = FunctionSchema.parse(namespace_helper.entity_name).name
assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
if len(index := parse_from_yaml(e)) != 0:
indices[opname] = index
return ETKernelIndex(indices)
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
"""Given a loaded yaml representing a list of operators, extract the
kernel key related fields indexed by the operator name.
"""
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
for ei in es: # type: ignore[attr-defined]
funcs = ei.get("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
opname = FunctionSchema.parse(namespace_helper.entity_name).name
for field in ET_FIELDS:
if (value := ei.get(field)) is not None:
fields[opname][field] = value
return fields
def parse_et_yaml(
path: str,
tags_yaml_path: str,
ignore_keys: set[DispatchKey] | None = None,
skip_native_fns_gen: bool = False,
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
of fields to persist from native_functions.yaml to functions.yaml
"""
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
et_kernel = extract_kernel_fields(es)
# Remove ET specific fields from entries for BC compatibility
strip_et_fields(es)
native_yaml = parse_native_yaml(
path,
tags_yaml_path,
ignore_keys,
skip_native_fns_gen=skip_native_fns_gen,
loaded_yaml=es,
)
return native_yaml.native_functions, et_kernel
def strip_et_fields(es: object) -> None:
"""Given a loaded yaml representing a list of operators,
remove ET specific fields from every entries for BC compatibility
"""
for entry in es: # type: ignore[attr-defined]
for field in ET_FIELDS:
entry.pop(field, None)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,486 @@
from __future__ import annotations
import textwrap
from dataclasses import dataclass
from typing import Sequence
from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
from torchgen.context import method_with_native_function
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
BaseType,
DispatchKey,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
OptionalType,
Type,
)
from torchgen.utils import mapMaybe
base_type_to_c_type = {
BaseTy.Tensor: "AtenTensorHandle",
BaseTy.bool: "int32_t", # Use int to pass bool
BaseTy.int: "int64_t",
BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt
BaseTy.Scalar: "double", # Use double to pass both integer and floating point
BaseTy.float: "double", # TODO: how about other floating point types?
BaseTy.str: "const char*",
BaseTy.DeviceIndex: "int32_t",
BaseTy.Layout: "int32_t", # Represent enum as int
BaseTy.MemoryFormat: "int32_t", # Represent enum as int
BaseTy.ScalarType: "int32_t", # Represent enum as int
BaseTy.Generator: "AtenGeneratorHandle",
}
base_type_to_aten_type = {
BaseTy.Tensor: "at::Tensor",
BaseTy.bool: "bool",
BaseTy.int: "int64_t",
BaseTy.SymInt: "c10::SymInt",
BaseTy.Scalar: "c10::Scalar",
BaseTy.float: "double",
BaseTy.str: "c10::string_view",
BaseTy.DeviceIndex: "c10::DeviceIndex",
BaseTy.Layout: "c10::Layout",
BaseTy.MemoryFormat: "c10::MemoryFormat",
BaseTy.ScalarType: "c10::ScalarType",
BaseTy.Generator: "at::Generator",
}
base_type_to_callsite_expr = {
BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
BaseTy.bool: "",
BaseTy.int: "",
BaseTy.SymInt: "",
BaseTy.Scalar: "",
BaseTy.float: "",
BaseTy.str: "",
BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
BaseTy.Layout: "static_cast<c10::Layout>",
BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
BaseTy.ScalarType: "static_cast<c10::ScalarType>",
BaseTy.Generator: "*generator_handle_to_generator_pointer",
}
# convert args to C types, names in declarations, and expressions in function bodies
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
if isinstance(typ, BaseType):
if typ.name in base_type_to_c_type:
return (
[base_type_to_c_type[typ.name]],
[name],
[base_type_to_aten_type[typ.name]],
[
f"{base_type_to_callsite_expr[typ.name]}({name})"
if base_type_to_callsite_expr[typ.name]
else name
],
)
elif typ.name == BaseTy.Device:
return (
["int32_t", "int32_t"],
[name, name + "_index_"],
["c10::Device"],
[
f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
],
)
else:
# TODO: BaseTy.Dimname, etc.
raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
elif isinstance(typ, OptionalType):
c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
typ.elem, name
)
j = 0 # index for names
new_aten_types = []
new_callsite_exprs = []
for aten_type in aten_types:
# Use pointer to denote optional type
c_types[j] = c_types[j] + "*"
if aten_type.startswith("c10::ArrayRef<"):
# ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
new_aten_types.append(f"::std::optional<{aten_type}>")
base_type = aten_type[len("c10::ArrayRef<") : -1]
new_callsite_exprs.append(
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
)
j += 2
elif aten_type == "c10::Device":
# Device is passed as device_type + device_index
new_aten_types.append("::std::optional<c10::Device>")
new_callsite_exprs.append(
f"pointer_to_optional_device({names[j]}, {names[j+1]})"
)
j += 2
else:
new_aten_types.append(f"::std::optional<{aten_type}>")
new_callsite_exprs.append(
f"pointer_to_optional<{aten_type}>({names[j]})"
)
j += 1
return (
c_types,
names,
new_aten_types,
new_callsite_exprs,
)
elif isinstance(typ, ListType):
# Need to explictly pass the list as pointer + length
c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
# The list content should never be modified
c_types[0] = f"const {c_types[0]}*"
c_types.append("int64_t")
name = names[0]
names.append(name + "_len_")
atype = aten_types[0]
callsite_exprs = []
if atype == "bool":
# no converter from std::vector<bool> to c10::ArrayRef<bool>
# construct std::array<bool, N> instead
assert typ.size is not None
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
elif atype == "::std::optional<at::Tensor>":
# convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
callsite_exprs.append(
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
)
else:
callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
return (
c_types,
names,
aten_types,
callsite_exprs,
)
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
return [typ + " " + name for typ, name in zip(types, names)]
# Generate argument declarations and callsite expressions
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
types = []
new_names = []
callsite_exprs = []
for arg in flat_arguments:
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
arg.type, arg.name
)
types.extend(new_types)
new_names.extend(names)
callsite_exprs.extend(new_callsite_exprs)
return zip_type_and_name(types, new_names), callsite_exprs
# Return values are passed out as pointer arguments because all the C shim functions
# are expected to return AOTITorchError.
# Generate returns as declarations and callsite expressions
def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
types = []
names = []
for idx, ret in enumerate(schema.returns):
names.append(f"ret{idx}")
if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
types.append(base_type_to_c_type[ret.type.name] + "*")
else:
raise NotImplementedError(
f"TODO: add support for return type {repr(ret.type)}"
)
def convert_return(typ: BaseType, val: str) -> str:
if typ.name == BaseTy.Tensor:
return f"new_tensor_handle(std::move({val}));"
elif typ.name == BaseTy.SymInt:
return f"{val}.expect_int()"
elif typ.name == BaseTy.Scalar:
return f"{val}.toDouble()"
else:
return val
ret_pointer_can_be_null = False
unambiguous_name = schema.name.unambiguous_name()
for name in [
"_scaled_dot_product_flash_attention",
"_scaled_dot_product_efficient_attention",
"_scaled_dot_product_cudnn_attention",
"convolution_backward",
]:
if name in unambiguous_name:
ret_pointer_can_be_null = True
break
callsite_exprs: list[str] = []
for idx, ret in enumerate(schema.returns):
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
assert isinstance(ret.type, BaseType)
rval = convert_return(ret.type, tmp)
if ret_pointer_can_be_null:
callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
else:
callsite_exprs.append(f"*{names[idx]} = {rval};")
return zip_type_and_name(types, names), callsite_exprs
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
def gen_declaration_and_definition(
schema: FunctionSchema, device: str, backend_call: str
) -> tuple[str, str]:
func_name = schema.name.unambiguous_name()
global declaration_definition_cache
if (func_name, device, backend_call) in declaration_definition_cache:
return declaration_definition_cache[(func_name, device, backend_call)]
if schema.is_out_fn():
# out_variant has out arguments in the front, and it's ok to ignore return values
# because C shim functions only return AOTITorchError
args, callsite_exprs = gen_arguments(
[*schema.arguments.out, *schema.arguments.flat_non_out]
)
ret_assignments: list[str] = []
else:
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
# ignore return values for inplace ops
ret_declarations, ret_assignments = (
([], []) if schema.name.name.inplace else gen_returns(schema)
)
args.extend(ret_declarations)
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
tmp_result = "auto tmp_result = " if ret_assignments else ""
ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
definition = f"""
{declaration} {{
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
{tmp_result}{backend_call}(
{textwrap.indent(', '.join(callsite_exprs), " ")}
);{textwrap.indent(ret_assignments_str, " ")}
}});
}}
"""
declaration_definition_cache[(func_name, device, backend_call)] = (
declaration,
definition,
)
return declaration, definition
def gen_static_dispatch_backend_call_signature(
sig: CppSignature | DispatcherSignature,
f: NativeFunction,
) -> CppSignature:
sig = DispatcherSignature.from_schema(f.func)
cpp_sigs = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
)
if sig.symint and f.func.has_symint():
cpp_sig = cpp_sigs.symint_signature
else:
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
return cpp_sig
def gen_static_dispatch_backend_call(
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
sig = DispatcherSignature.from_schema(f.func)
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
def get_backend_index_for_aoti(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
) -> BackendIndex | None:
backend_index = None
if backend_indices[dispatch_key].has_kernel(func) or (
func.structured_delegate is not None
and func.structured_delegate in func_group_mapping
and backend_indices[dispatch_key].has_kernel(
func_group_mapping[func.structured_delegate]
)
):
backend_index = backend_indices[dispatch_key]
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
func
):
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
backend_index = backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
return backend_index
def get_header_for_aoti(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
) -> str | None:
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices
)
return (
None
if backend_index is None
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
)
def get_fallback_op_name(func: NativeFunction) -> str:
return (
f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
if func.func.name.overload_name
else f"{func.namespace}.{func.func.name.name}.default"
)
def gen_c_shim(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
header: bool,
) -> str | None:
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices
)
if backend_index is None:
return None
schema = func.func
device = dispatch_key.lower()
backend_call = gen_static_dispatch_backend_call(
func,
backend_index,
)
try:
if header:
declaration, _ = gen_declaration_and_definition(
schema, device, backend_call
)
return f"AOTI_TORCH_EXPORT {declaration};"
else:
_, definition = gen_declaration_and_definition(schema, device, backend_call)
return definition
except NotImplementedError:
return None
@dataclass(frozen=True)
class ShimGenerator:
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
dispatch_key: DispatchKey
backend_indices: dict[DispatchKey, BackendIndex]
header: bool # True to generate .h and False to generate .cpp
@method_with_native_function
def __call__(
self,
func: NativeFunction,
) -> str | None:
result = gen_c_shim(
func,
self.func_group_mapping,
self.dispatch_key,
self.backend_indices,
self.header,
)
return result
def gen_aoti_c_shim(
native_functions: Sequence[NativeFunction],
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
header: bool,
includes: str = "",
) -> str:
body = "\n".join(
list(
mapMaybe(
ShimGenerator(
func_group_mapping, dispatch_key, backend_indices, header
),
native_functions,
)
)
)
device = dispatch_key.lower()
warning = """
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
if header:
return f"""
{warning}
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#ifdef __cplusplus
extern "C" {{
#endif
{body}
#ifdef __cplusplus
}} // extern "C"
#endif
"""
else:
return f"""
{warning}
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/{str(dispatch_key)}Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
{includes}
#endif
using namespace torch::aot_inductor;
{body}"""

View File

@ -0,0 +1,611 @@
from __future__ import annotations
import argparse
import os
import re
from collections import Counter, defaultdict, namedtuple
from pathlib import Path
from typing import Sequence
import yaml
import torchgen.api.dispatcher as dispatcher
import torchgen.dest as dest
from torchgen.api.types import DispatcherSignature
from torchgen.code_template import CodeTemplate
from torchgen.context import native_function_manager
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Target
from torchgen.yaml_utils import YamlLoader
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
ParsedExternalYaml = namedtuple(
"ParsedExternalYaml",
["backend_key", "autograd_key", "class_name", "cpp_namespace", "backend_indices"],
)
def parse_backend_yaml(
backend_yaml_path: str,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: dict[DispatchKey, BackendIndex],
) -> ParsedExternalYaml:
native_functions_map: dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(
lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
grouped_native_functions,
)
}
with open(backend_yaml_path) as f:
yaml_values = yaml.load(f, Loader=YamlLoader)
assert isinstance(yaml_values, dict)
valid_keys = [
"backend",
"class_name",
"cpp_namespace",
"extra_headers",
"supported",
"autograd",
"full_codegen",
"non_native",
"ir_gen",
"symint",
]
backend = yaml_values.pop("backend", None)
assert backend is not None, 'You must provide a value for "backend"'
class_name = yaml_values.pop("class_name", None)
cpp_namespace = yaml_values.pop("cpp_namespace", None)
assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"'
# Mostly just defaulting to false to stick with LazyTensor convention.
use_out_as_primary = yaml_values.pop("use_out_as_primary", False)
assert isinstance(
use_out_as_primary, bool
), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"
use_device_guard = yaml_values.pop("device_guard", False)
assert isinstance(
use_device_guard, bool
), f"You must provide either True or False for device_guard. Provided: {use_device_guard}"
supported = yaml_values.pop("supported", [])
if supported is None:
supported = [] # Allow an empty list of supported ops
assert isinstance(
supported, list
), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
symint = yaml_values.pop("symint", [])
if symint is None:
symint = [] # Allow an empty list of symint ops
assert isinstance(
symint, list
), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
symint_set = set(symint)
supported_autograd = yaml_values.pop("autograd", [])
assert isinstance(
supported_autograd, list
), f'expected "autograd" to be a list, but got: {supported_autograd}'
# full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
full_codegen = yaml_values.pop("full_codegen", [])
supported.extend(full_codegen)
# non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
yaml_values.pop("non_native", {})
# ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
yaml_values.pop("ir_gen", {})
assert (
len(yaml_values.keys()) == 0
), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
Only the following keys are supported: {", ".join(valid_keys)}'
def create_backend_index(
backend_ops: list[str],
symint_ops: set[str],
dispatch_key: DispatchKey,
*,
use_out_as_primary: bool,
use_device_guard: bool,
) -> BackendIndex:
metadata: dict[OperatorName, BackendMetadata] = {}
for op in backend_ops:
op_name = OperatorName.parse(op)
assert (
op_name in native_functions_map
), f"Found an invalid operator name: {op_name}"
# See Note [External Backends Follow Dispatcher API]
kernel_name = dispatcher.name(native_functions_map[op_name].func)
if op in symint_ops:
kernel_name += "_symint"
# TODO: allow structured external backends later.
m = BackendMetadata(
kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
)
metadata[op_name] = m
return BackendIndex(
dispatch_key=dispatch_key,
use_out_as_primary=use_out_as_primary,
external=True,
device_guard=use_device_guard,
index=metadata,
)
backend_key: DispatchKey | None = None
if len(supported) > 0:
with context(
lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
):
backend_key = DispatchKey.parse(backend)
backend_idx = create_backend_index(
supported,
symint_set,
backend_key,
use_out_as_primary=use_out_as_primary,
use_device_guard=use_device_guard,
)
assert backend_key not in backend_indices
backend_indices[backend_key] = backend_idx
autograd_key: DispatchKey | None = None
if len(supported_autograd) > 0:
with context(
lambda: f'The "autograd" key was specified, which indicates that you would like to override \
the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'
):
autograd_key = DispatchKey.parse(f"Autograd{backend}")
autograd_idx = create_backend_index(
supported_autograd,
symint_set,
autograd_key,
use_out_as_primary=use_out_as_primary,
use_device_guard=use_device_guard,
)
assert autograd_key not in backend_indices
backend_indices[autograd_key] = autograd_idx
for g in grouped_native_functions:
if isinstance(g, NativeFunction):
forward_kernels = (
[]
if backend_key is None
else [
m
for m in [backend_indices[backend_key].get_kernel(g)]
if m is not None
]
)
backward_kernels = (
[]
if autograd_key is None
else [
m
for m in [backend_indices[autograd_key].get_kernel(g)]
if m is not None
]
)
else:
forward_kernels = (
[]
if backend_key is None
else [
m
for m in [
backend_indices[backend_key].get_kernel(f)
for f in g.functions()
]
if m is not None
]
)
backward_kernels = (
[]
if autograd_key is None
else [
m
for m in [
backend_indices[autograd_key].get_kernel(f)
for f in g.functions()
]
if m is not None
]
)
forward_kernels = [f for f in forward_kernels if f is not None]
backward_kernels = [f for f in backward_kernels if f is not None]
assert (
len(forward_kernels) == 0 or len(backward_kernels) == 0
), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \
{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".'
return ParsedExternalYaml(
backend_key, autograd_key, class_name, cpp_namespace, backend_indices
)
def error_on_missing_kernels(
native_functions: Sequence[NativeFunction],
backend_indices: dict[DispatchKey, BackendIndex],
backend_key: DispatchKey,
autograd_key: DispatchKey | None,
class_name: str,
kernel_defn_file_path: str,
full_codegen: list[OperatorName] | None = None,
) -> None:
try:
with open(kernel_defn_file_path) as f:
backend_defns = f.read()
except OSError as e:
raise AssertionError(
f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
) from e
if full_codegen is None:
full_codegen = []
indices = [backend_indices[backend_key].index] + (
[] if autograd_key is None else [backend_indices[autograd_key].index]
)
# Quick mapping from each OperatorName used by the external backend
# to its backend kernel name
expected_backend_op_names: dict[OperatorName, str] = dict(
list(
concatMap(
lambda index: [
(op_name, metadata.kernel) for op_name, metadata in index.items()
],
indices,
)
)
)
expected_backend_native_funcs: list[NativeFunction] = [
f
for f in native_functions
if f.func.name in expected_backend_op_names.keys()
and f.func.name not in full_codegen
]
expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
list
)
for native_f in expected_backend_native_funcs:
expected_backend_kernel_name_counts[
expected_backend_op_names[native_f.func.name]
].append(native_f)
# This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
# It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
# here, then we get a nicer error message. If we miss it, you get a linker error.
kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\("
actual_backend_kernel_name_counts = Counter(
# A bit unwieldy (this could probably be moved into regex),
# but we don't want to include kernel names that come from function calls,
# like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)".
# Easy check is to ignore any lines with colons before the class name.
[
y
for (x, y) in re.findall(kernel_defn_regex, backend_defns)
if not x.endswith(":")
]
)
missing_kernels_err_msg = ""
for expected_name, funcs in expected_backend_kernel_name_counts.items():
expected_overload_count = len(funcs)
actual_overload_count = actual_backend_kernel_name_counts[expected_name]
if expected_overload_count != actual_overload_count:
def create_decl(f: NativeFunction) -> str:
with native_function_manager(f):
return DispatcherSignature.from_schema(f.func).decl()
expected_schemas_str = "\n".join([create_decl(f) for f in funcs])
missing_kernels_err_msg += f"""
{class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name,
but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are:
{expected_schemas_str}
"""
assert missing_kernels_err_msg == "", missing_kernels_err_msg
def main() -> None:
parser = argparse.ArgumentParser(description="Generate backend stub files")
parser.add_argument(
"-s",
"--source-yaml",
"--source_yaml",
help="path to source yaml file containing operator external definitions",
)
parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
parser.add_argument(
"--dry-run", "--dry_run", type=bool, default=False, help="output directory"
)
parser.add_argument(
"--impl-path",
"--impl_path",
type=str,
default=None,
help="path to the source C++ file containing kernel definitions",
)
options = parser.parse_args()
run(options.source_yaml, options.output_dir, options.dry_run, options.impl_path)
def gen_dispatchkey_nativefunc_headers(
fm: FileManager,
class_name: str,
cpp_namespace: str,
backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_dispatch_key: DispatchKey,
autograd_dispatch_key: DispatchKey | None,
backend_name: str = "",
) -> None:
assert class_name is not None
generated_comment = (
"Autogenerated file by gen_backend_stubs.py. Do not edit directly!"
)
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
# Sort for deterministic output.
backend_declarations = sorted(
set(
concatMap(
lambda f: dest.compute_native_function_declaration(
f, backend_indices[backend_dispatch_key]
),
grouped_native_functions,
)
)
)
autograd_declarations = sorted(
set(
concatMap(
lambda f: []
if autograd_dispatch_key is None
else dest.compute_native_function_declaration(
f, backend_indices[autograd_dispatch_key]
),
grouped_native_functions,
)
)
)
ns_helper = NamespaceHelper(cpp_namespace)
fm.write_with_template(
f"{backend_dispatch_key}NativeFunctions.h",
"DispatchKeyNativeFunctions.h",
lambda: {
"generated_comment": generated_comment,
"namespace_prologue": ns_helper.prologue,
"class_name": class_name,
"namespace_epilogue": ns_helper.epilogue,
"dispatch_declarations": backend_declarations + autograd_declarations,
"BackendName": backend_name,
"DispatchKey": backend_dispatch_key,
},
)
def gen_dispatcher_registrations(
fm: FileManager,
output_dir: str,
class_name: str,
backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_dispatch_key: DispatchKey,
dispatch_key: DispatchKey,
selector: SelectiveBuilder,
# build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
build_in_tree: bool = False,
per_operator_headers: bool = False,
backend_name: str = "",
eager_registration: bool = True,
) -> None:
headers = [
f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
]
if build_in_tree:
external_backend_headers_str = "\n".join(f"#include <{h}>" for h in headers)
else:
external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers)
assert class_name is not None
backend_index = backend_indices[dispatch_key]
dispatch_registrations_body = list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.REGISTRATION,
selector,
rocm=False,
symint=True,
class_method_name=f"{class_name}",
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
)
newline = "\n"
ns_helper = NamespaceHelper(namespace_str="at")
deferred_dispatch_registrations = ""
static_init_dispatch_registrations = ""
if eager_registration:
static_template = CodeTemplate(
"""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
$dispatch_registrations_body
};"""
)
static_init_dispatch_registrations = static_template.substitute(
dispatch_key=dispatch_key,
dispatch_registrations_body=dispatch_registrations_body,
)
else:
deferred_template = CodeTemplate(
"""\
TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions();
TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
$dispatch_registrations_body
}"""
)
deferred_dispatch_registrations = deferred_template.substitute(
backend_name=backend_name,
dispatch_key=dispatch_key,
dispatch_registrations_body=dispatch_registrations_body,
)
fm.write_with_template(
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
lambda: {
"extra_cuda_headers": "",
"external_backend_headers": external_backend_headers_str,
"ops_headers": "#include <ATen/Functions.h>"
if not per_operator_headers
else "",
"DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(),
"dispatch_headers": dest.gen_registration_headers(
backend_index, per_operator_headers=per_operator_headers, rocm=False
),
"dispatch_definitions": fm.substitute_with_template(
"RegisterDispatchDefinitions.ini",
lambda: {
"ns_prologue": ns_helper.prologue,
"ns_epilogue": ns_helper.epilogue,
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": deferred_dispatch_registrations,
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": "",
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=False,
symint=True,
class_method_name=f"{class_name}",
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
),
},
).split(newline),
},
)
def run(
source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
) -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
pytorch_root = Path(__file__).parent.parent.absolute()
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(
install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
)
fm = make_file_manager(output_dir)
native_yaml_path = os.path.join(
pytorch_root, "aten/src/ATen/native/native_functions.yaml"
)
tags_yaml_path = os.path.join(pytorch_root, "aten/src/ATen/native/tags.yaml")
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
native_functions, backend_indices = (
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
)
grouped_native_functions = get_grouped_native_functions(native_functions)
parsed_backend_yaml = parse_backend_yaml(
source_yaml, grouped_native_functions, backend_indices
)
backend_key = parsed_backend_yaml.backend_key
autograd_key = parsed_backend_yaml.autograd_key
cpp_namespace = parsed_backend_yaml.cpp_namespace
class_name = parsed_backend_yaml.class_name
backend_indices = parsed_backend_yaml.backend_indices
selector = SelectiveBuilder.get_nop_selector()
if backend_key is None:
# This could be useful if a backend wants to quickly set up a noop yaml file but doesn't have any kernels ready yet.
return
if class_name is None:
# class_name is an optional argument to backend yaml file.
# if specified it allows an external backend to override
# the name of the class that all generated kernel definitions live under.
# if not specified, its value is given as native_function_class_name.
class_name = backend_indices[backend_key].native_function_class_name()
assert class_name is not None
if impl_path is not None:
error_on_missing_kernels(
native_functions,
backend_indices,
backend_key,
autograd_key,
class_name,
impl_path,
)
gen_dispatchkey_nativefunc_headers(
fm,
class_name,
cpp_namespace,
backend_indices,
grouped_native_functions,
backend_key,
autograd_key,
)
for dispatch_key in (
[backend_key] if autograd_key is None else [backend_key, autograd_key]
):
gen_dispatcher_registrations(
fm,
output_dir,
class_name,
backend_indices,
grouped_native_functions,
backend_key,
dispatch_key,
selector,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,998 @@
from __future__ import annotations
import argparse
import os
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
import yaml
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
from torchgen import dest
from torchgen.api import cpp as aten_cpp
from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType
from torchgen.context import (
method_with_native_function,
method_with_nested_native_function,
with_native_function_and_index,
)
from torchgen.executorch.api import et_cpp
from torchgen.executorch.api.custom_ops import (
ComputeNativeFunctionStub,
gen_custom_ops_registration,
)
from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature
from torchgen.executorch.api.unboxing import Unboxing
from torchgen.executorch.model import ETKernelIndex, ETKernelKey, ETParsedYaml
from torchgen.executorch.parse import ET_FIELDS, parse_et_yaml, parse_et_yaml_struct
from torchgen.gen import (
get_custom_build_selector,
get_native_function_declarations,
get_native_function_declarations_from_ns_grouped_kernels,
get_native_function_schema_registrations,
LineLoader,
parse_native_yaml,
)
from torchgen.model import (
BackendIndex,
BackendMetadata,
DEFAULT_KERNEL_NAMESPACE,
DispatchKey,
FunctionSchema,
Location,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
Variant,
)
from torchgen.utils import (
context,
FileManager,
make_file_manager,
mapMaybe,
NamespaceHelper,
)
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
"""
A wrapper function to basically get `sig.decl(include_context=True)`.
For ATen kernel, the codegen has no idea about ET contextArg, so we
use this wrapper to add it.
"""
if isinstance(sig, ExecutorchCppSignature):
return sig.decl()
returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type()
cpp_args = [a.decl() for a in sig.arguments()]
cpp_args_str = ", ".join([contextArg.decl()] + cpp_args)
sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})"
return sig_decl
def static_dispatch(
sig: CppSignature | ExecutorchCppSignature,
f: NativeFunction,
backend_indices: list[BackendIndex],
) -> str:
"""
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
native function exists, error out. A simplified version of register_dispatch_key.py
Arguments:
sig: A CppSignature for this native function we want to use.
f: NativeFunction to generate static dispatch.
backend_indices: All available backends.
Return:
C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);"
"""
if len(backend_indices) == 0 or f.manual_kernel_registration:
return ""
backends = [b for b in backend_indices if b.has_kernel(f)]
static_block = None
if len(backends) == 1:
backend_metadata = backends[0].get_kernel(f)
if backend_metadata:
args = ", ".join(a.name for a in sig.arguments())
# Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch.
static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});"
else:
static_block = f"""
ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}.");
"""
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {_sig_decl_wrapper(sig)} {{
{static_block}
}}
"""
# Generates Functions.h, which provides the functional public C++ API,
# and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeFunction:
static_dispatch_backend_indices: list[BackendIndex]
selector: SelectiveBuilder
use_aten_lib: bool
is_custom_op: Callable[[NativeFunction], bool]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
is_method_variant = False
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
return None
if Variant.function not in f.variants and Variant.method in f.variants:
is_method_variant = True
# only valid remaining case is only function is in f.variants
elif not (Variant.function in f.variants and Variant.method not in f.variants):
raise Exception( # noqa: TRY002
f"Can't handle native function {f.func} with the following variant specification {f.variants}."
)
sig: CppSignature | ExecutorchCppSignature = (
CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature()
if self.use_aten_lib
else ExecutorchCppSignature.from_native_function(f)
)
if self.use_aten_lib and not self.is_custom_op(f):
comma = ", "
if is_method_variant:
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {_sig_decl_wrapper(sig)} {{
return {sig.arguments()[0].name}.{sig.name()}({comma.join(e.name for e in sig.arguments()[1:])});
}}
"""
else:
return f"""
// {f.namespace}::{f.func}
TORCH_API inline {_sig_decl_wrapper(sig)} {{
return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
}}
"""
else:
return static_dispatch(
sig,
f,
backend_indices=self.static_dispatch_backend_indices,
)
# Generates RegisterCodegenUnboxedKernels.cpp.
@dataclass(frozen=True)
class ComputeCodegenUnboxedKernels:
selector: SelectiveBuilder
use_aten_lib: bool
@method_with_nested_native_function
def __call__(
self,
unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
) -> str:
f: NativeFunction = unbox_kernel_entry[0]
kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0]
kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
op_name = f"{f.namespace}::{f.func.name}"
if not self.selector.is_root_operator(op_name):
return ""
if not isinstance(kernel_key, list):
kernel_key = [kernel_key]
used_kernel_keys = self.selector.et_get_selected_kernels(
op_name, [k.to_native_string() for k in kernel_key]
)
if not used_kernel_keys:
return ""
sig: CppSignature | ExecutorchCppSignature
argument_type_gen: Callable[..., NamedCType]
return_type_gen: Callable[..., CType]
if self.use_aten_lib:
sig = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature()
argument_type_gen = aten_cpp.argumenttype_type
return_type_gen = aten_cpp.returns_type
arguments = sig.arguments()
kernel_call = f"torch::executor::{f.namespace}::{sig.name()}"
else:
sig = ExecutorchCppSignature.from_native_function(f)
argument_type_gen = et_cpp.argumenttype_type
return_type_gen = et_cpp.returns_type
arguments = sig.arguments(include_context=False)
kernel_call = f"{kernel_meta.cpp_namespace}::{kernel_meta.kernel}"
# parse arguments into C++ code
binding_list, code_list = Unboxing(
argument_type_gen=argument_type_gen
).convert_arguments(arguments)
# for each C++ argument, generate the conversion code
code_connector = "\n\t"
arg_connector = ", "
args_str = f"{arg_connector.join(e.name for e in binding_list)}"
event_tracer_output_logging = ""
output_ids = []
if len(f.func.returns) == 0:
if len(f.func.arguments.out) == 0:
raise Exception( # noqa: TRY002
f"Can't handle native function {f.func} with no returns and no out yet."
)
out = f.func.arguments.out[0]
return_assignment = f"""stack[{len(binding_list)}] = &{out.name};"""
ret_prefix = ""
output_ids = [len(binding_list)]
else:
if len(f.func.arguments.out) == 0:
return_assignment = (
f"""*stack[{len(binding_list)}] = EValue(result_);"""
)
ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = "
output_ids = [len(binding_list)]
else:
return_assignment = ""
ret_prefix = ""
output_ids = [
len(binding_list) - (i + 1)
for i in reversed(range(len(f.func.arguments.out)))
]
for output_id in output_ids:
event_tracer_output_logging += (
f"internal::event_tracer_log_evalue("
f"context.internal_event_tracer(), "
f"*stack[{output_id}]);\n"
)
newline = "\n "
return "\n".join(
[
f"""
Kernel(
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != 'default' else ''}
[]({contextArg.defn()}, EValue** stack) {{
{code_connector.join(code_list)}
internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_{f.func.name}");
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
{ret_prefix}{kernel_call}(context, {args_str});
{event_tracer_output_logging}
{return_assignment}
}}
),
"""
for k in used_kernel_keys
]
)
def gen_unboxing(
*,
native_functions: Sequence[NativeFunction],
cpu_fm: FileManager,
selector: SelectiveBuilder,
use_aten_lib: bool,
kernel_index: ETKernelIndex,
manual_registration: bool,
) -> None:
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
def key_func(
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
) -> str:
return item[0].root_name + ":" + item[1][0].to_native_string()
items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [
(native_function, (kernel_key, metadata))
for native_function in native_functions
for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
]
header = ["Functions.h" if use_aten_lib else "NativeFunctions.h"]
filename = (
"RegisterKernels.cpp"
if manual_registration
else "RegisterCodegenUnboxedKernels.cpp"
)
cpu_fm.write_sharded(
filename,
items,
key_fn=key_func,
env_callable=lambda unbox_kernel_entry: {
"unboxed_kernels": [
ComputeCodegenUnboxedKernels(selector, use_aten_lib)(unbox_kernel_entry)
],
"fn_header": header
if unbox_kernel_entry == items[0]
else [], # Only write header once
},
num_shards=1,
sharded_keys={"unboxed_kernels", "fn_header"},
)
@with_native_function_and_index # type: ignore[arg-type]
def compute_native_function_declaration(
g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex
) -> list[str]:
assert isinstance(g, NativeFunction)
sig = ExecutorchCppSignature.from_native_function(f=g)
metadata_list = kernel_index.get_kernels(g).values()
if metadata_list is None:
return []
# for kernels in lean mode, we declare two versions, one with context and one without.
# In the end we will cleanup the unused one.
def gen_decl(metadata: BackendMetadata, include_context: bool) -> str:
return f"{sig.decl(name=metadata.kernel, include_context=include_context)};"
return [
gen_decl(metadata, include_context)
for include_context in [False, True]
for metadata in metadata_list
]
def gen_functions_declarations(
*,
native_functions: Sequence[NativeFunction],
kernel_index: ETKernelIndex,
selector: SelectiveBuilder,
use_aten_lib: bool,
custom_ops_native_functions: Sequence[NativeFunction] | None = None,
) -> str:
"""
Generates namespace separated C++ function API inline declaration/definitions.
Native functions are grouped by namespaces and the generated code is wrapped inside
namespace blocks.
E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol
in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when
the other `custom_2::foo.out` is available.
"""
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
backend_index = kernel_index._to_backend_index()
ns_grouped_functions = defaultdict(list)
for native_function in native_functions:
ns_grouped_functions[native_function.namespace].append(native_function)
functions_declarations = ""
newline = "\n"
for namespace in ns_grouped_functions:
ns_helper = NamespaceHelper(
namespace_str=namespace,
entity_name="",
max_level=3,
)
declarations = list(
mapMaybe(
ComputeFunction(
static_dispatch_backend_indices=[backend_index],
selector=selector,
use_aten_lib=use_aten_lib,
is_custom_op=lambda f: custom_ops_native_functions is not None
and f in custom_ops_native_functions,
),
ns_grouped_functions[namespace],
)
)
functions_declarations += f"""
{ns_helper.prologue}
{newline.join(declarations)}
{ns_helper.epilogue}
"""
return functions_declarations
def get_ns_grouped_kernels(
*,
native_functions: Sequence[NativeFunction],
kernel_index: ETKernelIndex,
native_function_decl_gen: Callable[
[
NativeFunctionsGroup | NativeFunction,
ETKernelIndex,
],
list[str],
],
) -> dict[str, list[str]]:
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
for f in native_functions:
native_function_namespaces = set()
op_kernels = kernel_index.get_kernels(f)
for backend_metadata in op_kernels.values():
if backend_metadata:
namespace = backend_metadata.cpp_namespace
native_function_namespaces.add(namespace)
else:
namespace = DEFAULT_KERNEL_NAMESPACE
assert (
len(native_function_namespaces) <= 1
), f"Codegen only supports one namespace per operator, got {native_function_namespaces}"
ns_grouped_kernels[namespace].extend(
native_function_decl_gen(f, kernel_index)
)
return ns_grouped_kernels
def gen_headers(
*,
native_functions: Sequence[NativeFunction],
gen_custom_ops_header: bool,
custom_ops_native_functions: Sequence[NativeFunction],
selector: SelectiveBuilder,
kernel_index: ETKernelIndex,
cpu_fm: FileManager,
use_aten_lib: bool,
) -> None:
"""Generate headers.
Args:
native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops.
gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h
custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops.
kernel_index (ETKernelIndex): kernel collection
cpu_fm (FileManager): file manager manages output stream
use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types.
"""
aten_headers = ["#include <ATen/Functions.h>"]
backend_indices = {DispatchKey.CPU: kernel_index._to_backend_index()}
if gen_custom_ops_header:
cpu_fm.write_with_template(
"CustomOpsNativeFunctions.h",
"NativeFunctions.h",
lambda: {
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=custom_ops_native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
"headers": [
"#include <ATen/ATen.h>",
"#include <torch/torch.h>",
],
},
)
aten_headers.append('#include "CustomOpsNativeFunctions.h"')
cpu_fm.write(
"Functions.h",
lambda: {
"static_dispatch_extra_headers": aten_headers
if use_aten_lib
else ['#include "NativeFunctions.h"'],
"Functions_declarations": gen_functions_declarations(
native_functions=native_functions,
kernel_index=kernel_index,
selector=selector,
use_aten_lib=use_aten_lib,
custom_ops_native_functions=custom_ops_native_functions,
),
},
)
cpu_fm.write(
"RegisterKernels.h",
lambda: {
"generated_comment": "@" + "generated by torchgen/gen_executorch.py",
},
)
headers = {
"headers": [
"#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.",
"#include <executorch/runtime/kernel/kernel_runtime_context.h>",
],
}
if use_aten_lib:
headers["headers"].append("#include <executorch/codegen/macros.h> // TORCH_API")
cpu_fm.write(
"NativeFunctions.h",
lambda: dict(
{
"nativeFunctions_declarations": get_native_function_declarations(
grouped_native_functions=native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
),
},
**headers,
),
)
else:
ns_grouped_kernels = get_ns_grouped_kernels(
native_functions=native_functions,
kernel_index=kernel_index,
native_function_decl_gen=compute_native_function_declaration, # type: ignore[arg-type]
)
cpu_fm.write(
"NativeFunctions.h",
lambda: dict(
{
"nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels(
ns_grouped_kernels=ns_grouped_kernels,
),
},
**headers,
),
)
def gen_custom_ops(
*,
native_functions: Sequence[NativeFunction],
selector: SelectiveBuilder,
kernel_index: ETKernelIndex,
cpu_fm: FileManager,
rocm: bool,
) -> None:
dispatch_key = DispatchKey.CPU
(
anonymous_definition,
static_init_dispatch_registrations,
) = gen_custom_ops_registration(
native_functions=native_functions,
selector=selector,
kernel_index=kernel_index,
rocm=rocm,
)
cpu_fm.write_with_template(
f"Register{dispatch_key}CustomOps.cpp",
"RegisterDispatchKeyCustomOps.cpp",
lambda: {
"ops_headers": '#include "CustomOpsNativeFunctions.h"',
"DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": "",
"dispatch_anonymous_definitions": anonymous_definition,
"static_init_dispatch_registrations": static_init_dispatch_registrations,
},
)
cpu_fm.write_with_template(
f"Register{dispatch_key}Stub.cpp",
"RegisterDispatchKeyCustomOps.cpp",
lambda: {
"ops_headers": "",
"DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": "",
"dispatch_anonymous_definitions": list(
mapMaybe(ComputeNativeFunctionStub(), native_functions)
),
"static_init_dispatch_registrations": static_init_dispatch_registrations,
},
)
(
aten_schema_registrations,
schema_registrations,
) = get_native_function_schema_registrations(
native_functions=native_functions,
schema_selector=selector,
)
cpu_fm.write(
"RegisterSchema.cpp",
lambda: {
"schema_registrations": schema_registrations,
"aten_schema_registrations": aten_schema_registrations,
},
)
def translate_native_yaml(
tags_yaml_path: str,
aten_yaml_path: str,
native_yaml_path: str | None,
use_aten_lib: bool,
out_file: TextIO,
) -> None:
"""Translates Executorch DSL dialect to use the same syntax as
native_functions.yaml. The major difference is that Executorch DSL dialect
supports "op" key, where it refers to the operator name in native_functions.yaml.
For example, a functions.yaml may have the following entry:
- op: add.out
...
It needs to be translated to the following:
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
...
We go in aten_yaml_path and find the operator schema for "add.out" and add it
to the original functions.yaml. We also add required field "variants", where for
Executorch it will always be "function".
For ATen mode we don't have to do the translation because native_yaml_path is
the same as native_functions.yaml.
Args:
tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
It is not optional.
aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
native_yaml_path: Path to a functions.yaml file to parse.
If the path does not exist in the filesystem, it is treated as an
empty file. If `custom_ops_yaml_path` exists, the contents of that
file are appended to the yaml input to be parsed.
use_aten_lib: We use this flag to determine if we want to generate native
functions. In ATen mode we should generate out= variants.
out_file: The IO object that we are writing into.
Returns:
None
"""
if use_aten_lib:
with open(aten_yaml_path) as aten_yaml:
out_file.writelines(aten_yaml.readlines())
return
native_functions, persisted_fields = parse_et_yaml(
aten_yaml_path,
tags_yaml_path,
None,
skip_native_fns_gen=False,
)
func_to_scoped_name: dict[FunctionSchema, str] = {
f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
}
op_to_scoped_name: dict[OperatorName, str] = {
func.name: name for func, name in func_to_scoped_name.items()
}
schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
kernel_persist_dict: dict[str, dict[str, Any]] = {
op_to_scoped_name[op]: v for op, v in persisted_fields.items()
}
if (
not native_yaml_path
or not os.path.exists(native_yaml_path)
or os.stat(native_yaml_path).st_size == 0
):
return
with open(native_yaml_path) as native_yaml:
native_es = yaml.load(native_yaml, Loader=LineLoader)
if not native_es:
return
for e in native_es:
assert isinstance(e.get("__line__"), int), e
loc = Location(native_yaml_path, e.pop("__line__"))
with context(lambda: f"in {loc}:\n "):
if "variants" not in e:
e["variants"] = "function"
if "func" in e:
continue
assert isinstance(e.get("op"), str), e
opname = e.pop("op")
if "::" not in opname:
opname = "aten::" + opname
assert opname in schema_dict
e["func"] = schema_dict.get(opname)
# Write out persisted kernel information
if opname in kernel_persist_dict:
for k, v in kernel_persist_dict[opname].items():
e[k] = v
yaml.dump(native_es, out_file, width=1000)
def parse_yaml(
path: str | None,
tags_yaml_path: str,
function_filter: Callable[[NativeFunction], bool],
skip_native_fns_gen: bool = False,
) -> tuple[
list[NativeFunction],
dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex,
]:
if path and os.path.exists(path) and os.stat(path).st_size > 0:
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
# Check for kernel index structure
kernel_index = (
parse_et_yaml_struct(es) if any("kernels" in e for e in es) else None
)
# Remove ET specific fields from entries for BC compatibility
for entry in es:
for field in ET_FIELDS:
entry.pop(field, None)
parsed_yaml = parse_native_yaml(
path,
tags_yaml_path,
None,
skip_native_fns_gen=skip_native_fns_gen,
loaded_yaml=es,
)
native_functions = list(filter(function_filter, parsed_yaml.native_functions))
op_names = [f.func.name for f in native_functions]
# (1) Return ETKernelIndex if kernel index is present
if kernel_index is not None:
filtered_index = {
op_name: kernel_mapping
for op_name, kernel_mapping in kernel_index.index.items()
if op_name in op_names
}
return native_functions, ETKernelIndex(index=filtered_index)
# (2) Return BackendIndices if kernel index is absent
def map_index(
m: dict[OperatorName, BackendMetadata]
) -> dict[OperatorName, BackendMetadata]:
return {op: m[op] for op in m if op in op_names}
backend_indices = {
k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items()
}
return native_functions, backend_indices
else:
return [], {}
def parse_yaml_files(
tags_yaml_path: str,
aten_yaml_path: str,
native_yaml_path: str | None,
custom_ops_yaml_path: str | None,
selector: SelectiveBuilder,
use_aten_lib: bool,
) -> tuple[ETParsedYaml, ETParsedYaml | None]:
"""Parses functions.yaml and custom_ops.yaml files.
Args:
tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
It is not optional.
aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
native_yaml_path: Path to a functions.yaml file to parse.
If the path does not exist in the filesystem, it is treated as an
empty file. If `custom_ops_yaml_path` exists, the contents of that
file are appended to the yaml input to be parsed.
custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If
the path does not exist in the filesystem, it is ignored.
selector: For selective build.
use_aten_lib: We use this flag to determine if we want to generate native
functions. In ATen mode we should generate out= variants.
Returns:
A tuple with two elements:
[0]: The parsed results of concatenating the contents of
`native_yaml_path` and `custom_ops_yaml_path`.
[1]: The parsed results of the contents of `custom_ops_yaml_path`, if
present. If not present, None.
"""
import tempfile
# only include selected ops, this is because we want to avoid
def function_filter(f: NativeFunction) -> bool:
return selector.is_native_function_selected(f)
with tempfile.TemporaryDirectory() as tmpdirname:
translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
with open(translated_yaml_path, "w") as translated:
translate_native_yaml(
tags_yaml_path,
aten_yaml_path,
native_yaml_path,
use_aten_lib,
translated,
)
translated_functions, translated_indices = parse_yaml(
translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib
)
custom_ops_functions, custom_ops_indices = parse_yaml(
custom_ops_yaml_path, tags_yaml_path, function_filter, True
)
# Convert BackendIndices to ETKernelIndex
if not isinstance(translated_indices, ETKernelIndex):
translated_indices = ETKernelIndex.from_backend_indices(translated_indices)
if not isinstance(custom_ops_indices, ETKernelIndex):
custom_ops_indices = ETKernelIndex.from_backend_indices(custom_ops_indices)
combined_functions = translated_functions + custom_ops_functions
combined_kernel_index = ETKernelIndex.merge_indices(
translated_indices, custom_ops_indices
)
combined_yaml = ETParsedYaml(combined_functions, combined_kernel_index)
custom_ops_parsed_yaml = ETParsedYaml(custom_ops_functions, custom_ops_indices)
return combined_yaml, custom_ops_parsed_yaml
def main() -> None:
parser = argparse.ArgumentParser(description="Generate operator source files")
# Although we don't refer to --source-path directly, make_file_manager()
# expects it to point to a directory that contains a templates/ subdirectory
# containing the file templates.
parser.add_argument(
"-s",
"--source-path",
help="path to source directory for kernel templates",
)
parser.add_argument(
"--functions-yaml-path",
"--functions_yaml_path",
help="path to the functions.yaml file to use. Optional, but at least "
"one of --functions-yaml-path and --custom-ops-yaml-path must be "
"specified.",
)
parser.add_argument(
"--custom-ops-yaml-path",
"--custom_ops_yaml_path",
help="path to the custom_ops.yaml file to use. Optional, but at least "
"one of --functions-yaml-path and --custom-ops-yaml-path must be "
"specified.",
)
parser.add_argument(
"--aten-yaml-path",
"--aten_yaml_path",
help="path to native_functions.yaml file.",
)
# Note that make_file_manager() also looks at --install-dir.
parser.add_argument(
"-d",
"--install-dir",
"--install_dir",
help="output directory",
default="build/generated",
)
parser.add_argument(
"-o",
"--output-dependencies",
help="output a list of dependencies into the given file and exit",
)
# Although we don't refer to --dry-run directly, make_file_manager() looks
# for it.
parser.add_argument(
"--dry-run",
action="store_true",
help="run without writing any files (still updates outputs)",
)
parser.add_argument(
"--static-dispatch-backend",
"--static_dispatch_backend",
nargs="*",
help="generate static dispatch code for the specific backend (if set)",
)
parser.add_argument(
"--op-registration-whitelist",
"--op_registration_whitelist",
nargs="*",
help="filter op registrations by the whitelist (if set); "
"each item is `namespace`::`operator name` without overload name; "
"e.g.: aten::empty aten::conv2d ...",
)
parser.add_argument(
"--op-selection-yaml-path",
"--op_selection_yaml_path",
help="Provide a path to the operator selection (for custom build) YAML "
"that contains the information about the set of selected operators "
"and their categories (training, ...). Each operator is either a "
"full operator name with overload or just a bare operator name. "
"The operator names also contain the namespace prefix (e.g. aten::)",
)
parser.add_argument(
"--tags-path",
help="Path to tags.yaml. Required by yaml parsing in codegen system.",
)
parser.add_argument(
"--rocm",
action="store_true",
help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
)
parser.add_argument(
"--use-aten-lib",
"--use_aten_lib",
action="store_true",
help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per "
"operator",
)
parser.add_argument(
"--manual_registration",
"--manual-registration",
action="store_true",
help="a boolean flag to indicate whether we want to manually call"
"register_kernels() or rely on static init. ",
)
parser.add_argument(
"--generate",
type=str,
nargs="*",
choices=["headers", "sources"],
default=["headers", "sources"],
help="Generate only a subset of files",
)
options = parser.parse_args()
assert options.tags_path, "tags.yaml is required by codegen yaml parsing."
selector = get_custom_build_selector(
options.op_registration_whitelist,
options.op_selection_yaml_path,
)
parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
aten_yaml_path=options.aten_yaml_path,
tags_yaml_path=options.tags_path,
native_yaml_path=options.functions_yaml_path,
custom_ops_yaml_path=options.custom_ops_yaml_path,
selector=selector,
use_aten_lib=options.use_aten_lib,
)
native_functions, kernel_index = (
parsed_yaml.native_functions,
parsed_yaml.kernel_index,
)
custom_ops_native_functions = (
custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else []
)
cpu_fm = make_file_manager(options=options)
if "headers" in options.generate:
# generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system.
gen_headers(
native_functions=native_functions,
gen_custom_ops_header=options.custom_ops_yaml_path,
custom_ops_native_functions=custom_ops_native_functions,
selector=selector,
kernel_index=kernel_index,
cpu_fm=cpu_fm,
use_aten_lib=options.use_aten_lib,
)
if "sources" in options.generate:
gen_unboxing(
native_functions=native_functions,
cpu_fm=cpu_fm,
selector=selector,
use_aten_lib=options.use_aten_lib,
kernel_index=kernel_index,
manual_registration=options.manual_registration,
)
if custom_ops_native_functions:
gen_custom_ops(
native_functions=custom_ops_native_functions,
selector=selector,
kernel_index=kernel_index,
cpu_fm=cpu_fm,
rocm=options.rocm,
)
if options.output_dependencies:
depfile_path = Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem
for fm, prefix in [
(cpu_fm, ""),
]:
varname = prefix + depfile_stem
path = depfile_path.parent / (prefix + depfile_name)
fm.write_outputs(varname, str(path))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,882 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, TYPE_CHECKING
from torchgen.api import cpp, dispatcher
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
CType,
DispatcherSignature,
FunctionalizationLambda,
iTensorListRefT,
NativeSignature,
OptionalCType,
optionalSymIntArrayRefT,
symIntArrayRefT,
SymIntT,
tensorListT,
tensorT,
VectorCType,
ViewInverseSignature,
)
from torchgen.context import (
method_with_native_function,
native_function_manager,
with_native_function,
with_native_function_and,
)
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
Return,
SchemaKind,
SelfArgument,
TensorOptionsArguments,
)
from torchgen.native_function_generation import (
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
)
from torchgen.utils import dataclass_repr
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
# Note: [Mutable Ops Not Using Functionalization]
# Ops in this list currently do not work with functionalization and should be fixed.
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
+ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+ [
# It will be BC-breaking, but we should fix their schemas.
# should be inplace?
"record_stream",
# See Note [resize_ in Functionalization]
"resize_",
"resize_as_",
# This function is used as for testing purposes only.
"_fill_mem_eff_dropout_mask_",
]
)
# This file contains codegen that relates to the functionalization pass.
# It includes:
# - gen_functionalization_definition
# Generates dispatcher kernel definitions for the functionalization pass.
# - gen_functionalization_registration
# Generates dispatcher kernel registrations for the functionalization pass.
# - gen_functionalization_view_inverse_declaration
# Generates a declaration for an "inverse view", for every view op
# that is needed in functionalization. We manually implement their definitions.
# - gen_composite_view_copy_kernel
# Generates view_copy() composite kernels for all view_copy operators.
# Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction
# See Note [view_copy NativeFunctions]
@dataclass(frozen=True)
class GenCompositeViewCopyKernel:
backend_index: BackendIndex
@method_with_native_function
def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
if g.view_copy is None:
return None
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
# If the view_copy doesn't match the standard naming scheme of <op>_copy,
# assume it already exists and doesn't need to be generated.
# Example: slice_inverse() with the copy variant named slice_scatter()
# instead of slice_inverse_copy()
return None
metadata = self.backend_index.get_kernel(g.view_copy)
assert metadata is not None
# We can make view_copy work in more cases by using reshape()
# when a normal view call would ordinarily fail.
# This also makes LTC more efficient, because they don't need to include
# clone() calls in their graph (which is normally needed by reshape).
if str(g.view_copy.func.name) == "view_copy":
assert metadata.kernel == "view_copy_symint"
return """\
at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
c10::SymDimVector shape = infer_size_dv(size, self.sym_numel());
if (!at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape).has_value()) {
return self.reshape_symint(size);
} else {
auto output = at::_ops::view::call(self, size);
return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
}
}
"""
# view_copy is a native signature, since we're generating an at::native:: kernel
# Functionalization always operates on symints though
view_copy_sig = NativeSignature(
g.view_copy.func, symint=metadata.supports_symint()
)
# view is a dispatcher signature, since we're calling into the at::_ops API
view_sig = DispatcherSignature(g.view.func)
view_api_name = g.view.func.name.unambiguous_name()
exprs = ", ".join(
[e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())]
)
# view ops today always return either a Tensor or a list of Tensors
assert len(g.view.func.returns) == 1
assert g.view.func.returns[0].type == BaseType(
BaseTy.Tensor
) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)
if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
return_cloned_output = """\
return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);"""
else:
# If the return type is a list, we need to clone each tensor in the list.
return_cloned_output = f"""\
{view_copy_sig.returns_type().cpp_type()} out_clone;
for (const auto i : c10::irange(output.size())) {{
out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous));
}}
return out_clone;"""
# The default generated composite kernel for {view}_copy() operators just clones
# the input tensor, and runs the underlying view on the clone.
return f"""
{view_copy_sig.defn(name=metadata.kernel)} {{
auto output = at::_ops::{view_api_name}::call({exprs});
{return_cloned_output}
}}
"""
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
assert len(rets) == len(names)
if len(rets) == 0:
return ""
elif len(rets) == 1:
return f"return {names[0]};"
else:
return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
def modifies_arguments(f: NativeFunction) -> bool:
return any(
a.annotation is not None and a.annotation.is_write
for a in f.func.arguments.flat_all
)
def wrapper_name(func: FunctionSchema) -> str:
if func.name.overload_name:
return f"{cpp.name(func)}_{func.name.overload_name}"
else:
return cpp.name(func)
def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
return isinstance(a, SelfArgument) or (
isinstance(a, Argument) and a.type.is_tensor_like()
)
# We need to wrap / unwrap various arguments from the op in the functionalization kernels.
# Some op schemas include non-owning types though (like TensorList),
# and when we unwrap them we expect to get out an owning type!.
# We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
if t == BaseCType(tensorListT):
return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
if t == BaseCType(iTensorListRefT):
return VectorCType(BaseCType(tensorT)), lambda x: f"{{{x}.begin(), {x}.end()}}"
# There are technically other non-owning types out there (like IntArrayRef),
# but functionalization only actually cares about the ones involving tensors.
return t, lambda x: x
# unwraps all tensor-like arguments, returning:
# (1) a string containing all of the logic that does the unwrapping
# (2) a context, to be used by translate(), with all of the relevant bindings.
def unwrap_tensor_args(
sig: DispatcherSignature, *, is_view_op: bool
) -> tuple[str, list[Binding]]:
context: list[Binding] = []
unwrapped_tensor_args: list[str] = []
for arg in sig.arguments():
if is_tensor_like(arg.argument):
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
unwrapped_name = f"{arg.name}_"
# For most ops, the functionalization needs to sync any pending updates on the input tensors
# before calling the operator, since otherwise the operator will act on stale data.
# For view ops though, we can continue to defer syncing until the tensor is used by
# a non-view operator.
maybe_sync_input = (
"" if is_view_op else f"at::functionalization::impl::sync({arg.name});"
)
unwrapped_type, conversion_fn = get_owning_type(
arg.nctype.remove_const_ref().type
)
unwrapped_tensor_args.append(
f"""
{unwrapped_type.cpp_type()} {unwrapped_name};
if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{
{maybe_sync_input}
{unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name});
}} else {{
{unwrapped_name} = {conversion_fn(arg.name)};
}}"""
)
context.append(arg.with_name(unwrapped_name))
else:
# for non-tensor inputs, we want to pass them directly into the redispatch calls.
context.append(arg)
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
return unwrap_tensor_args_str, context
# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
# (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
context: list[Binding] = []
unwrapped_tensor_args: list[str] = []
for arg in sig.arguments():
if is_tensor_like(arg.argument):
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
a_ = arg.name
unwrapped_name = f"{arg.name}_meta"
unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});")
context.append(arg.with_name(unwrapped_name))
else:
# for non-tensor inputs, we want to pass them directly into the redispatch calls.
context.append(arg)
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
return unwrap_tensor_args_str, context
# The functionalization codegen currently expects view op schemas to have this form:
# foo(Tensor(a), ...) -> Tensor(a) (e.g. transpose)
# foo(Tensor(a!), ...) -> Tensor(a!) (e.g. transpose_)
def assert_view_op_properties(func: FunctionSchema) -> None:
def is_alias(a: Argument) -> bool:
return a.annotation is not None
args = func.arguments.flat_non_out
# The first argument is a tensor with an alias semantics (annotations)
assert len(args) > 0 and args[0].type == BaseType(
BaseTy.Tensor
), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor,
but found an argument of type {str(args[0].type)} for operator: {str(func.name)}."""
# No other arguments have aliasing semantics
assert is_alias(args[0]) and not any(
is_alias(a) for a in args[1:]
), """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint"""
# One-liner expression for checking if an expression expr of type type has any
# symbolic values.
def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
if type == BaseCType(SymIntT):
return f"{expr}.is_symbolic()"
if isinstance(type, OptionalCType):
innerexpr = f"(*{expr})"
return f"{expr}.has_value() ? {emit_expr_has_symbolic_values(innerexpr, type.elem)} : false"
if type == BaseCType(optionalSymIntArrayRefT):
return emit_expr_has_symbolic_values(
expr, OptionalCType(BaseCType(symIntArrayRefT))
)
if type in (BaseCType(symIntArrayRefT), VectorCType(BaseCType(SymIntT))):
argname = "arg"
lambda_check = emit_expr_has_symbolic_values(argname, BaseCType(SymIntT))
return (
"std::any_of("
f"{expr}.begin(), {expr}.end(), "
f"[=](auto& {argname}) {{ return {lambda_check}; }})"
)
raise ValueError(
"unsupported type for has_symbolic_values check. "
"It should be a SymInt or a collection of those. "
f"Got: {type.cpp_type()}"
)
# Detects whether any of the SymInt arguments are, in fact, symbolic values.
# This is used in the constructor of ViewMeta.
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
name = "has_symbolic_inputs"
statements = [
f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
for binding in sig.arguments()
if (
isinstance(binding.argument, Argument)
and binding.argument.type.is_symint_like()
)
]
body = "\n ".join(statements)
return (
name,
f"""
bool {name} = false;
{body}""",
)
# Generates the Functionalization kernel for:
# - ops that create aliases (e.g. transpose())
# - ops that are views AND mutations (e.g. transpose_())
def emit_view_functionalization_body(
g: NativeFunctionsViewGroup, *, view_inplace: bool
) -> str:
if view_inplace:
# This op is both an inplace op AND a view op.
# See Note [Functionalization Pass - Inplace View Ops] for details.
# I currently have the view meta call into the out-of-place variant of the view, to avoid
# having to define an extra ~20 inplace {view}_inverse_ functions.
# Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops.
# I'm assuming that every inplace-view op has a corresponding out-of-place view op,
# with the same name but the trailing underscore removed.
# This is currently asserted at parse time in gen.py (see error_check_native_functions).
assert g.view_inplace is not None
f = g.view_inplace
else:
f = g.view
assert g.view_copy is not None
with native_function_manager(f):
call_sig = DispatcherSignature.from_schema(g.view_copy.func)
# the "view_copy" op name that the functionalization kernels need to call
api_name = g.view_copy.func.name.unambiguous_name()
# Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
# "no-op"ing in this context is just redispatching to the original op.
noop_api_name = f.func.name.unambiguous_name()
dispatcher_sig = DispatcherSignature.from_schema(f.func)
assert_view_op_properties(f.func)
view_tensor_name = dispatcher_sig.arguments()[0].name
return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()
unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
dispatcher_sig, is_view_op=True
)
view_redispatch_args = [
e.expr
for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
]
forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
# The meta API call should use the same arguments, but convert all tensors to meta tensors first.
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [
e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False)
]
(
symbolic_inputs_varname,
symbolic_inputs_check,
) = emit_has_symbolic_inputs(call_sig)
if "inplace_view" in f.tags:
# See Note [Functionalization Pass - Inplace View Ops] for more details
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
{unwrap_tensor_args_str}
at::AutoDispatchSkipFunctionalize guard;
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
}}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto inverse_return_mode = (
reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
: at::functionalization::InverseReturnMode::NeverView
);
{symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname}
);
auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
{return_type} reference_tensor_output;
if (compute_reference_meta) {{
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
}}
// This function adds the above view meta to the current tensor and replays them off the base,
// mutating the size/stride info of the current FunctionalTensorWrapper.
// Because of this, we need to make sure to run the reference shape function above,
// BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides)
at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass]
// XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely
// on a reference implementation here (instead of relying on the output from the forward lambda
// having the correct stride info)
if (compute_reference_meta) {{
at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
}}
return {view_tensor_name};
}}
"""
else:
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
{unwrap_tensor_args_str}
if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
at::AutoDispatchSkipFunctionalize guard;
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
}}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto inverse_return_mode = (
reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
: at::functionalization::InverseReturnMode::NeverView
);
auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
{return_type} reference_tensor_output;
if (compute_reference_meta) {{
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
}}
{return_type} tmp_output;
{{
at::AutoDispatchSkipFunctionalize guard;
if (reapply_views) {{
tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
}} else {{
tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)});
}}
}}
{symbolic_inputs_check}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
{forward_lambda.decl()} {{
if (reapply_views) {{
return {forward_lambda.inner_call(reapply_views=True)}
}} else {{
return {forward_lambda.inner_call(reapply_views=False)}
}}
}},
{reverse_lambda.decl()} {{
return {reverse_lambda.inner_call()}
}},
/*has_symbolic_inputs=*/{symbolic_inputs_varname},
/*is_multi_output=*/{str(is_multi_output_view).lower()},
/*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass]
if (compute_reference_meta) {{
at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
}}
return out;
}}
"""
def maybe_create_output(f: NativeFunction, var_name: str) -> str:
if len(f.func.returns) == 0:
return ""
return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
return f"{return_type} {var_name} = "
# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function,
# this returns two lists of names, consisting of:
# - the names of returns corresponding to the original (mutable) inputs of the outer function
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
def get_mutable_redispatch_return_names(
f: NativeFunction, inner_return_var: str
) -> tuple[list[str], list[str]]:
aliased_returns = []
non_aliased_returns = []
for i, name in enumerate(f.func.aliased_return_names()):
if name is not None:
aliased_returns.append(name)
else:
non_aliased_returns.append(
inner_return_var
if len(f.func.returns) == 1
else f"std::get<{i}>({inner_return_var})"
)
return aliased_returns, non_aliased_returns
# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that:
# - For fresh outputs, we return the result of the redispatch (without wrapping outputs)
# - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped)
def return_from_mutable_noop_redispatch(
f: NativeFunction, inner_return_var: str
) -> str:
aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var)
# Just get all of the return names, and immediately return them
return return_str(f.func.returns, aliased + non_aliased)
def wrap_propagate_mutations_and_return(
f: NativeFunction, functional_op: NativeFunction, inner_return_var: str
) -> str:
mutable_arg_names = f.func.arguments.mutable_arg_names()
(
aliased_outer_rets,
non_aliased_outer_rets,
) = get_mutable_redispatch_return_names(f, inner_return_var)
_, non_aliased_inner_rets = get_mutable_redispatch_return_names(
functional_op, inner_return_var
)
# The outer function may have a mix of aliased and non-aliased outputs,
# But the inner functional op that we're transforming to should only have non-aliased outputs
assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len(
non_aliased_inner_rets
)
# First, take all of the newly created outputs from the inner call and wrap them into functional tensors
updates = []
non_aliased_wrapped_ret_names = []
for i, inner_ret in enumerate(
non_aliased_inner_rets[: len(non_aliased_outer_rets)]
):
ret_name = f"output_{i}"
updates.append(
f"""\
auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});"""
)
non_aliased_wrapped_ret_names.append(ret_name)
# Next, take all of the mutated outputs from the inner call corresponding to mutated inputs,
# and propagate the mutations
for outer_arg, inner_ret in zip(
mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :]
):
updates.append(
f"""\
auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg});
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
at::functionalization::impl::commit_update({outer_arg});
at::functionalization::impl::sync({outer_arg});
auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg});
at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);"""
)
# Finally, we return:
# - Any mutable arguments that also returns
# - Any immutable returns that were created wrapping the output from the inner call
returns_str = return_str(
f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names
)
updates_str = "\n".join(updates)
return f"""\
{updates_str}
{returns_str}"""
# Generates the Functionalization kernel for:
# - mutation ops (inplace and out= ops)
@with_native_function_and
def emit_inplace_functionalization_body(
f: NativeFunction, g: NativeFunctionsGroup
) -> str:
# mutation case
assert modifies_arguments(f)
dispatcher_sig = DispatcherSignature.from_schema(f.func)
unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
dispatcher_sig, is_view_op=False
)
mutated_names = [
a.name
for a in f.func.arguments.flat_all
if a.type.is_tensor_like() and a.annotation is not None
]
non_mutated_names = [
a.name
for a in f.func.arguments.flat_all
if a.type.is_tensor_like() and a.annotation is None
]
non_mutated_tensor_names = [
a.name
for a in f.func.arguments.flat_all
if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
]
# all mutable inputs must be functional tensors in order to participate in functionalization
check_all_mutated_args_are_functional = " && ".join(
["true"]
+ [
f"at::functionalization::impl::isFunctionalTensor({a})"
for a in mutated_names
]
)
check_any_non_mutated_args_are_functional = " || ".join(
["false"]
+ [
f"at::functionalization::impl::isFunctionalTensor({a})"
for a in non_mutated_names
]
)
check_any_non_mutated_tensors_are_xla = " || ".join(
["false"]
+ [
f"{a}.device().type() == c10::DeviceType::XLA"
for a in non_mutated_tensor_names
]
)
# These are used in the cases where we don't functionalize and redispatch to the inplace op
# case 1: we hit an inplace op that doesn't have an out-of-place equivalent
# case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
inplace_exprs = [
e.expr
for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
]
# call the out-of-place variant of the op
return_type = (
dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type()
)
functional_sig = DispatcherSignature.from_schema(g.functional.func)
functional_exprs = [
e.expr
for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False)
]
if f.func.is_out_fn():
mutable_input_post_processing = "\n".join(
[
f"""
at::functionalization::impl::replace_(
{a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'});
at::functionalization::impl::commit_update({a.name});"""
for (i, a) in enumerate(f.func.arguments.out)
if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
]
)
else:
mutable_input_post_processing = "\n".join(
[
f"""
at::functionalization::impl::replace_({a.name}, tmp_output);
at::functionalization::impl::commit_update({a.name});"""
for a in f.func.arguments.flat_all
if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
]
)
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
# We don't want to run the inplace meta func for ops like .set_(), because:
# (1) they're unnecessary: inplace meta checks are only useful for ops like add_(),
# where broadcasting will work for the out-of-place case but should fail on the inplace call
# (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument
# into a meta storage
any_storage_args = any(
a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all
)
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{
// Before converting the mutable op to its functional variant, run meta tensors through the original op.
// This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
// (We can only do this for inplace ops today though, because they technically all support meta tensors).
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)});
}}
{unwrap_tensor_args_str}
if (!({check_all_mutated_args_are_functional})) {{
// We want to disable this check if there are any XLA tensors.
// cpu_tensor.copy_(xla_tensor) is valid code.
if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
// case 1: trying to mutate a non functional tensor with a functional tensor is an error
TORCH_INTERNAL_ASSERT(false,
"mutating a non-functional tensor with a functional tensor is not allowed.",
" Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
}} else {{
// case 2: arguments are not functional tensors, so we no-op and redispatch.
at::AutoDispatchSkipFunctionalize guard;
{maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
{return_from_mutable_noop_redispatch(f, 'tmp_output')}
}}
}} else {{
{return_type} tmp_output;
{{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
}}
{wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')}
}}
}}"""
# The below functions generate RegisterFunctionalization.cpp
# These files provide the kernels that run the functionalization pass, which can be opted into
# per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch).
# See Note [Functionalization Pass: View Inverses].
def gen_functionalization_view_inverse_declaration(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> str | None:
# For every (non-composite) view op, we need a corresponding "inverse view" function.
# This generates the declarations so we get a good compiler error when someone adds a new view.
@with_native_function
def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
if g.view.has_composite_implicit_autograd_kernel:
return None
view_inverse_sig = ViewInverseSignature(g)
return view_inverse_sig.decl()
return emit_decl_helper(g)
def gen_functionalization_registration(
selector: SelectiveBuilder,
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
composite_implicit_autograd_index: BackendIndex,
) -> list[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'
# Don't generate kernels in mobile build
if not selector.include_all_operators:
return []
if isinstance(g, NativeFunctionsViewGroup):
# functionalization needs to register kernels for view + view_inplace ops
# See Note [Functionalization <> torch.Tensor constructor]
if str(g.view.func.name) == "lift_fresh":
return []
view_str = []
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str
elif isinstance(g, NativeFunctionsGroup):
# Gets a hand-written functionalization kernel
if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor":
fns = []
else:
fns = list(g.functions())
else:
if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
return []
fns = [g]
registrations = []
for f in fns:
if f.has_composite_implicit_autograd_kernel:
continue
if str(f.func.name) == "lift":
# See Note [Functionalization <> torch.Tensor constructor]
return []
if str(f.func.name) == "resize_":
# See Note [resize_ in Functionalization]
return []
if str(f.func.name.name) != "set_":
assert not f.is_view_op
# functionalization needs to generate and register kernels for inplace ops.
# We *also* need to directly register CompositeImplicitAUtograd kernels
# so that they decompose properly before functioanlization.
if modifies_arguments(f):
registrations.append(emit_registration_helper(f))
return registrations
def gen_functionalization_definition(
selector: SelectiveBuilder,
# Note: Ideally this code should never have to look at NativeFunction
# (and instead only need to operate on grouped NativeFunctions).
# The only reason currently is because we need to emit direct dispatch registrations
# For CompositeImplicitAutograd operators, which are potentially ungrouped.
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
# Don't generate kernels in mobile build
if not selector.include_all_operators:
return []
if isinstance(g, NativeFunctionsViewGroup):
# Case 1: emit view -> view_copy kernels for the functionalization pass
view_defs = []
if not g.composite:
# invariant: NativeFunctionsViewGroup's always have a view_copy operator
# if the view is not composite (implicit autograd)
assert g.view_copy is not None, dataclass_repr(g, indent=1)
view_defs.append(emit_view_functionalization_body(g, view_inplace=False))
if g.view_inplace is not None:
view_defs.append(emit_view_functionalization_body(g, view_inplace=True))
return view_defs
elif isinstance(g, NativeFunction):
# Invariant: all mutable operators that we need to handle in functionalization
# should have been properly grouped up.
# TODO: The below ops all have "problematic" schemas that prevent them from
# getting functionalized. Instead of bending over backwards to get things to work,
# I think we should either:
# (1) fix their schemas (BC-breaking)
# (2) hand-write their functionalization kernels
if (
str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
and str(g.func.name.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
):
assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g)
return []
else:
# Case 2: emit inplace -> out-of-place kernels for the functionalization pass
mutation_defs = []
mutation_defs.append(emit_inplace_functionalization_body(g.out, g))
if g.inplace is not None:
mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g))
if g.mutable is not None:
mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g))
return mutation_defs
return []

View File

@ -0,0 +1,581 @@
from __future__ import annotations
import argparse
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Sequence
import yaml
import torchgen.dest as dest
from torchgen.api.lazy import setValueT
from torchgen.api.types import BaseCppType
from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.gen_backend_stubs import (
error_on_missing_kernels,
gen_dispatcher_registrations,
gen_dispatchkey_nativefunc_headers,
parse_backend_yaml,
)
from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager, NamespaceHelper
from torchgen.yaml_utils import YamlLoader
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Lazy Tensor Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Overview
# ~~~~~~~~
#
# This codegen script builds on existing data models and helpers used
# by all ATen backends, and adds new functionality specific to lazy
# tensor backends.
#
# Inputs:
# - <backend>_native_functions.yaml: controls which operators are
# supported by the backend.
#
# Outputs:
# (for all backends)
# <DispatchKey>Ir.h defines Lazy IR classes to be constructed during tracing
# - opt-in: also generate 'lowering' methods for the TorchScript backend only
# <DispatchKey>NativeFunctions.cpp defines implementations of native functions which perform lazy tracing
# - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations
# <DispatchKey>NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen'
# ops
#
# Register<DispatchKey>.cpp registers all op implementations with the dispatcher
# RegisterAutograd<DispatchKey>.cpp registers all autograd implementations with the dispatcher
#
# Validation Helpers:
# - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or
# implementations in torch/csrc/lazy/core/shape_inference.*
# - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend
# (non-codegen) implementation file
#
#
# About the Data Model
# ~~~~~~~~~~~~~~~~~~~~
#
# Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators
# we care about. In this case, the <backend>_native_functions yaml defines a subset of the core operators
# (defined in more detail in the main native_functions.yaml), which will be supported by your backend.
# Backends can list ops in two categories:
# - `supported` ops require hand-implementations but still get codegenned declarations and registrations
# - `full_codegen` ops get implementations (and IR classes) generated too
#
# Each native function is modeled as an object with a schema, and each schema has objects representing their
# arguments. Much of the codegen is manipulation of the arguments and their types. For example, lazy tensor
# backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference
# types (stringref) with actual string objects, and this is done by manipulating the data model objects.
# - see api/lazy.py for the lazy data model
#
# Once the data model is set up, the rest of this script processes a number of templates for output CPP file
# and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`. These
# helpers mostly iterate over functions and their arguments, outputting different c++ snippets.
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen)
ParsedExternalYaml = namedtuple(
"ParsedExternalYaml",
["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
)
def parse_native_functions_keys(
backend_yaml_path: str,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
with open(backend_yaml_path) as f:
yaml_values = yaml.load(f, Loader=YamlLoader)
assert isinstance(yaml_values, dict)
full_codegen = yaml_values.pop("full_codegen", [])
non_native = yaml_values.pop("non_native", [])
ir_gen = yaml_values.pop("ir_gen", [])
assert isinstance(full_codegen, list)
assert isinstance(non_native, list)
assert isinstance(ir_gen, list)
full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
return full_codegen_opnames, non_native, ir_gen_opnames
def validate_shape_inference_header(
shape_inference_hdr: str, expected_shape_infr_decls: list[str]
) -> None:
try:
with open(shape_inference_hdr) as f:
shape_infr_decls = f.read()
shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
except OSError as e:
raise AssertionError(
f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
) from e
# TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
missing_decls = [
decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
]
if missing_decls:
raise Exception( # noqa: TRY002
f"""Missing shape inference function.\n
Please add declare this function in {shape_inference_hdr}:\n
and implement it in the corresponding shape_inference.cpp file.\n
{os.linesep.join(missing_decls)}"""
)
# Some helper functions for the codegen.
def get_ltc_helper_fns() -> str:
return """\
at::Tensor to_meta(const at::Tensor& tensor) {
// undefined tensors can't be converted to the meta device, since they don't have sizes/strides
if (!tensor.defined()) return tensor;
auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
/*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
/*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
// needs to handle wrapped numbers, so dtype promotion works properly.
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
out.unsafeGetTensorImpl()->set_wrapped_number(true);
}
return out;
}
std::optional<at::Tensor> to_meta(const std::optional<at::Tensor>& tensor) {
if (tensor.has_value()) {
return to_meta(*tensor);
}
return std::nullopt;
}
std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
std::vector<at::Tensor> outs;
outs.reserve(t_list.size());
for (const auto& tensor : t_list) {
outs.push_back(to_meta(tensor));
}
return outs;
}
"""
class default_args:
node_base: str = "Node"
node_base_hdr: str | None = None
shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
tensor_class: str = "torch::lazy::LazyTensor"
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
lazy_ir_generator: type[GenLazyIR] = GenLazyIR
native_func_definition_generator: type[
GenLazyNativeFuncDefinition
] = GenLazyNativeFuncDefinition
backend_name: str = "TorchScript"
def main() -> None:
parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
parser.add_argument(
"-s",
"--source-yaml",
"--source_yaml",
help="path to source yaml file containing operator external definitions",
)
parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
parser.add_argument(
"--dry-run", "--dry_run", type=bool, default=False, help="output directory"
)
parser.add_argument(
"--impl-path",
"--impl_path",
type=str,
default=None,
help="path to the source C++ file containing kernel definitions",
)
parser.add_argument(
"--gen-ts-lowerings",
"--gen_ts_lowerings",
action="store_true",
help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
)
parser.add_argument(
"--node-base",
"--node_base",
type=str,
default=default_args.node_base,
help="Name of backend specific custom Lazy IR Node base class",
)
parser.add_argument(
"--node-base-hdr",
"--node_base_hdr",
type=str,
default=default_args.node_base_hdr,
help="Path to header file defining custom Lazy IR Node base class",
)
parser.add_argument(
"--shape-inference-hdr",
"--shape_inference_hdr",
type=str,
default=default_args.shape_inference_hdr,
help="Path to header file defining custom Lazy shape inference functions",
)
parser.add_argument(
"--tensor-class",
"--tensor_class",
type=str,
default=default_args.tensor_class,
help="Name of backend specific custom Lazy Tensor class",
)
parser.add_argument(
"--tensor-class-hdr",
"--tensor_class_hdr",
type=str,
default=default_args.tensor_class_hdr,
help="Path to header file defining custom Lazy Tensor class",
)
parser.add_argument(
"--backend-name",
"--backend_name",
type=str,
default=default_args.backend_name,
help="Name of the backend to generate",
)
options = parser.parse_args()
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
torch_root = Path(__file__).parent.parent.parent.absolute()
aten_path = str(torch_root / "aten" / "src" / "ATen")
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
if options.gen_ts_lowerings:
lazy_ir_generator = GenTSLazyIR
native_func_definition_generator: type[
GenLazyNativeFuncDefinition
] = default_args.native_func_definition_generator
run_gen_lazy_tensor(
aten_path,
options.source_yaml,
options.output_dir,
options.dry_run,
options.impl_path,
options.node_base,
options.node_base_hdr,
options.tensor_class,
options.tensor_class_hdr,
options.shape_inference_hdr,
lazy_ir_generator,
native_func_definition_generator,
options.backend_name,
)
def run_gen_lazy_tensor(
aten_path: str,
source_yaml: str,
output_dir: str,
dry_run: bool,
impl_path: str | None,
node_base: str = default_args.node_base,
node_base_hdr: str | None = default_args.node_base_hdr,
tensor_class: str = default_args.tensor_class,
tensor_class_hdr: str = default_args.tensor_class_hdr,
shape_inference_hdr: str = default_args.shape_inference_hdr,
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
native_func_definition_generator: type[
GenLazyNativeFuncDefinition
] = default_args.native_func_definition_generator,
# build_in_tree is true for TS backend and affects include paths
build_in_tree: bool = False,
# per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
# it must match how ATen was built
per_operator_headers: bool = False,
backend_name: str = default_args.backend_name,
gen_forced_fallback_code: bool = False,
use_lazy_shape: bool = True,
# the following arguments are temporary customization points for xla backend migration.
# do not rely on them otherwise, they should be removed once migration is complete
backend_namespace: str = "torch::lazy",
get_tensorlist: str = "GetTensorList",
get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
try_get_tensor: str = "TryGetLtcTensor",
metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
create_tensor: str = "LazyTensor::Create",
create_from_first_tensor: bool = False,
create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
lazy_value_class: str = "torch::lazy::Value",
lazy_tensor_ptr: str = "LazyTensorPtr",
get_device_fn: str = "torch::lazy::GetBackendDevice",
) -> None:
lv_tokens = lazy_value_class.split("::")
lv_class = lv_tokens[-1]
lv_ns = "::".join(lv_tokens[:-1])
setValueT(BaseCppType(lv_ns, lv_class))
template_dir = os.path.join(aten_path, "templates")
def make_file_manager(install_dir: str) -> FileManager:
return FileManager(
install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
)
fm = make_file_manager(output_dir)
native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
native_functions, backend_indices = (
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
)
grouped_native_functions = get_grouped_native_functions(native_functions)
def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
"""
We sort the native function because of the note in concat_map_codegen.
TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
"""
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
return str(func.name.name)
grouped_native_functions = sorted(
grouped_native_functions, key=sort_native_function
)
parsed_backend_yaml = parse_backend_yaml(
source_yaml, grouped_native_functions, backend_indices
)
backend_key = parsed_backend_yaml.backend_key
autograd_key = parsed_backend_yaml.autograd_key
cpp_namespace = parsed_backend_yaml.cpp_namespace
backend_indices = parsed_backend_yaml.backend_indices
# the following 3 keys are all processed differently
# for full_codegen, we generate IR, kernels, etc
# for ir_gen, we generate only IR
# non_native is used to register kernels not declared in
# native_functions.yaml
full_codegen, non_native, ir_gen = parse_native_functions_keys(
source_yaml, grouped_native_functions
)
def concat_map_codegen(
func: Callable[[NativeFunction], Sequence[str]],
xs: Iterable[NativeFunctionsGroup | NativeFunction],
ops_list: list[OperatorName] = full_codegen,
) -> Iterator[str]:
"""
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
only code-gen additional entries for the inplace variant for the native functions.
"""
for x in xs:
fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
for f in fs:
if f.func.name in ops_list:
yield from func(f)
selector = SelectiveBuilder.get_nop_selector()
assert backend_key is not None
class_name = backend_indices[backend_key].native_function_class_name()
if impl_path is not None:
error_on_missing_kernels(
native_functions,
backend_indices,
backend_key,
autograd_key,
class_name,
impl_path,
full_codegen,
)
""" Validate Shape Inference Definitions
Generated lazy native functions all perform shape inference, by first using a meta:: kernel
if available for that op, and otherwise using a 'compute_shape_{op}' function instead. The generator
knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature,
so it just has to check whether the op is structured and generate a call for one or the other. It's up to the dev
to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
the expected signature which can be copy-pasted into shape_inference.h.
compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
to structured kernels.
See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
"""
if shape_inference_hdr is not None:
expected_shape_infr_decls = list(
concat_map_codegen(
dest.GenLazyShapeInferenceDefinition(
backend_indices[backend_key], tensor_class
),
grouped_native_functions,
)
)
validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
assert class_name is not None
# Generate nativefunction declarations
# Note, eager registrations is set to False for the lazy TS backend as another LTC backend
# may want to register their own lazy kernels instead of registering the TS ones.
# The registration will lazily happen when init_ts_backend is called.
gen_dispatchkey_nativefunc_headers(
fm,
class_name,
cpp_namespace,
backend_indices,
grouped_native_functions,
backend_key,
autograd_key,
backend_name,
)
# Generate Dispatcher registrations which hook up the nativefunctions
for dispatch_key in (
[backend_key] if autograd_key is None else [backend_key, autograd_key]
):
gen_dispatcher_registrations(
fm,
output_dir,
class_name,
backend_indices,
grouped_native_functions,
backend_key,
dispatch_key,
selector,
build_in_tree=build_in_tree,
per_operator_headers=per_operator_headers,
backend_name=backend_name,
eager_registration=False,
)
# Generate native function impls that build IR nodes
ns_helper = NamespaceHelper(cpp_namespace)
fm.write_with_template(
f"{backend_key}NativeFunctions.cpp",
"DispatchKeyNativeFunctions.cpp",
lambda: {
"includes": [
f"#include <{path}>"
for path in [
tensor_class_hdr,
shape_inference_hdr,
"ATen/Functions.h",
"ATen/native/TensorConversions.h",
"ATen/NativeFunctions.h",
"ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
"ATen/MetaFunctions.h",
"ATen/Operators.h",
"ATen/native/CPUFallback.h",
"torch/csrc/lazy/core/ir_builder.h",
"torch/csrc/lazy/core/lazy_graph_executor.h",
"torch/csrc/lazy/core/metrics.h",
"torch/csrc/lazy/core/shape.h",
f"{output_dir}/{backend_key}NativeFunctions.h",
f"{output_dir}/LazyIr.h",
]
+ (
["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
if gen_forced_fallback_code
else []
)
],
"helper_fns": get_ltc_helper_fns(),
"native_functions_include": "",
"namespace_prologue": ns_helper.prologue,
"namespace_epilogue": ns_helper.epilogue,
"native_function_definitions": list(
concat_map_codegen(
native_func_definition_generator(
f"{backend_key}NativeFunctions",
backend_indices[backend_key],
tensor_class,
gen_forced_fallback_code,
backend_namespace,
get_tensorlist,
get_tensor_or_wrap_number,
try_get_tensor,
metrics_counter,
create_tensor,
create_from_first_tensor,
create_aten_from_ltc_tensor,
tuple_aten_from_ltc_tensors,
lazy_tensor_ptr,
get_device_fn,
),
grouped_native_functions,
)
),
},
)
# Generate IR node classes
lazy_ir_obj = lazy_ir_generator(
backend_indices[backend_key], backend_name, node_base, use_lazy_shape
)
fm.write_with_template(
"LazyIr.h",
"LazyIr.h",
lambda: {
"lazy_ir_sysinc": [
f"#include <{path}>"
for path in [
"ATen/core/Formatting.h",
"c10/core/ScalarType.h",
"torch/csrc/lazy/core/hash.h",
"torch/csrc/lazy/core/ir.h",
"torch/csrc/lazy/core/shape.h",
"optional",
"vector",
]
],
"lazy_ir_inc": [f'#include "{node_base_hdr}"']
if node_base_hdr is not None
else [],
"ir_declarations": list(
concat_map_codegen(
lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
)
),
"namespace_prologue": ns_helper.prologue,
"namespace_epilogue": ns_helper.epilogue,
},
)
# Generate Non Native IR Node classes
fm.write_with_template(
"LazyNonNativeIr.h",
"LazyNonNativeIr.h",
lambda: {
"lazy_non_native_ir_inc": [
f"#include <{path}>"
for path in [
"torch/csrc/lazy/core/ir.h",
"torch/csrc/lazy/core/ir_builder.h",
"torch/csrc/lazy/core/internal_ops/ltc_ops.h",
"torch/csrc/lazy/core/shape_inference.h",
]
+ ([node_base_hdr] if node_base_hdr else [])
if path
],
"non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
non_native, lazy_ir_obj
),
"namespace_prologue": ns_helper.prologue,
"namespace_epilogue": ns_helper.epilogue,
},
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,97 @@
from typing import Any, Optional, Tuple, Union
from torchgen.model import (
Annotation,
Argument,
Arguments,
BaseOperatorName,
BaseTy,
BaseType,
CustomClassType,
FunctionSchema,
ListType,
OperatorName,
Return,
)
# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
# their schemas can vary for different instances of the same HOP.
class TypeGen:
convert_to_base_ty = {
int: BaseTy.int,
float: BaseTy.float,
str: BaseTy.str,
bool: BaseTy.bool,
}
@staticmethod
def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
import torch
if isinstance(obj, torch.fx.GraphModule):
return BaseType(BaseTy.GraphModule)
elif isinstance(obj, torch.Tensor):
return BaseType(BaseTy.Tensor)
elif isinstance(obj, torch.SymInt):
return BaseType(BaseTy.SymInt)
elif isinstance(obj, torch.SymBool):
return BaseType(BaseTy.SymBool)
elif isinstance(obj, torch.ScriptObject):
return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
elif isinstance(obj, (list, tuple)):
assert len(obj) > 0
all_base_tys = [TypeGen.from_example(x) for x in obj]
if len(set(all_base_tys)) > 1:
raise RuntimeError(
f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
"Consider unpacking the argument and give proper names to them if possible "
"instead of using *args."
)
return ListType(all_base_tys[0], len(obj))
tp = type(obj)
if tp not in TypeGen.convert_to_base_ty:
raise RuntimeError(f"unsupported type {tp}")
return BaseType(TypeGen.convert_to_base_ty[tp])
class ReturnGen:
@staticmethod
def from_example(
name: Optional[str], obj: Any, annotation: Optional[Annotation]
) -> Return:
return Return(name, TypeGen.from_example(obj), annotation)
class ArgumentGen:
@staticmethod
def from_example(
name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
) -> Argument:
return Argument(
name, TypeGen.from_example(obj), default=default, annotation=annotation
)
class FunctionSchemaGen:
@staticmethod
def from_example(
op_name: str,
example_inputs: Tuple[Tuple[str, Any], ...],
example_outputs: Tuple[Any, ...],
) -> FunctionSchema:
args = []
for name, inp in example_inputs:
args.append(ArgumentGen.from_example(name, inp, None, None))
# ignore the annotations and other attributes for now, we could add more when needed.
arguments = Arguments(
tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
)
returns = tuple(
ReturnGen.from_example(None, out, None) for out in example_outputs
)
op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
return FunctionSchema(op_name, arguments, returns)

View File

@ -0,0 +1,271 @@
from __future__ import annotations
import textwrap
from dataclasses import dataclass
from typing import Sequence
from torchgen.api.translate import translate
from torchgen.api.types import DispatcherSignature
from torchgen.context import method_with_native_function
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
OptionalType,
Return,
SchemaKind,
Type,
)
from torchgen.utils import mapMaybe
def is_tensor(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
def is_optional_tensor(typ: Type) -> bool:
return isinstance(typ, OptionalType) and is_tensor(typ.elem)
def is_tensor_list(typ: Type) -> bool:
return isinstance(typ, ListType) and is_tensor(typ.elem)
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
result = f"""\
auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});"""
return textwrap.dedent(result).split("\n")
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
result = f"""\
std::optional<Tensor> {name}_value;
std::optional<int64_t> {name}_bdim;
if ({name}) {{
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
}}"""
return textwrap.dedent(result).split("\n")
def gen_unwraps(
flat_arguments: Sequence[Argument], cur_level_var: str
) -> tuple[str, list[str]]:
arg_names = [a.name for a in flat_arguments]
arg_types = [a.type for a in flat_arguments]
tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
optional_tensors = [
name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
]
unwraps = []
for tensor in tensors:
unwraps += unwrap_tensor(tensor, cur_level_var)
for opt_tensor in optional_tensors:
unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
unwrap_code = "\n".join(unwraps)
unwrapped_arg_list = []
for arg in arg_names:
if arg in tensors or arg in optional_tensors:
unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
else:
unwrapped_arg_list.append(arg)
return unwrap_code, unwrapped_arg_list
def gen_case_where_all_bdims_are_none(
outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
) -> str:
conditions = []
flat_args = schema.arguments.flat_all
for arg in flat_args:
if not arg.type.is_tensor_like():
continue
conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
sig = DispatcherSignature.from_schema(schema)
translated_args = ", ".join(
e.expr for e in translate(outer_sig.arguments(), sig.arguments())
)
return f"""\
if ({' && '.join(conditions)}) {{
return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
}}"""
def gen_returns(
returns: tuple[Return, ...], cur_level_var: str, results_var: str
) -> str:
idx = 0
wrapped_returns = []
for ret in returns:
if is_tensor(ret.type):
wrapped_returns.append(
f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
)
idx += 2
elif is_tensor_list(ret.type):
wrapped_returns.append(
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
)
idx += 2
else:
wrapped_returns.append(f"std::get<{idx}>({results_var})")
idx += 1
if len(wrapped_returns) == 1:
result = f"return {wrapped_returns[0]};"
else:
result = f'return std::make_tuple({", ".join(wrapped_returns)});'
return result
def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
def is_mutated_arg(argument: Argument) -> bool:
return argument.annotation is not None and argument.annotation.is_write
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
# Assumptions:
# - only one argument is being modified in-place
# - the argument that is being modified in-place is the first argument
# - all returns are either Tensor, tuple of Tensor, or TensorList
schema = native_function.func
sig = DispatcherSignature.from_schema(schema)
returns = schema.returns
# Check assumptions. If these are invalid we return None
# and punt the work to handle them to the future.
assert schema.kind() == SchemaKind.inplace
if not is_mutated_arg(schema.arguments.flat_all[0]):
return None
if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
return None
# Only support cases where all returns are Tensors or vector<Tensor>
if len(returns) == 0:
return None
if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
return None
if not accepts_at_least_one_tensor_input(schema):
return None
cur_level_var = "cur_level"
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule>
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
int64_t {cur_level_var} = maybe_layer->layerId();
{textwrap.indent(bdims_all_none_case, " ")}
{textwrap.indent(unwraps, " ")}
batch_rule({', '.join(unwrapped_arg_list)});
return {schema.arguments.flat_all[0].name};
}}"""
def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
schema = native_function.func
sig = DispatcherSignature.from_schema(schema)
cur_level_var = "cur_level"
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule>
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
int64_t {cur_level_var} = maybe_layer->layerId();
{textwrap.indent(bdims_all_none_case, " ")}
{textwrap.indent(unwraps, " ")}
batch_rule({', '.join(unwrapped_arg_list)});
}}"""
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
schema = native_function.func
sig = DispatcherSignature.from_schema(schema)
returns = schema.returns
# Only support cases where all returns are Tensors or vector<Tensor>
if not accepts_at_least_one_tensor_input(schema):
return None
if len(returns) == 0:
return gen_vmap_plumbing_no_returns(native_function)
return_symint_overrides = [
"_scaled_dot_product_flash_attention",
"_scaled_dot_product_cudnn_attention",
]
if (
not all(ret.type.is_tensor_like() for ret in returns)
and schema.name.unambiguous_name() not in return_symint_overrides
):
return None
# in-place views need special handling
if "inplace_view" in native_function.tags:
return None
if schema.kind() == SchemaKind.inplace:
return gen_vmap_inplace_plumbing(native_function)
# Don't support these (mutable, out, scratch)
if schema.kind() != SchemaKind.functional:
return None
results_var = "results"
cur_level_var = "cur_level"
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
wrapped_returns = gen_returns(returns, cur_level_var, results_var)
return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule>
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
int64_t {cur_level_var} = maybe_layer->layerId();
{textwrap.indent(bdims_all_none_case, " ")}
{textwrap.indent(unwraps, " ")}
auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
{wrapped_returns}
}}"""
@dataclass(frozen=True)
class ComputeBatchRulePlumbing:
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
result = gen_vmap_plumbing(f)
return result
def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
return f"""
#pragma once
#include <ATen/Operators.h>
#include <ATen/functorch/PlumbingHelper.h>
namespace at {{ namespace functorch {{
{body}
}}}} // namespace at::functorch
"""

View File

@ -0,0 +1,59 @@
from __future__ import annotations
import threading
from contextlib import contextmanager
from typing import Iterator
# Simple dynamic scoping implementation. The name "parametrize" comes
# from Racket.
#
# WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
# why you need to add a toggle to the global behavior of code
# generation. The parameters here should really only be used
# for "temporary" situations, where we need to temporarily change
# the codegen in some cases because we cannot conveniently update
# all call sites, and are slated to be eliminated once all call
# sites are eliminated. If you don't have a plan for how to get there,
# DON'T add a new entry here.
class Locals(threading.local):
use_const_ref_for_mutable_tensors: bool | None = None
use_ilistref_for_tensor_lists: bool | None = None
_locals = Locals()
def use_const_ref_for_mutable_tensors() -> bool:
assert _locals.use_const_ref_for_mutable_tensors is not None, (
"need to initialize local.use_const_ref_for_mutable_tensors with "
"local.parametrize"
)
return _locals.use_const_ref_for_mutable_tensors
def use_ilistref_for_tensor_lists() -> bool:
assert _locals.use_ilistref_for_tensor_lists is not None, (
"need to initialize local.use_ilistref_for_tensor_lists with "
"local.parametrize"
)
return _locals.use_ilistref_for_tensor_lists
@contextmanager
def parametrize(
*, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
) -> Iterator[None]:
old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
try:
_locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
_locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
yield
finally:
_locals.use_const_ref_for_mutable_tensors = (
old_use_const_ref_for_mutable_tensors
)
_locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,646 @@
from __future__ import annotations
from collections import defaultdict
from typing import Sequence
import torchgen.api.dispatcher as dispatcher
from torchgen.api.translate import translate
from torchgen.api.types import Binding, DispatcherSignature, Expr
from torchgen.context import with_native_function
from torchgen.model import (
Annotation,
Argument,
BackendIndex,
BackendMetadata,
BaseOperatorName,
BaseTy,
BaseType,
DEFAULT_KERNEL_NAMESPACE,
DeviceCheckType,
DispatchKey,
FunctionSchema,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
Return,
SchemaKind,
Variant,
)
from torchgen.utils import concatMap
# See Note: [Out ops with functional variants that don't get grouped properly]
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# This has a functional variant, but it's currently marked private.
# This function should be marked private as well (*_backward ops aren't exposed to python anyway).
"adaptive_avg_pool3d_backward.grad_input",
# There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
# Maybe we can kill this operator in favor of convolution_backward?
"_slow_conv2d_backward.grad_input",
]
# See Note: [Mutable ops that cannot get an out variant]
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
# should be out=?
"_cummax_helper",
# should be out=?
"_cummin_helper",
]
# All of these operators don't have any tensor like returns
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"_assert_async", # no return
"_assert_async.msg", # no return
"_cslt_sparse_mm_search", # returns an int
"_assert_scalar", # no return
"_dimI", # returns an int
"_dimV", # returns an int
"_has_same_storage_numel", # returns a boolean
"_linalg_check_errors", # no return
"_local_scalar_dense", # returns a Scalar
"_nested_tensor_from_mask_left_aligned", # returns a boolean
"_nnz", # returns an int
"_use_cudnn_ctc_loss", # returns a boolean
"_use_cudnn_ctc_loss.Tensor", # returns a boolean
"_validate_compressed_sparse_indices", # no return
"allclose", # returns a boolean
"dense_dim", # returns an int
"equal", # returns a boolean
"is_coalesced", # returns an boolean
"is_pinned", # returns a boolean
"is_same_size", # returns a boolean
"is_set_to", # returns a boolean
"q_per_channel_axis", # returns an int
"q_scale", # returns a float
"q_zero_point", # returns an int
"qscheme", # returns a QScheme
"record_stream", # no return
"sparse_dim", # returns an int
"sym_constrain_range", # no return
"sym_constrain_range_for_size", # no return
"_nested_tensor_storage_offsets", # returns a vector of ints
"_chunk_grad_outputs_efficient_attention", # returns a bool
"_fused_sdp_choice", # returns an int
"_print", # no return
"_sink_tokens", # no return
"_nested_get_ragged_idx", # returns an int
]
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# polygamma and polygamma.out both exist, but have a
# pre-self arg (while polygamma_ does not)
# We should either fix this schema so it can be grouped properly,
# or allow the codegen to generate new functional/out= NativeFunctions for this op
# (which would require changing its overload name to prevent overload ambiguity).
"polygamma_"
]
# Groups "similar" NativeFunctions together
# example add.Tensor, add_.Tensor, add.out
# "similar" NativeFunctions are all expected to have an identical `signature()`,
# But have differing SchemaKinds.
def pre_group_native_functions(
native_functions: Sequence[NativeFunction],
) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
pre_grouped_native_functions: dict[
FunctionSchema, dict[SchemaKind, NativeFunction]
] = defaultdict(dict)
for f in native_functions:
d = pre_grouped_native_functions[f.func.signature()]
assert f.func.kind() not in d
d[f.func.kind()] = f
return pre_grouped_native_functions
# Returns the out variant overload name given a base function overload name
def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
return "out" if not overload_name else f"{overload_name}_out"
# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
# Example before:
# _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
# Example after:
# _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Generating an out= schema from an inplace schema.
assert func.kind() == SchemaKind.inplace
assert func.arguments.self_arg is not None
# The new out= schema has:
# - a new out argument with the same type as "func" (but with a mutable annotation)
# - The returns (if any) now alias the out= argument instead of "func"
# - an "out" overload name
return FunctionSchema(
name=func.name.remove_inplace().with_overload(
get_expected_out_variant_overload_name(func.name.overload_name)
),
arguments=func.arguments.remove_self_annotation().with_out_args(
[
Argument(
name="out",
type=func.arguments.self_arg.argument.type,
default=None,
annotation=func.arguments.self_arg.argument.annotation,
)
]
),
returns=func.returns,
)
# Helper function: given a functional FunctionSchema, generate its corresponding out= variant
# Example before:
# _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
# Example after:
# _to_copy._out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None,
# Tensor(a!) out) -> Tensor(a!)
def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Generating an out= schema from a functional schema.
assert func.kind() == SchemaKind.functional
new_returns, new_out_args = generate_out_args_from_schema(func)
# The new out= schema has:
# - one or more new out argument(s) with the same type as returns (but with a mutable annotation)
# - The returns now alias the out= arguments
# - an "_out" overload name
return FunctionSchema(
name=func.name.with_overload(
get_expected_out_variant_overload_name(func.name.overload_name)
),
arguments=func.arguments.signature().with_out_args(
new_out_args,
),
returns=tuple(new_returns),
)
# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
def generate_out_args_from_schema(
func: FunctionSchema,
) -> tuple[list[Return], list[Argument]]:
# More of a sanity check - our existing restrictions on schemas should enforce that
# mutable schema kinds never return their mutable arguments.
assert not any(
r.annotation is not None and r.annotation.is_write for r in func.returns
)
tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
assert len(tensorlike_rets) > 0
used_annotations = concatMap(
lambda a: [] if a.annotation is None else a.annotation.alias_set,
func.arguments.flat_all,
)
valid_annotations = [
x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
]
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
new_out_args: list[Argument] = []
# The end result of new_returns is that:
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
new_returns: list[Return] = []
for i, r in enumerate(func.returns):
if r.type.is_tensor_like():
new_out = Argument(
name="out" if len(func.returns) == 1 else f"out{i}",
type=r.type,
default=None,
annotation=Annotation.parse(f"{valid_annotations[i]}!"),
)
new_out_args.append(new_out)
if all_rets_are_tensors:
# The convention for out= schemas is that they only return their out arguments
# if the return is a plain Tensor (or if it's a tuple of plain Tensors)
new_ret = Return(
name=None, type=new_out.type, annotation=new_out.annotation
)
new_returns.append(new_ret)
else:
new_returns.append(r)
return new_returns, new_out_args
# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
# Example before:
# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
# Example after:
# _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950
def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Generating an out= schema from a mutable schema.
assert func.kind() == SchemaKind.mutable
# The new out= schema has:
# - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
# (if the argument is a tensor then we also return it for method chaining,
# otherwise we return nothing)
# - an "out" overload name
#
# Note that:
# (1) This also means that we can *only* generate an out= variant from a mutable schema
# if the mutable schema has at least one tensor-like non-aliasing return.
# (2) The generated out= variant still has mutable positional arguments,
# but if necessary we could probably add another out= variant that also
# functionalizes the mutable arguments (a functional_out variant)
new_returns, new_out_args = generate_out_args_from_schema(func)
return FunctionSchema(
name=func.name.remove_inplace().with_overload(
get_expected_out_variant_overload_name(func.name.overload_name)
),
arguments=func.arguments.with_out_args(new_out_args),
returns=tuple(new_returns),
)
# This function, given function of one SchemaKind, as well as a target SchemaKind,
# generates a new NativeFunction with the same properties, but using the target SchemaKind.
# We only actually generate functions for either functional or out= SchemaKinds.
# This function returns a tuple, with:
# - The generated NativeFunction
# - a dictionary of `BackendIndex` objects, describing which dispatch keys
# we will generate kernels for, for the new NativeFunction.
# Details are in the function, but we only generate composite kernels (in some cases) today.
def generate_function(
f: NativeFunction, k: SchemaKind
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
from torchgen.api import cpp
if k == SchemaKind.functional:
assert f.func.kind() != SchemaKind.functional
# The new "functional" NativeFunction has:
# - any mutable arguments have been converted into (immutable) returns.
# (if a mutable argument was not also a return, it gets converted to one)
# - "_functional" appended to the base name, ONLY IF this op has a mutable variant.
# See Note [Overload Ambiguity With Functional Variants]
# The default grouping logic in signature() actually already does this,
# so we can piggy-back off it (but we still want return names)
func = f.func.signature(keep_return_names=True).with_name(
OperatorName(
name=BaseOperatorName(
base=f.func.name.name.base,
inplace=False,
dunder_method=f.func.name.name.dunder_method,
# See Note [Overload Ambiguity With Functional Variants]
functional_overload=f.func.kind() == SchemaKind.mutable,
),
overload_name=f.func.name.overload_name,
)
)
elif k == SchemaKind.out:
# We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
# but at least today, there is no good reason to actually use them.
# we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
if f.func.kind() == SchemaKind.inplace:
func = self_to_out_signature(f.func)
elif f.func.kind() == SchemaKind.mutable:
func = mutable_to_out_signature(f.func)
elif f.func.kind() == SchemaKind.functional:
func = functional_to_out_signature(f.func)
else:
raise AssertionError(
"We only bother generating out= functions from either inplace or mutable or functional variants"
)
else:
raise AssertionError(
"We currently only generate either functional or out= NativeFunctions"
)
# Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to
# disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and
# `randn.generator_with_names_out`.
kernel_name = (
func.name.unambiguous_name()
if func.kind() == SchemaKind.out
else cpp.name(func)
)
if f.func.has_symint():
kernel_name += "_symint"
backend_metadata = {
DispatchKey.CompositeExplicitAutograd: {
func.name: BackendMetadata(
kernel=kernel_name,
structured=False,
cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
)
}
}
tags = {"generated"} | set(
f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"}
)
return (
NativeFunction(
func=func,
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
# These generated fn's aren't meant to be user friendly- don't generate methods.
variants={Variant.function},
structured=False,
structured_delegate=None,
structured_inherits=None,
precomputed=None,
autogen=[],
ufunc_inner_loop={},
manual_kernel_registration=False,
manual_cpp_binding=False,
python_module=None,
category_override=None,
device_guard=False,
device_check=DeviceCheckType.NoCheck,
loc=f.loc,
cpp_no_default_args=set(),
is_abstract=f.is_abstract,
has_composite_implicit_autograd_kernel=False,
has_composite_implicit_autograd_nested_tensor_kernel=False,
has_composite_explicit_autograd_kernel=True,
has_composite_explicit_autograd_non_functional_kernel=False,
# Every generated NativeFunction gets a "generated" tag, so it's easy to tell
# which NativeFunction objects did not come directly from native_functions.yaml.
tags=tags,
namespace=f.namespace,
),
backend_metadata,
)
# This function is responsible for adding generated NativeFunctions which don't appear
# explicitly in the codegen.
# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
# (Maybe we should make a friendly API for this)
#
# Note: this function *mutates* its two inputs,
# adding the new NativeFunctions / BackendMetadata to them
def add_generated_native_functions(
rs: list[NativeFunction],
indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> None:
# The main code for generating new NativeFunctions
# First we group of NativeFunctions by schema kind,
# then we detect which ones are missing and generate them.
pre_grouped_native_functions = pre_group_native_functions(rs)
for d in pre_grouped_native_functions.values():
has_functional = SchemaKind.functional in d
has_inplace = SchemaKind.inplace in d
has_mutable = SchemaKind.mutable in d
has_out = SchemaKind.out in d
# We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
# (1) If an operator has an inplace/out= variant but no functional variant, we can generate
# a simple functional variant that the functionalization pass can consume.
# (2) If an operator has an inplace or functional but no out= variant, we generate an out=
# variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
# while maintaining the constraint that the out= variant is "required".
if has_mutable or has_inplace or has_out or has_functional:
# Don't bother generating functions trio's for native functions that bypass the dispatcher.
are_manual = all(f.manual_cpp_binding for f in d.values())
# Don't bother generating functional + out= variants for view operators
# set_ is technically an inplace_view, but for now it is treated
# as a normal inplace op in the codegen
has_view_ops = any(
f.is_view_op and str(f.func.name.name) != "set_" for f in d.values()
)
# Don't generate the other variants for CompositeImplicitAutograd operators.
# We could probably do this, but the main benefit of generating the function triplets
# is for transforms that need them, and transforms don't need to act directly
# on CompositeImplicitAutograd operators (since we let them decompose).
are_composite_implicit = all(
f.has_composite_implicit_autograd_kernel for f in d.values()
)
if are_manual or has_view_ops or are_composite_implicit:
continue
if has_out and len(d.values()) == 1:
# Note: [Out ops with functional variants that don't get grouped properly]
# In theory we could validly have an out= operator in native_functions.yaml
# that has no other variants.
# But today, all of the operators where that's the case actually do have
# functional variants, that we are just unable to pair up properly.
# I think banning this all together is probably safer
# (you can always add a functional variant yourself if you want to add a new out= operator).
#
# We should probably fix the existing cases; this check is to prevent us from adding more over time.
if (
str(d[SchemaKind.out].func.name)
not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
):
raise AssertionError(
f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
)
continue
# Some inplace ops that have problematic schemas (that we should fix), which prevent us
# from generating out= and functional variants
if (
has_inplace
and str(d[SchemaKind.inplace].func.name)
in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
):
continue
base_fn = (
d[SchemaKind.inplace]
if has_inplace
else d[SchemaKind.mutable]
if has_mutable
else d[SchemaKind.out]
if has_out
else d[SchemaKind.functional]
)
# Note: [Mutable ops that cannot get an out variant]
# We can only generate an out= variant if either:
# - the original function has tensor-like returns (since we can convert them to out kwargs)
# - or it's inplace (since we can convert `self` to an out kwarg)
# There are only two functions that don't fit this criteria today though,
# and they both look like they should be fixed to be out= variants,
# so if feels safer to ban this schema all-together
base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any(
r.type.is_tensor_like() for r in base_fn.func.returns
)
# Note: [Loosen the assertion that all functional should have out variant]
# By design all functional operators should have our variants. The needs_out check
# is loosening this requirement, changing it to only generate out variant if there's
# an `autogen` block in the native function, in the long run it should be removed.
# FIXME: Remove this after figuring out CI job failures related to min, max, mean
needs_out = any("out" in str(op_name) for op_name in base_fn.autogen)
gets_out_variant = not has_out and base_fn_valid and needs_out
if not has_out and not base_fn_valid:
if (
str(base_fn.func.name)
not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
and str(base_fn.func.name)
not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
):
raise AssertionError(
f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}.
This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If
out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list."""
)
# Generate an out= variant
if gets_out_variant:
fn, metadata = generate_function(base_fn, SchemaKind.out)
d[SchemaKind.out] = fn
BackendIndex.grow_index(indices, metadata)
rs.append(fn)
# Generate a functional variant, but only do it if the operator got an out= variant
# (Functional variants are only useful if we can group up the variants,
# which we can only do if they have an out= variant)
if not has_functional and (has_out or gets_out_variant):
fn, metadata = generate_function(base_fn, SchemaKind.functional)
d[SchemaKind.functional] = fn
BackendIndex.grow_index(indices, metadata)
rs.append(fn)
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
assert len(rets) == len(names)
if len(rets) == 0:
return ""
elif len(rets) == 1:
return f"return {names[0]};"
else:
return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
# Given a function, and the name of a variable corresponding to the output of that function,
# gather up all of the individual returns that are not aliased
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
aliased_rets = func.aliased_return_names()
non_aliased_names = []
is_out_var_a_tuple = len(func.returns) > 1
for i, r in enumerate(aliased_rets):
if r is None:
non_aliased_names.append(
f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var
)
return non_aliased_names
# Generates functional kernels in terms of their inplace.mutable counterparts.
# We only do this for "generated" NativeFunctions
@with_native_function
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
# We should only be generating these for code-generated NativeFunctions
if "generated" not in g.functional.tags:
return None
# And we always write the kernel for a generated op in terms of a non-generated op.
if g.inplace is not None and "generated" not in g.inplace.tags:
target_f = g.inplace
elif g.mutable is not None and "generated" not in g.mutable.tags:
target_f = g.mutable
else:
# We should be guaranteed to have a valid inplace/mutable variant to call into.
# See Note: [Mutable Ops Not Using Functionalization]
raise AssertionError(str(g.functional.func))
sig = DispatcherSignature(g.functional.func)
target_sig = DispatcherSignature(target_f.func)
context: list[Binding | Expr] = []
clone_mutable_inputs = []
cloned_return_names = []
# We can't just directly pass all of the arguments from the functional op into the mutating op.
# We need to check for which inputs to the mutating operator are mutable,
# and clone those inputs first.
for a_curr, a_tgt in zip(
dispatcher.jit_arguments(g.functional.func),
dispatcher.jit_arguments(target_f.func),
):
if a_tgt.annotation is not None and a_tgt.annotation.is_write:
clone_mutable_inputs.append(
f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
)
context.append(
Expr(
expr=f"{a_curr.name}_clone",
type=dispatcher.argument_type(a_curr, binds=a_curr.name),
)
)
# Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
cloned_return_names.append(f"{a_curr.name}_clone")
else:
context.append(dispatcher.argument(a_curr))
exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])
out_name = "output"
maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
ret_str = return_str(
g.functional.func.returns, inner_return_names + cloned_return_names
)
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
return f"""
{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
{clone_mutable_inputs_str}
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
{ret_str}
}}
"""
# Generates out= kernels in terms of their functional counterparts.
# We only do this for "generated" NativeFunctions
@with_native_function
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
# We should only be generating these for code-generated NativeFunctions
if "generated" not in g.out.tags:
return None
# And we always write the kernel for the out= op in terms of the functional.
# Note that the functional op might have also been generated, but we don't have to
# worry about cycles, because the generated functional kernels are always implemented
# in terms of non-generated kernels (see gen_composite_functional_kernel).
sig = DispatcherSignature(g.out.func)
target_sig = DispatcherSignature(g.functional.func)
exprs = ", ".join(
[e.expr for e in translate(sig.arguments(), target_sig.arguments())]
)
copy_outs = []
out_name = "tmp_output"
for i, out_arg in enumerate(g.out.func.arguments.out):
functional_return_name = (
out_name
if len(g.functional.func.returns) == 1
else f"std::get<{i}>({out_name})"
)
copy_outs.append(
f"""\
resize_out_helper({out_arg.name}, {functional_return_name});
copy_arg({out_arg.name}, {functional_return_name});"""
)
rets = []
# For each return arg in the calling (out=) operator,
# If it corresponds to an aliased input, return the input.
# Otherwise, return the corresponding output from calling the functional operator.
for i, ret_name in enumerate(g.out.func.aliased_return_names()):
if ret_name is not None:
rets.append(ret_name)
else:
functional_return_name = (
out_name
if len(g.functional.func.returns) == 1
else f"std::get<{i}>({out_name})"
)
rets.append(functional_return_name)
copy_outs_str = "\n".join(copy_outs)
# Kernel name needs to follow the naming convention defined in `generate_function()`
return f"""
{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
{copy_outs_str}
{return_str(g.out.func.returns, rets)}
}}
"""

Some files were not shown because too many files have changed in this diff Show More