I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,388 @@
from __future__ import annotations
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
if isinstance(g, NativeFunctionsGroup):
return str(g.functional.func.name.name.base)
else:
return str(g.view.root_name)
is_hand_written_ops_ = frozenset(
(
"abs",
"add",
"addmm",
"all",
"any",
"argmin",
"bmm",
"clamp",
"clamp_min",
"cumsum",
"div",
"fmod",
"index_select",
"leaky_relu",
"linear",
"log",
"matmul",
"mul",
"narrow_copy",
"nonzero",
"pow",
"remainder",
"sigmoid",
"sign",
"sub",
"tanh",
"detach",
"expand_as",
"flatten",
"narrow",
"reshape_as",
"select",
"slice",
"softmax",
"split",
"squeeze",
"transpose",
"view",
"where",
)
)
def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
name_base = func_name_base_str(g)
return name_base in is_hand_written_ops_
def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
assert index == 0 or index == 1
if op_name == "addr":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["vec1"] = "at::rand({6})"
arg_map["vec2"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["vec1"] = "at::rand({22})"
arg_map["vec2"] = "at::rand({22})"
return
if op_name == "mv":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["vec"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["vec"] = "at::rand({22})"
return
if op_name == "addbmm":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
else:
arg_map["self"] = "at::rand({22, 22})"
return
if op_name == "cross":
if index == 0:
arg_map["self"] = "at::rand({3, 3, 3})"
arg_map["other"] = "at::rand({3, 3, 3})"
else:
arg_map["self"] = "at::rand({22, 3, 22})"
arg_map["other"] = "at::rand({22, 3, 22})"
return
if op_name == "take":
if index == 0:
arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)"
else:
arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)"
return
if op_name == "take_along_dim":
if index == 0:
arg_map["indices"] = "at::argsort(self0, 1, true)"
else:
arg_map["indices"] = "at::argsort(self1, 1, true)"
return
if op_name == "masked_select":
if index == 0:
arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5"
else:
arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5"
return
if op_name == "orgqr":
if index == 0:
arg_map["input2"] = "at::rand({6, 6})"
else:
arg_map["input2"] = "at::rand({22, 22})"
return
if op_name == "ormqr":
if index == 0:
arg_map["input2"] = "at::rand({6, 6})"
else:
arg_map["input2"] = "at::rand({22, 22})"
return
if op_name == "quantile":
if index == 0:
arg_map["q"] = "at::rand({6})"
arg_map["interpolation"] = '"linear"'
else:
arg_map["q"] = "at::rand({22})"
arg_map["interpolation"] = '"linear"'
return
if op_name == "nanquantile":
if index == 0:
arg_map["q"] = "at::rand({6})"
arg_map["interpolation"] = '"linear"'
else:
arg_map["q"] = "at::rand({22})"
arg_map["interpolation"] = '"linear"'
return
if op_name == "multi_margin_loss":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
arg_map["weight"] = "at::rand({22})"
return
if op_name == "multilabel_margin_loss":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)"
return
if op_name == "nll_loss":
if index == 0:
arg_map["self"] = "at::rand({6, 6})"
arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22})"
arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
arg_map["weight"] = "at::rand({22})"
return
if op_name == "nll_loss2d":
if index == 0:
arg_map["self"] = "at::rand({6, 6, 6, 6})"
arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
else:
arg_map["self"] = "at::rand({22, 22, 22, 22})"
arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
arg_map["weight"] = "at::rand({22})"
return
if op_name in (
"fft_fft",
"fft_ifft",
"fft_rfft",
"fft_irfft",
"fft_hfft",
"fft_ihfft",
):
arg_map["norm"] = '"forward"'
return
if op_name == "linalg_tensorinv":
if index == 0:
arg_map["self"] = "at::rand({6, 6, 6, 6})"
arg_map["ind"] = "2"
else:
arg_map["self"] = "at::rand({22, 22, 22, 22})"
arg_map["ind"] = "2"
return
if op_name == "addmv":
if index == 0:
arg_map["self"] = "at::rand({2})"
arg_map["mat"] = "at::rand({2, 2})"
arg_map["vec"] = "at::rand({2})"
else:
arg_map["self"] = "at::rand({35})"
arg_map["mat"] = "at::rand({35, 35})"
arg_map["vec"] = "at::rand({35})"
return
if op_name == "acosh":
if index == 0:
arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})"
else:
arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
return
if op_name == "adaptive_max_pool2d_backward":
if index == 0:
arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)"
arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)"
arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)"
else:
arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)"
arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)"
arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)"
return
if op_name == "adaptive_max_pool3d_backward":
if index == 0:
arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)"
else:
arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)"
return
if op_name == "bitwise_left_shift":
if index == 0:
arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)"
arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)"
else:
arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)"
arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)"
return
if op_name == "bitwise_right_shift":
if index == 0:
arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)"
arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)"
else:
arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)"
arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)"
return
if op_name == "gather":
if index == 0:
arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)"
arg_map["dim"] = "1"
arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
arg_map["sparse_grad"] = "false"
else:
arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)"
arg_map["dim"] = "1"
arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
arg_map["sparse_grad"] = "false"
return
if op_name == "gelu":
if index == 0:
arg_map["self"] = "at::rand({6, 6, 6})"
arg_map["approximate"] = '"tanh"'
else:
arg_map["self"] = "at::rand({22, 22, 22})"
arg_map["approximate"] = '"tanh"'
return
if op_name == "gelu_backward":
if index == 0:
arg_map["grad_output"] = "at::rand({6, 6, 6})"
arg_map["self"] = "at::rand({6, 6, 6})"
arg_map["approximate"] = '"tanh"'
else:
arg_map["grad_output"] = "at::rand({22, 22, 22})"
arg_map["self"] = "at::rand({22, 22, 22})"
arg_map["approximate"] = '"tanh"'
return
if op_name == "index_add":
if index == 0:
arg_map["self"] = "at::rand({2})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
arg_map["source"] = "at::rand({2})"
arg_map["alpha"] = "2"
else:
arg_map["self"] = "at::rand({16})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
arg_map["source"] = "at::rand({16})"
arg_map["alpha"] = "2"
return
if op_name == "index_copy":
if index == 0:
arg_map["self"] = "at::rand({2})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
arg_map["source"] = "at::rand({2})"
else:
arg_map["self"] = "at::rand({32})"
arg_map["dim"] = "0"
arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
arg_map["source"] = "at::rand({32})"
return
if op_name == "linalg_cross":
if index == 0:
arg_map["self"] = "at::rand({6, 3, 6})"
arg_map["other"] = "at::rand({6, 3, 6})"
arg_map["dim"] = "1"
else:
arg_map["self"] = "at::rand({22, 3, 22})"
arg_map["other"] = "at::rand({22, 3, 22})"
arg_map["dim"] = "1"
return
if op_name == "nll_loss_backward":
if index == 0:
arg_map["grad_output"] = "at::rand({})"
arg_map["self"] = "at::rand({6})"
arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)"
arg_map["weight"] = "at::rand({6})"
arg_map["reduction"] = "1"
arg_map["ignore_index"] = "1"
arg_map["total_weight"] = "at::rand({})"
else:
arg_map["grad_output"] = "at::rand({})"
arg_map["self"] = "at::rand({36})"
arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)"
arg_map["weight"] = "at::rand({36})"
arg_map["reduction"] = "1"
arg_map["ignore_index"] = "1"
arg_map["total_weight"] = "at::rand({})"
return
if op_name in ["scatter", "scatter_add", "_scatter_reduce"]:
if index == 0:
arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
else:
arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)"
arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
if "reduce" in arg_map:
arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"'
return
if op_name == "scatter_reduce":
arg_map["reduce"] = '"mean"'
if index == 0:
arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
else:
arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
return
if op_name == "special_zeta":
if index == 0:
arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
else:
arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
return
if op_name == "_convert_indices_from_csr_to_coo":
if index == 0:
arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)"
arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)"
arg_map["out_int32"] = "false"
else:
arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)"
arg_map[
"col_indices"
] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
arg_map["out_int32"] = "false"
return
if op_name == "_convert_indices_from_coo_to_csr":
if index == 0:
arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)"
arg_map["size"] = "10"
arg_map["out_int32"] = "false"
else:
arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)"
arg_map["size"] = "24"
arg_map["out_int32"] = "false"
return
if op_name in ("diagonal", "linalg_diagonal"):
arg_map["offset"] = "0"
arg_map["dim1"] = "2"
arg_map["dim2"] = "1"
return

