I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,19 @@
from torchgen.dest.lazy_ir import (
generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
GenLazyIR as GenLazyIR,
GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
)
from torchgen.dest.native_functions import (
compute_native_function_declaration as compute_native_function_declaration,
)
from torchgen.dest.register_dispatch_key import (
gen_registration_headers as gen_registration_headers,
gen_registration_helpers as gen_registration_helpers,
RegisterDispatchKey as RegisterDispatchKey,
)
from torchgen.dest.ufunc import (
compute_ufunc_cpu as compute_ufunc_cpu,
compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
compute_ufunc_cuda as compute_ufunc_cuda,
)

View File

@ -0,0 +1,707 @@
from __future__ import annotations
import itertools
from abc import ABC
from dataclasses import dataclass
from typing import Any
import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import (
getValueT,
isValueType,
LazyArgument,
LazyIrProperties,
LazyIrSchema,
tensorListValueT,
)
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
deviceT,
DispatcherSignature,
kernel_signature,
NativeSignature,
OptionalCType,
VectorCType,
)
from torchgen.context import method_with_native_function
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
from torchgen.model import (
Argument,
BackendIndex,
BackendMetadata,
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsGroup,
)
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
"""
Given a LazyArgument,
generate a c++ string for materializing an rvalue of that arg for passing into
a lazy Node constructor.
"""
# TODO: Matching on CType seems wrong; should be matching on Type
if isValueType(arg.lazy_type):
if isinstance(arg.lazy_type, BaseCType):
if arg.is_wrapped_scalar:
return f"node_{arg.name}"
elif arg.lazy_type.type is tensorListValueT:
return f"lazy_{arg.name}_tensorlist"
elif arg.is_symint_or_list:
return f"GetSymIntValue({arg.name})"
return f"lazy_{arg.name}->GetIrValue()"
elif isinstance(arg.lazy_type, OptionalCType):
if arg.is_symint_or_list:
# TODO: I don't understand when you should put lazy_ in the name
# or not
return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
elif arg.is_wrapped_scalar:
return f"node_{arg.name}"
return (
f"lazy_{arg.name} ? "
f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
"::std::nullopt"
)
else:
raise AssertionError(
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
)
else:
# NB: this is here because right now we aren't treating SymInt[] as a
# value type; when we do this needs to move above
# NB: we cannot test arg.lazy_type as we've already specified it is an
# int64_t and so we cannot distinguish between SymInt and int64_t
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
BaseTy.SymInt
):
if arg.symint:
return f"GetSymIntArrayRefValue({arg.name})"
else:
return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
arg.lazy_type.elem, BaseCType
):
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
elif (
isinstance(arg.lazy_type, OptionalCType)
and isinstance(arg.lazy_type.elem, VectorCType)
and isinstance(arg.lazy_type.elem.elem, BaseCType)
):
return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
else:
return f"{arg.name}"
def node_ctor_inputs(schema: LazyIrSchema) -> str:
"""
Produce a formatted string with the arguments as passed into the constructor of a node class.
"""
node_ctor_values = [
node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
]
return ", ".join(node_ctor_values)
def gen_fallback_code(
schema: LazyIrSchema,
sig: DispatcherSignature | NativeSignature,
overload_name: str,
) -> str:
"""
Generate code that falls back to eager conditioned on a predicate
"""
dispatcher_sig = DispatcherSignature.from_schema(schema.func)
exprs = translate(sig.arguments(), dispatcher_sig.arguments())
fallback_args = ",\n ".join([a.expr for a in exprs])
if len(overload_name):
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
else:
aten_op_str = f"ATEN_OP({schema.aten_name})"
return f"""
if (force_eager_fallback({aten_symbol(schema)})) {{
return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
{fallback_args}
);
}}
"""
def aten_symbol(schema: LazyIrSchema) -> str:
missing_interned_strings = {
"sigmoid_backward",
}
if schema.aten_name in missing_interned_strings:
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
if not schema.aten_name.startswith("at::"):
return f"at::aten::{schema.aten_name}"
else:
return schema.aten_name
# converts all tensor-like arguments to meta tensors. Returns:
# (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
context: list[Binding] = []
unwrapped_tensor_args: list[str] = []
for arg in sig.arguments():
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
unwrapped_name = f"{arg.name}_meta"
unwrapped_tensor_args.append(
f"auto {unwrapped_name} = to_meta({arg.name});"
)
context.append(arg.with_name(unwrapped_name))
else:
context.append(arg)
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
return unwrap_tensor_args_str, context
@dataclass(frozen=True)
class GenLazyIR(ABC):
backend_index: BackendIndex
backend_name: str
node_base: str
use_lazy_shape: bool
@method_with_native_function
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
metadata = self.backend_index.get_kernel(
f.functional if isinstance(f, NativeFunctionsGroup) else f
)
schema = LazyIrSchema(
func, symint=metadata is not None and metadata.supports_symint()
)
return self.gen(schema)
# there is no lowering functionality generated unless this IR base class is subclassed and
# implemented as a backend-specific node
def lowering_function(self, schema: LazyIrSchema) -> str:
return ""
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return ""
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
return f"""bool CanBeReused({node_ctor_args}) const {{
return false;
}}"""
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
# backends can customize the way the node base class constructor is called,
# as long as all of its arguments can be generated from information available from the schema
base_ctor_value_args_list = []
for arg in value_args:
if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
base_ctor_value_args_list.append(f"{arg.name}")
elif isinstance(arg.lazy_type, OptionalCType):
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
else:
raise AssertionError(
f"Unsupported type ({arg.lazy_type}) - add support if necessary"
)
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
scalar_args = schema.filtered_args(values=False, scalars=True)
# Shape construction.
# Conditionally build shape depending on specified shape property
if schema.properties.ShapePrecompute:
shape_ctor_arg = "std::move(shapes),"
elif schema.properties.ShapeCompute:
shape_args = [a.name for a in value_args]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
elif schema.properties.ShapeCache:
shape_args = [f"operand({i})" for i in range(len(value_args))]
shape_args.extend(a.name for a in scalar_args)
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
else:
shape_ctor_arg = ""
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
return f"""{self.node_base}(
{schema.node_name}::ClassOpKind(),
OpList{{{base_ctor_value_args}}},
{shape_ctor_arg}
/* num_outputs */ {len(schema.returns)},
torch::lazy::MHash({scalar_hashes}))"""
def gen(self, schema: LazyIrSchema) -> list[str]:
opkind = schema.opkind or aten_symbol(schema)
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
all_args = schema.filtered_args()
scalar_args = schema.filtered_args(values=False, scalars=True)
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
reuse_ctor_args = ", ".join(ctor_args)
if self.use_lazy_shape and schema.properties.ShapePrecompute:
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
node_ctor_args = ", ".join(ctor_args)
scalar_initializers = ",\n ".join(
[
# This code is just special casing the mapping from string_view -> strings
f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
else f"{a.name}({a.name})"
for a in scalar_args
]
)
if len(scalar_initializers):
scalar_initializers = f",\n {scalar_initializers}"
scalar_decls = "\n ".join(
[
f"std::string {a.name};"
if a.lazy_type.cpp_type() == "c10::string_view"
else f"::std::optional<std::string> {a.name};"
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
else f"{a.lazy_type.cpp_type()} {a.name};"
for a in scalar_args
]
)
optional_values = [
arg.name
for arg in schema.filtered_args(values=True, scalars=False)
if isinstance(arg.lazy_type, OptionalCType)
]
has_optional_decls = "\n ".join(
[f"bool has_{value}: 1;" for value in optional_values]
)
has_optional_defs = "\n ".join(
[f"has_{value} = !!{value};" for value in optional_values]
)
members_to_string = []
for arg in scalar_args:
if isinstance(arg.lazy_type, OptionalCType):
value = f"{arg.name}.value()"
if arg.is_generator:
value = '"torch.Generator()"'
members_to_string.append(
f"""if ({arg.name}.has_value()) {{
ss << ", {arg.name}=" << {value};
}} else {{
ss << ", {arg.name}=null";
}}"""
)
else:
members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
members_to_string_str = "\n ".join(members_to_string)
return [
f"""\
class {schema.node_name} : public {self.node_base} {{
public:
static torch::lazy::OpKind ClassOpKind() {{
return torch::lazy::OpKind({opkind});
}}
{schema.node_name}({node_ctor_args})
: {self.node_base_ctor_call(schema)}{scalar_initializers}
{{
{has_optional_defs}
}}
std::string ToString() const override {{
std::stringstream ss;
ss << {self.node_base}::ToString();
{members_to_string_str}
return ss.str();
}}
{self.create_function(schema, reuse_ctor_args)}
{self.can_be_reused_function(schema, reuse_ctor_args)}
{self.lowering_function(schema)}
{scalar_decls}
{has_optional_decls}
}};
""",
]
@dataclass(frozen=True)
class GenTSLazyIR(GenLazyIR):
def lowering_function(self, schema: LazyIrSchema) -> str:
signature = """
torch::lazy::TSOpVector Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
torch::lazy::TSLoweringContext* loctx) const override"""
if schema.properties.LowerDeclOnly:
return f"{signature};"
elif schema.properties.Lower:
return f"""{signature} {{
{ts_lowering_body(schema)}
}}
"""
else:
return ""
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
signature = f"static NodePtr Create({node_ctor_args})"
if schema.properties.CreateFnDeclOnly:
return f"{signature};"
elif not schema.properties.CreateFn:
return ""
return f"""{signature} {{
return ReuseOrMakeNode<{schema.node_name}>(data);
}}"""
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
signature = f"bool CanBeReused({node_ctor_args}) const"
if schema.properties.CanBeReusedDeclOnly:
return f"{signature};"
elif not schema.properties.CanBeReused:
return ""
value_comparison = []
for arg in itertools.chain(schema.positional_values, schema.keyword_values):
if isinstance(arg.lazy_type, OptionalCType):
value_comparison.append(
f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
)
else:
value_comparison.append(f"operand(i++) == {arg.name}")
for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
if isinstance(arg.lazy_type, OptionalCType):
value_comparison.append(
f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
)
else:
value_comparison.append(f"this->{arg.name} == {arg.name}")
value_comparison_str = " &&\n ".join(value_comparison)
return f"""{signature} {{
size_t i = 0;
return ({value_comparison_str});
}}"""
@dataclass(frozen=True)
class GenLazyNativeFuncDefinition:
class_method_name: str
backend_index: BackendIndex
tensor_class: str
gen_forced_fallback_code: bool
backend_namespace: str
get_tensorlist: str
get_tensor_or_wrap_number: str
try_get_tensor: str
metrics_counter: str
create_tensor: str
create_from_first_tensor: bool
create_aten_from_ltc_tensor: str
tuple_aten_from_ltc_tensors: str
lazy_tensor_ptr: str
get_device_fn: str
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
lazy_tensor_decls: list[str] = []
for arg in value_args:
if arg.is_wrapped_scalar:
if isinstance(arg.lazy_type, OptionalCType):
lazy_tensor_decls.append(
f"""auto node_{arg.name} = {arg.name} ?
std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
::std::nullopt;"""
)
else:
lazy_tensor_decls.append(
f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
)
elif arg.is_symint_or_list:
continue # values are extracted in isValueType
elif isinstance(arg.lazy_type, BaseCType):
if arg.lazy_type.type is tensorListValueT:
lazy_tensor_decls.append(
f"auto lazy_{arg.name}_tensorlist = "
f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
)
else:
lazy_tensor_decls.append(
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
)
elif isinstance(arg.lazy_type, OptionalCType):
assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
# TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
# until we encounter a real world example.
lazy_tensor_decls.append(
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
)
else:
raise AssertionError(
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
)
return ("\n ").join(lazy_tensor_decls)
def force_eager_fallback(
self,
func: NativeFunction,
schema: LazyIrSchema,
metadata: BackendMetadata,
sig: DispatcherSignature | NativeSignature,
) -> str:
if self.gen_forced_fallback_code:
return gen_fallback_code(
schema, sig, overload_name=func.func.name.overload_name
)
return ""
def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
return f"{self.metrics_counter};"
def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
scalar_args = schema.filtered_args(values=False, scalars=True)
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
optional_device = OptionalCType(BaseCType(deviceT))
optional_devices = [
a.name for a in scalar_args if a.lazy_type == optional_device
]
assert (
len(value_types_names) > 0 or len(optional_devices) > 0
), "Expected at least one Value or Device type"
get_device_str = (
f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
)
return f"""auto common_device = {get_device_str};
TORCH_INTERNAL_ASSERT(common_device);
"""
def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
metadata = self.backend_index.get_kernel(func)
assert metadata is not None
all_args = schema.filtered_args()
returns_length = len(schema.returns)
# call the meta kernel if it exists, to compute output shape/dtype for our IR
# Note [Generated LTC Shape Functions]
# LTC uses meta tensors from core to do shape inference when possible, and otherwise
# we generate a shape function declaration that needs to be manually implemented.
# How do we detect which ops are eligible to use meta tensors?
# In general we should be able to use meta tensors not just on structured operators,
# but also on composite operators that are implemented in terms of structured kernels.
# We don't currently have a way of knowing at codegen time which ops are implemented that way.
# This is the case for all view and view_copy operators however, so we're going to
# use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
is_view_copy_op = "view_copy" in func.tags
is_structured = func.structured or func.structured_delegate is not None
if is_structured or is_view_copy_op:
meta_out = """
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
if returns_length > 1:
def this_shape(i: int) -> str:
return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
# Convert tensor args to the meta device and call it.
# (We can't pass in the input tensors directly, because they are "functional wrappers".
# If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
# Even at::meta:: functions might redispatch, e.g. if they call into view ops.
dispatcher_sig = DispatcherSignature.from_schema(func.func)
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
meta_call_args = [
e.expr
for e in translate(
meta_call_ctx, dispatcher_sig.arguments(), method=False
)
]
if is_view_copy_op:
# view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
assert func.has_composite_explicit_autograd_non_functional_kernel
dispatch_ns = "compositeexplicitautogradnonfunctional"
else:
dispatch_ns = "meta"
aten_name = schema.aten_name
# TODO: this is trolling
if func.func.has_symint() and metadata.supports_symint():
aten_name += "_symint"
shape_str = f"""\
{meta_conversion_str}
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
{meta_out}"""
else:
shape_sig = ComputeShapeSignature(
metadata.kernel, func, symint=metadata.supports_symint()
)
shape_str = f"""
auto shapes = {shape_sig.shape_call};"""
shape_str += f"""
TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
# Calculating which dimensions are symbolic
func_schema_str = "aten::" + str(func.func)
shape_str += f"""
if(torch::lazy::symbolicShapeEnabled()){{
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
const char* schema_str = "{func_schema_str}";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}}
"""
return shape_str
def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
node_ctor_input_str = node_ctor_inputs(schema)
return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
if (!node) {{
{self.shape_inference(func, schema)}
node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
CacheNode(node);
}}
"""
def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
# xla uses an instance method for tensor creation, for the time being
if self.create_from_first_tensor:
# TODO(whc) remove this if XLA switches to using static method for creation
assert (
first_tensor_name is not None
), "Requires first tensor to create lazy tensor"
return f"{first_tensor_name}.{self.create_tensor}"
return f"{self.backend_namespace}::{self.create_tensor}"
def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
returns_length = len(schema.returns)
value_args = schema.filtered_args(values=True, scalars=False)
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
if returns_length > 1:
assert (
len(value_types_names) > 0
), "Code below assumes there is at least one tensor arg"
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
for (int i = 0; i < {returns_length}; i++) {{
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
}}
auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
if schema.name.name.inplace or func.func.is_out_fn():
assert returns_length == 1, (
"We assumed there was no such case where an op is an in-place variant "
f"and has tuple outputs, but got tuple of len {returns_length}."
)
bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
auto& result = {first_tensor_name};"""
bridge_str += """
return result;"""
return bridge_str
@method_with_native_function
def __call__(self, func: NativeFunction) -> list[str]:
sig = kernel_signature(func, self.backend_index)
metadata = self.backend_index.get_kernel(func)
assert metadata is not None
schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
return [
f"""\
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
{self.force_eager_fallback(func, schema, metadata, sig)}
{self.metrics(func, schema)}
{self.get_device(func, schema)}
{self.lazy_tensor_decls(func, schema)}
{self.build_ir_node(func, schema)}
{self.return_aten_tensor(func, schema)}
}}\n
"""
]
class ComputeShapeSignature:
"""
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
"""
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
self.__schema = LazyIrSchema(f.func, symint=symint)
self.__dispatch_args = ", ".join(
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
)
self.__call_args = ", ".join(
[f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
)
self.__kernel_name = kernel_name
def __decl_suffix(self) -> str:
return f"{self.__kernel_name}({self.__dispatch_args})"
def __call_suffix(self) -> str:
return f"{self.__kernel_name}({self.__call_args})"
@property
def shape_decl(self) -> str:
return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
@property
def shape_call(self) -> str:
return f"torch::lazy::compute_shape_{self.__call_suffix()}"
@dataclass(frozen=True)
class GenLazyShapeInferenceDefinition:
backend_index: BackendIndex
tensor_class: str
@method_with_native_function
def __call__(self, f: NativeFunction) -> list[str]:
metadata = self.backend_index.get_kernel(f)
assert metadata is not None
# See Note [Generated LTC Shape Functions]
is_view_copy_op = "view_copy" in f.tags
is_structured = f.structured or f.structured_delegate is not None
if is_structured or is_view_copy_op:
return []
else:
shape_sig = ComputeShapeSignature(
metadata.kernel, f, symint=metadata.supports_symint()
)
return ["\n".join([f"{shape_sig.shape_decl};"])]
def generate_non_native_lazy_ir_nodes(
non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
) -> list[str]:
"""Generate the non-native lazy IR node classes"""
nodes = []
for op in non_native:
# Set default properties for Non-Native IRs
properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
for p in op.get("properties", []):
setattr(properties, p, True)
# non-native is assumed to want symint bindings if you wrote symint
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
schema.opkind = op.get("opkind")
nodes.append(gen_lazy_ir.gen(schema)[0])
return nodes

View File

@ -0,0 +1,48 @@
from torchgen.api.lazy import LazyArgument, LazyIrSchema
from torchgen.api.types import OptionalCType
def ts_lowering_body(schema: LazyIrSchema) -> str:
# for now, we just want one IR class decl and soon after also the method defs
# and we use the functional version not out/inplace.
emplace_arguments = []
def get_value(arg: LazyArgument) -> str:
if isinstance(arg.lazy_type, OptionalCType):
return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
return "loctx->GetOutputOp(operand(i++))"
for arg in schema.positional_args:
if arg.is_lazy_value:
emplace_arguments.append(get_value(arg))
continue
emplace_arguments.append(f'"{arg.name}", {arg.name}')
emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments]
)
emplace_kwarg_values = [
f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
]
emplace_kwarg_scalars = [
f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
]
emplace_kwarguments = "\n ".join(
[
f"kwarguments.emplace_back({a});"
for a in emplace_kwarg_values + emplace_kwarg_scalars
]
)
return f"""\
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
"""

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import torchgen.api.meta as meta
import torchgen.api.structured as structured
from torchgen.api.types import kernel_signature
from torchgen.context import with_native_function_and_index
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
from torchgen.utils import mapMaybe
@with_native_function_and_index
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
sig = kernel_signature(f, backend_index)
metadata = backend_index.get_kernel(f)
if metadata is None:
return None
if "legacy::" in metadata.kernel:
return None
else:
prefix = "static" if backend_index.external else "TORCH_API"
return f"{prefix} {sig.decl(name=metadata.kernel)};"
@with_native_function_and_index
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
meta_name = meta.name(g)
out_args = structured.impl_arguments(g)
metadata = backend_index.get_kernel(g)
if metadata is None:
return []
prefix = "" if backend_index.external else "TORCH_API "
return [
f"""\
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
"""
]
# Generates NativeFunctions.h, a list of forward declarations of all
# actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function_and_index
def compute_native_function_declaration(
g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
) -> list[str]:
metadata = backend_index.get_kernel(g)
if isinstance(g, NativeFunctionsGroup):
if metadata is not None and metadata.structured:
if backend_index.external:
# Structured hasn't been tested with external backends yet.
raise AssertionError(
"Structured external backend functions are not implemented yet."
)
else:
return gen_structured(g, backend_index)
else:
return list(
mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
)
else:
x = gen_unstructured(g, backend_index)
return [] if x is None else [x]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,551 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
CType,
Expr,
NamedCType,
opmath_t,
scalar_t,
StructuredImplSignature,
VectorizedCType,
)
from torchgen.context import with_native_function
from torchgen.model import (
Argument,
BaseTy,
BaseType,
DispatchKey,
NativeFunctionsGroup,
ScalarType,
UfuncKey,
)
from torchgen.utils import OrderedSet
if TYPE_CHECKING:
from torchgen.api.ufunc import UfunctorBindings
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CUDA STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# NB: not bothering to generate dispatch stub forward declaration in header,
# we can just paste it whereever necessary
# TODO: use BackendIndex
# dispatch_key: DispatchKey # only CPU/CUDA right now
# Represents functors for implementing CUDA ufuncs.
# Functors are templated by scalar_t because when USERS instantiate functors
# they are templated. A functor looks something like this:
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
# using opmath_t = at::opmath_type<scalar_t>;
# opmath_t other_;
# opmath_t alpha_;
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
# : other_(other), alpha_(alpha) {}
# __device__ scalar_t operator()(scalar_t self) {
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
# }
# };
#
@dataclass(frozen=True)
class UfunctorSignature:
g: NativeFunctionsGroup
scalar_tensor_idx: int | None
name: str
def arguments(self) -> UfunctorBindings:
return ufunc.ufunctor_arguments(
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
)
def fields(self) -> list[Binding]:
# fields are renamed to have a trailing underscore, as is conventional
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
def returns_type(self) -> CType:
# TODO: don't hardcode; return type will be inferred based on tags on
# the native function
return BaseCType(scalar_t)
def decl_fields(self) -> str:
return "\n".join(f"{f.type} {f.name};" for f in self.fields())
def inline_defn_ctor(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().ctor)
# NB: hypothetically could do this with translate but the
# transition here is very regular
init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
return f"{self.name}({args_str}) : {init_str} {{}}"
def decl_apply(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().apply)
return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
@dataclass(frozen=True)
class UfuncSignature:
g: NativeFunctionsGroup
name: str
compute_t: CType
def arguments(self) -> list[Binding]:
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
def call(self, ctx: Sequence[Binding | Expr]) -> str:
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# steps:
# 1. take the functional signature
# 2. use api.ufunc to convert it to template signature. this establishes
# the type of the template function
# 3. use api.ufunc (II) to generate a split struct / operator() signature.
# this establish context in which we call the template signature
#
# StructuredImplSignature context
# ~> functor constructor sig
#
# Functor constructor context
# ~> functor fields sig
#
# Functor apply context (functor fields + functor apply sig)
# ~> template sig
#
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
num_tensors = sum(
1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
)
return num_tensors == 2
def compute_ufunc_cuda_functors(
g: NativeFunctionsGroup,
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
# First, build the functors.
ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
ufunctors: list[str] = []
loops = g.out.ufunc_inner_loop
scalar_tensor_idx_lookup = {
UfuncKey.CUDAFunctorOnSelf: 1,
UfuncKey.CUDAFunctorOnOther: 0,
UfuncKey.CUDAFunctor: None,
}
if eligible_for_binary_scalar_specialization(g):
keys = [
UfuncKey.CUDAFunctorOnSelf,
UfuncKey.CUDAFunctorOnOther,
UfuncKey.CUDAFunctor,
]
else:
keys = [UfuncKey.CUDAFunctor]
for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
assert k not in loops, f"cannot use {k} on non-binary function"
for k in keys:
# If the key was directly defined, skip functor codegen; we assume the
# user already done it for us
if k in loops:
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
)
for dtype in loops[k].supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
continue
# Note [ScalarOnly and Generic must match names for CUDA]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Otherwise, look in ANY of the generic entries. For simplicity of
# codegen, both ScalarOnly and Generic are defined, the ufunc name
# must match (if they didn't match, we'd have to generate distinct
# functors per dtype, which is awful, so we're not going to do it unless
# someone really forces us to)
ufunc_name = None
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
if lk not in loops:
continue
if ufunc_name is None:
ufunc_name = loops[lk].name
else:
# See Note [ScalarOnly and Generic must match names for CUDA]
assert (
ufunc_name == loops[lk].name
), "ScalarOnly and Generic must have same ufunc name"
supported_dtypes |= loops[lk].supported_dtypes
assert ufunc_name is not None
name = f"{k}_{ufunc_name}"
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
)
for dtype in supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
ufunc_sig = UfuncSignature(
g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
)
apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
ufunctors.append(
f"""
template <typename scalar_t>
struct {ufunctor_sig.name} {{
using opmath_t = at::opmath_type<scalar_t>;
{ufunctor_sig.decl_fields()}
{ufunctor_sig.inline_defn_ctor()}
__device__ {ufunctor_sig.decl_apply()} {{
return {ufunc_sig.call(apply_ctx)};
}}
}};
"""
)
return ufunctor_sigs, "\n".join(ufunctors)
@dataclass(frozen=True)
class BinaryScalarSpecializationConfig:
scalar_idx: int
ctor_tensor: str
ufunc_key: UfuncKey
BinaryScalarSpecializationConfigs = [
BinaryScalarSpecializationConfig(
scalar_idx=0,
ctor_tensor="self",
ufunc_key=UfuncKey.CUDAFunctorOnOther,
),
BinaryScalarSpecializationConfig(
scalar_idx=1,
ctor_tensor="other",
ufunc_key=UfuncKey.CUDAFunctorOnSelf,
),
]
def compute_ufunc_cuda_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: dict[UfuncKey, UfunctorSignature],
parent_ctx: Sequence[Binding],
) -> str:
body = "using opmath_t = at::opmath_type<scalar_t>;"
body += "if (false) {}\n" # for ease of codegen
for config in BinaryScalarSpecializationConfigs:
if config.ufunc_key not in inner_loops:
continue
ufunctor_sig = inner_loops[config.ufunc_key]
scalar_idx = config.scalar_idx + 1
# Make a copy and at the same time widen the type (not permissible
# without copy; we don't want to mutate the input argument anyway)
ctx: list[Expr | Binding] = list(parent_ctx)
ctx.append(
Expr(
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
)
)
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
)
# NB: ufunctor must be allocated before iter.remove_operand is called,
# as it relies on iter
body += f"""\
else if (iter.is_cpu_scalar({scalar_idx})) {{
{ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
iter.remove_operand({scalar_idx});
gpu_kernel(iter, ufunctor);
}}"""
ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
)
body += f"""
else {{
gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
}}
"""
return body
@with_native_function
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
# First, build the functors, indexing them by dtype
ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
# Next, build the conditionals
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
stub_sig = StubSignature(g)
return f"""
{ufunctors}
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
{dtype_cases_str}
);
}}
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
{sig.defn()} {{
{stub_sig.direct_call(sig.arguments())};
}}
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CPU STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@dataclass(frozen=True)
class StubSignature:
g: NativeFunctionsGroup
@property
def name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_stub"
@property
def kernel_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_kernel"
@property
def type_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_fn"
def arguments(self) -> list[Binding]:
return ufunc.stub_arguments(self.g)
def type(self) -> str:
cpp_args = self.arguments()
return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
def dispatch_decl(self) -> str:
return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
def dispatch_defn(self) -> str:
return f"DEFINE_DISPATCH({self.name})"
def kernel_defn(self) -> str:
return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
def type_defn(self) -> str:
return f"using {self.type_name} = {self.type()}"
# must be called from context where this is TensorIteratorBase*
def call(self, ctx: Sequence[Binding]) -> str:
return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# used in CUDA to skip the unnecessary dynamic dispatch
def direct_call(self, ctx: Sequence[Binding]) -> str:
return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
@with_native_function
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
return f"""
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.dispatch_defn()};
{sig.defn()} {{
{stub_sig.call(sig.arguments())};
}}
"""
def compute_ufunc_cpu_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: dict[UfuncKey, UfuncSignature],
parent_ctx: Sequence[Binding],
) -> str:
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
scalar_loop = inner_loops[UfuncKey.CPUScalar]
vec_loop = None
if UfuncKey.CPUVector in inner_loops:
vec_loop = inner_loops[UfuncKey.CPUVector]
# NB: We DON'T use translate here, because translate is
# incapable of CSE'ing the scalar accesses in case it is also
# used by Vectorized; also, the unpacking here is very simple
# and only affects Scalar; everything else is implicitly captured
# by the lambda
# Setup scalar in scope
body = []
ctx = []
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
if vec_loop is not None:
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(
f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
)
ctx.append(
Expr(
f"_v_{b.name}",
NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
)
)
# Setup lambda signature
# NB: simplified version of ufunctor_arguments
scalar_bindings = []
vec_bindings = []
for a in g.functional.func.arguments.flat_non_out:
if not a.type.is_tensor_like():
continue
assert a.type == BaseType(BaseTy.Tensor)
scalar_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, BaseCType(scalar_t)),
argument=a,
)
)
if vec_loop is not None:
vec_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
argument=a,
)
)
def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
r: list[Expr | Binding] = []
r.extend(ctx)
r.extend(b)
return r
body_str = "\n".join(body)
if vec_loop is not None:
return f"""
{body_str}
cpu_kernel_vec(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
);
"""
else:
return f"""
{body_str}
cpu_kernel(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
);
"""
@with_native_function
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
loops = g.out.ufunc_inner_loop
ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
lks = []
# ORDER MATTERS: this specifies overriding precedence
if k in loops: # should happen rarely
lks.append(k)
if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
lks.append(UfuncKey.ScalarOnly)
if UfuncKey.Generic in loops:
lks.append(UfuncKey.Generic)
# TODO: don't hardcode ufunc:: namespace here, should be centralized smh
for lk in lks:
for dtype in loops[lk].supported_dtypes:
compute_t: CType
if k is UfuncKey.CPUScalar:
compute_t = BaseCType(scalar_t)
elif k is UfuncKey.CPUVector:
compute_t = VectorizedCType(BaseCType(scalar_t))
else:
raise AssertionError
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
if k not in inner_ufunc_sigs:
inner_ufunc_sigs[k] = UfuncSignature(
g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
)
# Build the conditionals
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunc_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
return f"""
namespace {{
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
{dtype_cases_str}
);
}}
}} // anonymous namespace
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
"""