View File

@ -0,0 +1,229 @@
from __future__ import annotations
import argparse
import itertools
import os
from typing import Sequence, TypeVar, Union
from libfb.py.log import set_simple_logging # type: ignore[import]
from torchgen import gen
from torchgen.context import native_function_manager
from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
from torchgen.static_runtime import config, generator
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
# lists each of which groups ops that share the base name. For example, `mean` and
# `mean.dim` are grouped together by this function.
NativeGroupT = TypeVar(
"NativeGroupT",
bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup],
)
def group_functions_by_op_name(
grouped_native_functions: Sequence[NativeGroupT],
) -> Sequence[Sequence[NativeGroupT]]:
if not grouped_native_functions:
return []
groups = []
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
with native_function_manager(g):
return generator.is_supported(g)
eligible_ops = (g for g in grouped_native_functions if is_supported(g))
groups = [
list(group)
for k, group in (
itertools.groupby(
eligible_ops,
key=config.func_name_base_str,
)
)
]
return groups
def clang_format(cpp_file_path: str) -> None:
import subprocess
subprocess.check_call(["clang-format", "-i", cpp_file_path])
def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
code = "\n".join(cpp_ops)
generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
#include <torch/csrc/jit/runtime/static/ops.h>
#include <ATen/CPUFunctions.h>
#include <ATen/InferSize.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/ScalarOps.h>
#include <ATen/TensorUtils.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/EmbeddingBag.h>
#include <ATen/native/Fill.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/NonSymbolicBC.h>
#include <ATen/native/Resize.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/cpu/SerialStackImpl.h>
#include <ATen/native/layer_norm.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/quantized/Quantizer.h>
#include <c10/core/ScalarType.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/te_wrapper.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
namespace torch {{
namespace jit {{
{code}
}} // namespace jit
}} // namespace torch
"""
with open(file_path, "w") as f:
f.write(generated)
clang_format(file_path)
def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
code = "\n".join(cpp_ops)
generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
#include <gtest/gtest.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/torch.h>
#include "test_utils.h"
using namespace caffe2;
using namespace torch;
using namespace torch::jit;
using namespace torch::jit::test;
using c10::IValue;
{code}
"""
with open(file_path, "w") as f:
f.write(generated)
clang_format(file_path)
def main() -> None:
parser = argparse.ArgumentParser(description="Generate ATen source files")
parser.add_argument(
"-s",
"--source-path",
help="path to source directory for ATen",
default="caffe2/aten/src/ATen",
)
parser.add_argument(
"-p",
"--generated-ops-cpp-path",
help="path to directory to generate op dispatcher .cpp file",
default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp",
)
parser.add_argument(
"-t",
"--generated-ops-test-cpp-path",
help="path to directory to generate op dispatcher .cpp file",
default="caffe2/benchmarks/static_runtime/test_generated_ops.cc",
)
options = parser.parse_args()
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path)
native_functions, backend_indices = (
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
)
op_generator = generator.GenOpDispatcher()
test_case_generator = generator.GenOpTestCase()
native_functions_groups = [
g
for g in gen.get_grouped_native_functions(native_functions)
if isinstance(g, NativeFunctionsGroup)
]
supported_functions_groups = group_functions_by_op_name(native_functions_groups)
out_variant_op_result = [
op_generator.out_variant(groups, backend_indices[DispatchKey.CPU])
for groups in supported_functions_groups
]
out_variant_test_result = [
test_case_generator.out_variant(groups) for groups in supported_functions_groups
]
native_functions_view_groups = [
g
for g in gen.get_grouped_by_view_native_functions(native_functions)
if isinstance(g, NativeFunctionsViewGroup)
]
supported_functions_view_groups = group_functions_by_op_name(
native_functions_view_groups
)
view_op_result = [
op_generator.view(groups, backend_indices[DispatchKey.CPU])
for groups in supported_functions_view_groups
]
view_test_result = [
test_case_generator.view(groups) for groups in supported_functions_view_groups
]
op_result = out_variant_op_result + ["\n\n"] + view_op_result
test_result = out_variant_test_result + ["\n\n"] + view_test_result
write_cpp(op_result, options.generated_ops_cpp_path)
write_test_cpp(test_result, options.generated_ops_test_cpp_path)
print(
"\ntotal grouped native ops: %d"
% len(gen.get_grouped_native_functions(native_functions))
)
print("grouped native ops with out variant: %d" % len(native_functions_groups))
supported_functions_num = sum(len(groups) for groups in supported_functions_groups)
print("generated functions groups with out variant: %d" % supported_functions_num)
print("\nview grouped native ops: %d" % len(native_functions_view_groups))
supported_view_functions_num = sum(
len(groups) for groups in supported_functions_view_groups
)
print("generated functions view groups: %d" % supported_view_functions_num)
print(
"\noverall generated : %d"
% (supported_functions_num + supported_view_functions_num)
)
if __name__ == "__main__":
set_simple_logging(escape_newlines=False)
main()

View File

@ -0,0 +1,809 @@
from __future__ import annotations
import json
import logging
import math
from typing import Sequence
import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
BaseType,
FunctionSchema,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
OptionalType,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.static_runtime import config
logger: logging.Logger = logging.getLogger()
def has_alias(
arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
) -> bool:
for arg in arguments:
annotation = getattr(arg, "annotation", None)
if not annotation:
continue
alias_set = getattr(annotation, "alias_set", ())
if alias_set:
return True
return False
BLOCKED_OPS = frozenset(
(
# non cpu ops
"sparse_sampled_addmm",
"hspmm",
"linalg_svdvals",
# sparse ops
"sspaddmm",
"coalesce",
"_indices",
"indices",
"_values",
"values",
"crow_indices",
"col_indices",
# deprecated ops
"floor_divide",
"ger",
# buggy ops
"conj_physical", # P495807361
"binary_cross_entropy", # P496394764
"arccosh",
# uncommon ops
"cholesky",
"lu_solve",
"linalg_cholesky",
"linalg_householder_product",
"linalg_ldl_solve",
"_compute_linear_combination",
# training related ops
"_make_dual",
# cannot call directly
"_fw_primal",
# no documentation
"_index_reduce",
# TODO: these ones got added recently and need manual inspection
"_new_zeros_with_same_feature_meta",
"_conj_physical",
"binary_cross_entropy_with_logits",
"bincount",
"conv_tbc",
"copy",
"_copy_from",
"_copy_from_and_resize",
"count_nonzero",
"cudnn_affine_grid_generator",
"cudnn_affine_grid_generator_backward",
"cudnn_grid_sampler",
"diag_embed",
"embedding",
"embedding_dense_backward",
"_embedding_bag_dense_backward",
"_embedding_bag_per_sample_weights_backward",
"grid_sampler_2d",
"_grid_sampler_2d_cpu_fallback",
"grid_sampler_3d",
"isnan",
"mkldnn_linear",
"median",
"nanmedian",
"_sparse_sparse_matmul",
"batch_norm_backward_elemt",
"_euclidean_dist",
"pixel_shuffle",
"pixel_unshuffle",
"channel_shuffle",
"_reshape_nested_backward",
"relu",
"prelu",
"celu",
"slice_scatter",
"select_scatter",
"diagonal_scatter",
"sum",
"_mkldnn_transpose",
"_nested_tensor_from_mask",
"_nested_from_padded",
"_nested_tensor_size",
"_nested_from_padded_and_nested_example",
"_standard_gamma_grad",
"_dirichlet_grad",
"native_norm",
"_sparse_softmax",
"_sparse_softmax_backward_data",
"_sparse_log_softmax",
"_sparse_log_softmax_backward_data",
"zero",
"_sparse_addmm",
"sparse_mask",
"_sparse_mask_projection",
"_to_dense",
"_coalesce",
"_coalesced",
"copy_sparse_to_sparse",
"to_sparse",
"to_sparse_csr",
"to_sparse_csc",
"to_mkldnn",
"quantize_per_tensor_dynamic",
"quantize_per_channel",
"q_per_channel_scales",
"q_per_channel_zero_points",
"int_repr",
"_make_per_channel_quantized_tensor",
"set",
"lift",
"lift_fresh",
"lift_fresh_copy",
"masked_scatter",
"_masked_softmax",
"_masked_softmax_backward",
"put",
"index_reduce",
"trace",
"_cholesky_solve_helper",
"dist",
"max",
"_torch_cuda_cu_linker_symbol_op",
"glu_jvp",
"glu_backward_jvp",
"hardswish_backward",
"rrelu_with_noise_backward",
"mkldnn_adaptive_avg_pool2d_backward",
"_adaptive_avg_pool2d_backward",
"_adaptive_avg_pool3d_backward",
"isinf",
"linalg_lu_solve",
"linalg_vecdot",
"linalg_matrix_exp",
"linalg_eigvalsh",
"_test_warn_in_autograd",
"_test_autograd_multiple_dispatch_view",
"_test_autograd_multiple_dispatch_view_copy",
"_segment_reduce",
"_segment_reduce_backward",
"_fw_primal_copy",
"_make_dual_copy",
"view_as_real_copy",
"view_as_complex_copy",
"_conj_copy",
"_neg_view_copy",
"diagonal_copy",
"detach_copy",
"squeeze_copy",
"t_copy",
"unsqueeze_copy",
"_indices_copy",
"_values_copy",
"indices_copy",
"values_copy",
"crow_indices_copy",
"col_indices_copy",
"ccol_indices",
"ccol_indices_copy",
"row_indices",
"row_indices_copy",
"unfold_copy",
"alias_copy",
"_triton_multi_head_attention",
"special_airy_ai",
"special_bessel_j0",
"special_bessel_j1",
"special_bessel_y0",
"special_bessel_y1",
"special_chebyshev_polynomial_t",
"special_chebyshev_polynomial_u",
"special_chebyshev_polynomial_v",
"special_chebyshev_polynomial_w",
"special_hermite_polynomial_h",
"special_hermite_polynomial_he",
"special_laguerre_polynomial_l",
"special_legendre_polynomial_p",
"special_modified_bessel_i0",
"special_modified_bessel_i1",
"special_modified_bessel_k0",
"special_modified_bessel_k1",
"special_scaled_modified_bessel_k0",
"special_scaled_modified_bessel_k1",
"special_shifted_chebyshev_polynomial_t",
"special_shifted_chebyshev_polynomial_u",
"special_shifted_chebyshev_polynomial_v",
"special_shifted_chebyshev_polynomial_w",
"special_spherical_bessel_j0",
"_foobar",
"_nested_tensor_strides",
"_nested_tensor_storage_offsets",
"_nested_get_values", # no CPU backend
"_nested_get_values_copy", # no CPU backend
"_nested_view_from_jagged", # testing needs to be patched
"_nested_view_from_jagged_copy", # testing needs to be patched
"_nested_view_from_buffer", # testing needs to be patched
"_nested_view_from_buffer_copy", # testing needs to be patched
"_int_mm", # testing needs to be patched
"_to_sparse_csc", # testing needs to be patched
"_to_sparse_csr", # testing needs to be patched
"segment_reduce", # testing needs to be patched
)
)
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
base_op_name = ""
func = None
if isinstance(g, NativeFunctionsViewGroup):
base_op_name = g.view.root_name
func = g.view.func
else:
base_op_name = g.out.func.name.name.base
func = g.out.func
if config.is_hand_written(g):
logger.info("HAND WRITTEN: %s", base_op_name)
return False
if base_op_name in BLOCKED_OPS:
logger.info("BLOCKED: %s", base_op_name)
return False
for arg in func.schema_order_arguments():
maybe_method = ivalue_type_conversion_method(arg.type)
if not maybe_method:
# Type converting is unsupported yet.
logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func)
return False
if isinstance(g, NativeFunctionsViewGroup):
# TODO: stop doing type tests by converting to C++ and then testing
# the string, just test the dang thing directly
if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
# Returns a non-Tensor value.
logger.info("NON-TENSOR RET TYPE: %s", str(func))
return False
return True
# For out variant ops, we need to check the arguments of its functional func.
for arg in g.functional.func.schema_order_arguments():
maybe_method = ivalue_type_conversion_method(arg.type)
if not maybe_method:
# Type converting is unsupported yet.
logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func)
return False
if not g.structured:
# In case of unstructured op, we check if it has out variant implementation.
# The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
# parameter.
if (
not hasattr(g, "out")
or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
or not str(func.name).endswith(".out")
):
return False
# TODO: stop type testing by converting to C++
if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
logger.info("NON_TENSOR RET TYPE: %s", func)
return False
if has_alias(func.arguments.non_out):
# This op may create an alias of inputs.
logger.info("INPUTS ALIAS: %s", base_op_name)
return False
return True
def ivalue_type_conversion_method(
arg_type: BaseType | OptionalType | Type,
) -> tuple[bool, str] | None:
"""
Return the method call expression of `c10::ivalue' to convert its contained value to
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
this function returns ".toTensor()", so that it can be appended to the ivalue's
variable name to get the value of the expected type.
"""
type_conversion_methods = {
BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
BaseTy.ScalarType: (
(False, "toScalarType()"),
(False, "toOptional<at::ScalarType>()"),
),
BaseTy.str: (
(False, "toStringView()"),
(False, "toOptional<c10::string_view>()"),
),
}
base_ty_object = None
if isinstance(arg_type, BaseType):
base_ty_object = arg_type.name
elif isinstance(arg_type, OptionalType):
if not isinstance(arg_type.elem, BaseType):
# ListType is currently unsupported.
return None
base_ty_object = arg_type.elem.name
else:
return None
if base_ty_object not in type_conversion_methods:
return None
methods = type_conversion_methods[base_ty_object]
if isinstance(arg_type, BaseType):
return methods[0]
return methods[1]
should_use_int_tensor_ops_ = frozenset(
(
"bitwise_not",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"bitwise_left_shift",
"bitwise_right_shift",
"gcd",
"lcm",
"scatter",
"gather",
"_convert_indices_from_coo_to_csr",
"_convert_indices_from_csr_to_coo",
)
)
should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
def should_use_int_tensor(op_name: str) -> bool:
return op_name in should_use_int_tensor_ops_
def should_use_complex_tensor(op_name: str) -> bool:
return op_name in should_use_complex_tensor_ops_
test_tensor_dim_ops_1_ = frozenset(
(
"addmv",
"index_add",
"_convert_indices_from_coo_to_csr",
"_convert_indices_from_csr_to_coo",
"nll_loss_backward",
"dot",
"vdot",
"outer",
"ger",
)
)
test_tensor_dim_ops_2_ = frozenset(
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
)
def test_tensor_dim(op_name: str) -> int:
if op_name in test_tensor_dim_ops_1_:
return 1
if op_name in test_tensor_dim_ops_2_:
return 2
return 3
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
def test_tensor_shape(op_name: str) -> str:
if op_name in test_tensor_shape_json:
return test_tensor_shape_json[op_name]
else:
return ""
def test_value_expression(
arg_type: BaseType | OptionalType | Type, index: int, op_name: str
) -> str:
tensor_size_ex = test_tensor_shape(op_name)
if tensor_size_ex == "":
num_tensors = 16 if index == 0 else 64
num_dim = test_tensor_dim(op_name)
size_per_dim = math.ceil(num_tensors / float(num_dim))
size_per_dim += size_per_dim % 2
tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim))
if should_use_int_tensor(op_name):
tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
elif should_use_complex_tensor(op_name):
tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
else:
tensor_expression = f"at::rand({tensor_size_ex})"
value_expressions = {
BaseTy.Tensor: tensor_expression,
BaseTy.int: "1",
BaseTy.bool: "false",
BaseTy.Scalar: "2",
BaseTy.ScalarType: "at::ScalarType::Float",
BaseTy.str: '"floor"',
}
base_ty_object = None
if isinstance(arg_type, BaseType):
base_ty_object = arg_type.name
else:
assert isinstance(arg_type, OptionalType) and isinstance(
arg_type.elem, BaseType
)
base_ty_object = arg_type.elem.name
assert base_ty_object in value_expressions, "not expected type"
value_expression = value_expressions[base_ty_object]
return value_expression
def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
assert not schema.is_out_fn()
schema_name = schema.name.name.base
arg_map = {}
for arg in schema.schema_order_arguments():
test_value_exp = test_value_expression(arg.type, index, schema_name)
arg_map[arg.name] = test_value_exp
config.override_test_values(arg_map, schema_name, index)
arg_populations = []
for arg_name, arg_value in arg_map.items():
arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
return ";\n ".join(arg_populations) + ";"
def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
assert not schema.is_out_fn()
return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
generate_test_ir_arguments_base_ty_to_type_str_ = {
BaseTy.Tensor: "Tensor",
BaseTy.int: "int",
BaseTy.float: "float",
BaseTy.str: "str",
BaseTy.Scalar: "int",
BaseTy.ScalarType: "int",
BaseTy.bool: "bool",
}
def generate_test_ir_arguments(
schema: FunctionSchema,
) -> list[tuple[str, str | None]]:
def ir_argument(arg: Argument) -> tuple[str, str | None]:
t = arg.type
add_optional = False
if isinstance(t, OptionalType):
t = t.elem
add_optional = True
assert isinstance(t, BaseType)
type_str = None
if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
if type_str and add_optional:
type_str = f"{type_str}?"
return ("%" + arg.name, type_str)
return [ir_argument(arg) for arg in schema.schema_order_arguments()]
def generate_arg_extraction(schema: FunctionSchema) -> str:
arg_populations = []
for i, arg in enumerate(schema.schema_order_arguments()):
maybe_method = ivalue_type_conversion_method(arg.type)
assert maybe_method
is_reference, type_conversion_method = maybe_method
reference = "&" if is_reference else ""
arg_populations.append(
f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
)
return ";\n ".join(arg_populations) + ";"
def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
kernel = backend_index.get_kernel(g.functional)
if g.structured or kernel is None:
return cpp.name(g.functional.func)
return kernel.kernel
def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
kernel = backend_index.get_kernel(g.out)
if g.structured or kernel is None:
return cpp.name(g.out.func)
return kernel.kernel
def generate_non_out_variant_call(
g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
schema = g.functional.func
assert not schema.is_out_fn()
kernel_name = get_kernel_name(g, backend_index)
arg_names = (arg.name for arg in schema.schema_order_arguments())
namespace_name = "cpu" if g.structured else "native"
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def generate_call_to_view_ops(
g: NativeFunctionsViewGroup, backend_index: BackendIndex
) -> str:
schema = g.view.func
kernel_name = cpp.name(schema)
kernel = backend_index.get_kernel(g.view)
if kernel:
kernel_name = kernel.kernel
arg_names = (arg.name for arg in schema.schema_order_arguments())
namespace_name = "native"
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def generate_out_variant_call(
g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
schema = g.out.func
assert schema.is_out_fn()
arg_names = []
kernel_name = get_out_kernel_name(g, backend_index)
if g.structured:
# structured op starts with the output tensor argument.
arg_names = [out_arg.name for out_arg in schema.arguments.out]
else:
arg_names = []
for arg in schema.arguments.non_out:
if isinstance(arg, SelfArgument):
arg_names.append(arg.argument.name)
else:
assert isinstance(arg, Argument)
arg_names.append(arg.name)
if not g.structured:
assert len(schema.arguments.out) == 1
arg_names.append(schema.arguments.out[0].name)
cpp_arg_names = ",".join(arg_names)
namespace_name = "cpu" if g.structured else "native"
return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
no_memory_resize_ops = frozenset(
(
"isin.Scalar_Tensor",
"index_add",
"dot",
"vdot",
"nuclear_norm",
"histc",
"l1_loss",
"multi_margin_loss",
"multilabel_margin_loss",
"nll_loss",
"nll_loss2d",
"prod",
)
)
def should_check_resize(schema: FunctionSchema) -> bool:
schema_str = str(schema)
type_variant_op_name = schema_str[: schema_str.find("(")]
return type_variant_op_name not in no_memory_resize_ops
def op_name_from_group(g: NativeFunctionsGroup) -> str:
return g.functional.func.name.name.base
class GenOpDispatcher:
def out_variant(
self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
) -> str:
if not groups:
return ""
generated_type_variants = []
for g in groups:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsGroup)
generated_type_variant = self.out_variant_op_generator(g, backend_index)
generated_type_variants.append(generated_type_variant)
op_name = op_name_from_group(groups[0])
body = "\n".join(generated_type_variants)
generated = f"""
REGISTER_OPERATOR_FUNCTOR(
aten::{op_name},
aten_{op_name},
[](Node* n) -> SROperator {{
{body}
LogAndDumpSchema(n);
return nullptr;
}});
"""
return generated
def view(
self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
) -> str:
if not groups:
return ""
generated_type_variants = []
for g in groups:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsViewGroup)
generated_type_variant = self.view_op_generator(g, backend_index)
generated_type_variants.append(generated_type_variant)
op_name = config.func_name_base_str(groups[0])
body = "\n".join(generated_type_variants)
generated = f"""
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::{op_name},
aten_{op_name},
[](Node* n) -> SROperator {{
{body}
LogAndDumpSchema(n);
return nullptr;
}});
"""
return generated
def out_variant_op_generator(
self, g: NativeFunctionsGroup, backend_index: BackendIndex
) -> str:
functional = g.functional
schema = str(functional.func)
populated_argument = generate_arg_extraction(g.functional.func)
functional_variant_call = generate_non_out_variant_call(g, backend_index)
assert len(g.out.func.arguments.out) == 1
out_variable_name = str(g.out.func.arguments.out[0].name)
out_variant_call = generate_out_variant_call(g, backend_index)
generated = f"""
if (n->matches(torch::schema("aten::{schema}"))) {{
return [](ProcessedNode* p_node) {{
{populated_argument}
if (p_node->Output(0).isNone()) {{
p_node->Output(0) = {functional_variant_call};
return;
}}
auto& {out_variable_name} = p_node->Output(0).toTensor();
fastResizeToZero({out_variable_name});
{out_variant_call};
}};
}}"""
return generated
def view_op_generator(
self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
) -> str:
schema = str(g.view.func)
populated_argument = generate_arg_extraction(g.view.func)
functional_variant_call = generate_call_to_view_ops(g, backend_index)
generated = f"""
if (n->matches(torch::schema("aten::{schema}"))) {{
return [](ProcessedNode* p_node) {{
{populated_argument}
p_node->Output(0) = {functional_variant_call};
}};
}}"""
return generated
class GenOpTestCase:
def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
if not groups:
return ""
generated_type_variants = []
for g in groups:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsGroup)
generated_type_variant = self.out_variant_op_test_case_generator(g)
generated_type_variants.append(generated_type_variant)
return "\n".join(generated_type_variants)
def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
if not groups:
return ""
generated_type_variants = []
for g in groups:
with native_function_manager(g):
assert is_supported(g)
assert isinstance(g, NativeFunctionsViewGroup)
generated_type_variant = self.view_op_test_case_generator(g)
generated_type_variants.append(generated_type_variant)
return "\n".join(generated_type_variants)
def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
schema = g.functional.func
schema_str = str(schema)
assert schema_str.find("(") > 0
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
op_name = op_name_from_group(g)
assert type_variant_op_name.startswith(op_name)
arg_types = generate_test_ir_arguments(schema)
arg_declarations = ", ".join(
(
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
for arg_name, arg_type in arg_types
)
)
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
assert (
len(schema.returns) == 1
and isinstance(schema.returns[0].type, BaseType)
and schema.returns[0].type.name is BaseTy.Tensor
)
test_value_definitions = generate_test_value_definitions(schema, 0)
test_value_names = generate_test_value_names(schema, 0)
test_value_definitions2 = generate_test_value_definitions(schema, 1)
test_value_names2 = generate_test_value_names(schema, 1)
check_resize = "true" if should_check_resize(schema) else "false"
generated = f"""
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
const std::string script = R"IR(
graph({arg_declarations}):
%bias: None = prim::Constant()
%ret = aten::{op_name}({arg_names})
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
{test_value_definitions}
std::vector<IValue> args{{{test_value_names}}};
testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
{test_value_definitions2}
std::vector<IValue> args2{{{test_value_names2}}};
testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
}}
"""
return generated
def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
schema = g.view.func
schema_str = str(schema)
assert schema_str.find("(") > 0
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
op_name = g.view.root_name
assert type_variant_op_name.startswith(op_name)
arg_types = generate_test_ir_arguments(schema)
arg_declarations = ", ".join(
(
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
for arg_name, arg_type in arg_types
)
)
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
assert (
len(schema.returns) == 1
and isinstance(schema.returns[0].type, BaseType)
and schema.returns[0].type.name is BaseTy.Tensor
)
test_value_definitions = generate_test_value_definitions(schema, 0)
test_value_names = generate_test_value_names(schema, 0)
generated = f"""
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
const std::string script = R"IR(
graph({arg_declarations}):
%bias: None = prim::Constant()
%ret = aten::{op_name}({arg_names})
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
{test_value_definitions}
std::vector<IValue> args{{{test_value_names}}};
testStaticRuntime(script, args);
}}
"""
return generated