I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,171 @@
from __future__ import annotations
from dataclasses import dataclass
# This class holds information about a single operator used to determine
# the outcome of a selective/custom PyTorch build that doesn't include
# registration code for all the supported operators. This is done to
# reduce the size of the generated binary so that it can be deployed in
# situations where binary size comes at a premium.
#
@dataclass(frozen=True)
class SelectiveBuildOperator:
# The name of the operator. This includes the aten::, etc... prefix
# The operator name may or may not have the overload name. If this
# operator name does not specify an overload name, the way to determine
# if this entry refers to the family of operators with this base name
# or just the operator with this name is to look at the value of the
# 'include_all_overloads' flag in this class.
name: str
# True if this is a root operator (i.e. called directly from a
# TorchScript model, etc...). An operator is considered to be a
# root operator if it is called directly from any one of the models
# that this instance of the pytorch library was built for. Hence, it
# may not be a root operator in all of the models that are used in
# this instance of the pytorch library.
is_root_operator: bool
# Is this operator used for on-device training? If True, then we need to
# use the information to generate code in VariableType_N.cpp for registration
# of training related operators. Again, this is True if this operator
# is used for training in one or more models used by this instance of the
# pytorch library.
is_used_for_training: bool
# If True, it indicates that this operator instance (object) refers to an
# operator without the overload name and should apply to all overloads
# which have this operator name as the base name. This flag is applicable
# only for objects that have operator names without a DOT (period) character
# in them.
#
# Note: This flag is a temporary workaround to grandfather in the current
# static selective (custom) build mechanism, which largely ignores overload
# names when determining whether to select operators for registration
# purposes.
include_all_overloads: bool
# Debug Information at the operator level
_debug_info: tuple[str, ...] | None
@staticmethod
def from_yaml_dict(
op_name: str, op_info: dict[str, object]
) -> SelectiveBuildOperator:
allowed_keys = {
"name",
"is_root_operator",
"is_used_for_training",
"include_all_overloads",
"debug_info",
}
if len(set(op_info.keys()) - allowed_keys) > 0:
raise Exception( # noqa: TRY002
"Got unexpected top level keys: {}".format(
",".join(set(op_info.keys()) - allowed_keys),
)
)
if "name" in op_info:
assert op_name == op_info["name"]
is_root_operator = op_info.get("is_root_operator", True)
assert isinstance(is_root_operator, bool)
is_used_for_training = op_info.get("is_used_for_training", True)
assert isinstance(is_used_for_training, bool)
include_all_overloads = op_info.get("include_all_overloads", True)
assert isinstance(include_all_overloads, bool)
debug_info: tuple[str, ...] | None = None
if "debug_info" in op_info:
di_list = op_info["debug_info"]
assert isinstance(di_list, list)
debug_info = tuple(str(x) for x in di_list)
return SelectiveBuildOperator(
name=op_name,
is_root_operator=is_root_operator,
is_used_for_training=is_used_for_training,
include_all_overloads=include_all_overloads,
_debug_info=debug_info,
)
@staticmethod
def from_legacy_operator_name_without_overload(
name: str,
) -> SelectiveBuildOperator:
return SelectiveBuildOperator(
name=name,
is_root_operator=True,
is_used_for_training=True,
include_all_overloads=True,
_debug_info=None,
)
def to_dict(self) -> dict[str, object]:
ret: dict[str, object] = {
"is_root_operator": self.is_root_operator,
"is_used_for_training": self.is_used_for_training,
"include_all_overloads": self.include_all_overloads,
}
if self._debug_info is not None:
ret["debug_info"] = self._debug_info
return ret
def merge_debug_info(
lhs: tuple[str, ...] | None,
rhs: tuple[str, ...] | None,
) -> tuple[str, ...] | None:
# Ensure that when merging, each entry shows up just once.
if lhs is None and rhs is None:
return None
return tuple(set((lhs or ()) + (rhs or ())))
def combine_operators(
lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
) -> SelectiveBuildOperator:
if str(lhs.name) != str(rhs.name):
raise Exception( # noqa: TRY002
f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
)
return SelectiveBuildOperator(
name=lhs.name,
# Consider this operator to be a root operator if it is a
# root operator in any of the models used in this instance of
# the pytorch library.
is_root_operator=lhs.is_root_operator or rhs.is_root_operator,
# Consider this operator to be a training operator if it is
# an operator used for training in any of the models used
# in this instance of the pytorch library.
is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training,
include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads,
_debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info),
)
def merge_operator_dicts(
lhs: dict[str, SelectiveBuildOperator],
rhs: dict[str, SelectiveBuildOperator],
) -> dict[str, SelectiveBuildOperator]:
operators: dict[str, SelectiveBuildOperator] = {}
for op_name, op in list(lhs.items()) + list(rhs.items()):
new_op = op
if op_name in operators:
new_op = combine_operators(operators[op_name], op)
operators[op_name] = new_op
return operators
def strip_operator_overload_name(op_name: str) -> str:
return op_name.split(".")[0]

View File

@ -0,0 +1,352 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING
import yaml
from torchgen.selective_build.operator import (
merge_debug_info,
merge_operator_dicts,
SelectiveBuildOperator,
strip_operator_overload_name,
)
if TYPE_CHECKING:
from torchgen.model import NativeFunction
# A SelectiveBuilder holds information extracted from the selective build
# YAML specification.
#
# It includes information about the build's selectivity, the debug_info
# associated with this selective build (opaque string), and the set of
# operators that should be included in the build.
#
@dataclass(frozen=True)
class SelectiveBuilder:
# If true, then the build is not selective, and includes all
# operators.
include_all_operators: bool
# Debug Information at the selective/custom build level.
_debug_info: tuple[str, ...] | None
# A dictionary of operator -> operator metadata.
operators: dict[str, SelectiveBuildOperator]
# A dictionary of selected kernel tags and dtypes. Typically a
# PyTorch Operator Kernel (function) may have many code paths
# that are specialized for many many Tensor dtypes, so it's not
# one per kernel function, but there could be many per kernel
# function. The tag isn't a kernel function name, but some fragment
# of the kernel function implementation itself.
kernel_metadata: dict[str, list[str]]
# ExecuTorch only. A dictionary of kernel tag -> list of (list of input
# dtypes for tensor-like input args).
# This is from selective.yaml
et_kernel_metadata: dict[str, list[str]]
# A set of all the custom torch bind classes used by the selected models
# Stored as a set internally to remove duplicates proactively, but written
# as a list to yamls
custom_classes: set[str]
# A set of all the build features used by the selected models
# Stored as a set internally to remove duplicates proactively, but written
# as a list to yamls
build_features: set[str]
# If true, then fragments for all dtypes for all kernel functions
# are included as well as all custom classes. This is typically set when any one of the
# operator lists is generated from a mechanism other than
# tracing based selective build.
include_all_non_op_selectives: bool
@staticmethod
def get_nop_selector() -> SelectiveBuilder:
return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
@staticmethod
def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
valid_top_level_keys = {
"include_all_non_op_selectives",
"include_all_operators",
"debug_info",
"operators",
"kernel_metadata",
"et_kernel_metadata",
"custom_classes",
"build_features",
}
top_level_keys = set(data.keys())
if len(top_level_keys - valid_top_level_keys) > 0:
raise Exception( # noqa: TRY002
"Got unexpected top level keys: {}".format(
",".join(top_level_keys - valid_top_level_keys),
)
)
include_all_operators = data.get("include_all_operators", False)
assert isinstance(include_all_operators, bool)
debug_info = None
if "debug_info" in data:
di_list = data["debug_info"]
assert isinstance(di_list, list)
debug_info = tuple(str(x) for x in di_list)
operators = {}
operators_dict = data.get("operators", {})
assert isinstance(operators_dict, dict)
for k, v in operators_dict.items():
operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)
kernel_metadata = {}
kernel_metadata_dict = data.get("kernel_metadata", {})
assert isinstance(kernel_metadata_dict, dict)
for k, v in kernel_metadata_dict.items():
kernel_metadata[str(k)] = [str(dtype) for dtype in v]
et_kernel_metadata = data.get("et_kernel_metadata", {})
assert isinstance(et_kernel_metadata, dict)
custom_classes = data.get("custom_classes", [])
assert isinstance(custom_classes, Iterable)
custom_classes = set(custom_classes)
build_features = data.get("build_features", [])
assert isinstance(build_features, Iterable)
build_features = set(build_features)
include_all_non_op_selectives = data.get("include_all_non_op_selectives", False)
assert isinstance(include_all_non_op_selectives, bool)
return SelectiveBuilder(
include_all_operators,
debug_info,
operators,
kernel_metadata,
et_kernel_metadata,
custom_classes, # type: ignore[arg-type]
build_features, # type: ignore[arg-type]
include_all_non_op_selectives,
)
@staticmethod
def from_yaml_str(config_contents: str) -> SelectiveBuilder:
contents = yaml.safe_load(config_contents)
return SelectiveBuilder.from_yaml_dict(contents)
@staticmethod
def from_yaml_path(config_path: str) -> SelectiveBuilder:
with open(config_path) as f:
contents = yaml.safe_load(f)
return SelectiveBuilder.from_yaml_dict(contents)
@staticmethod
def from_legacy_op_registration_allow_list(
allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
) -> SelectiveBuilder:
operators = {}
for op in allow_list:
operators[op] = {
"name": op,
"is_root_operator": is_root_operator,
"is_used_for_training": is_used_for_training,
"include_all_overloads": True,
}
return SelectiveBuilder.from_yaml_dict(
{
"operators": operators,
"include_all_non_op_selectives": True,
}
)
def is_operator_selected(self, name: str) -> bool:
if self.include_all_operators:
return True
if name in self.operators:
return True
name = strip_operator_overload_name(name)
return name in self.operators and self.operators[name].include_all_overloads
def is_native_function_selected(self, func: NativeFunction) -> bool:
op_name = op_name_from_native_function(func)
return self.is_operator_selected(op_name)
def is_operator_selected_for_training(self, name: str) -> bool:
if not self.is_operator_selected(name):
return False
if self.include_all_operators:
return True
not_training_op = SelectiveBuildOperator(
name="",
is_root_operator=False,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
op = not_training_op
if name in self.operators:
op = self.operators[name]
name = strip_operator_overload_name(name)
base_op = not_training_op
if name in self.operators:
base_op = self.operators[name]
return op.is_used_for_training or (
base_op.include_all_overloads and base_op.is_used_for_training
)
def is_native_function_selected_for_training(self, func: NativeFunction) -> bool:
op_name = op_name_from_native_function(func)
return self.is_operator_selected_for_training(op_name)
def is_root_operator(self, name: str) -> bool:
if not self.is_operator_selected(name):
return False
if self.include_all_operators:
return True
if name in self.operators:
op: SelectiveBuildOperator = self.operators[name]
return op.is_root_operator
name = strip_operator_overload_name(name)
if name not in self.operators:
return False
base_op: SelectiveBuildOperator = self.operators[name]
return base_op.include_all_overloads and base_op.is_root_operator
def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
if self.include_all_operators or self.include_all_non_op_selectives:
return True
return (
kernel_tag in self.kernel_metadata
and dtype in self.kernel_metadata[kernel_tag]
)
def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
"""
Return a list of kernel keys that cover the used ops
"""
# If no kernel metadata, either it's implied by include_all_operators=True or the op is not used.
if op_name not in self.et_kernel_metadata:
return kernel_key if self.include_all_operators else []
# Otherwise, only return the specific kernel keys.
result_set = set()
for model_kernel_keys in self.et_kernel_metadata[op_name]:
key_found = False
for key in kernel_key:
# Don't compare the version for now
if (
key != "default"
and key.split("/")[1] == model_kernel_keys.split("/")[1]
):
result_set.add(key)
key_found = True
break
if not key_found:
if "default" not in kernel_key:
raise Exception("Missing kernel for the model") # noqa: TRY002
else:
result_set.add("default")
return list(result_set)
def to_dict(self) -> dict[str, object]:
ret: dict[str, object] = {
"include_all_non_op_selectives": self.include_all_non_op_selectives,
"include_all_operators": self.include_all_operators,
}
operators = {}
for op_name, op in self.operators.items():
operators[op_name] = op.to_dict()
ret["operators"] = operators
if self._debug_info is not None:
ret["debug_info"] = sorted(self._debug_info)
ret["kernel_metadata"] = {
k: sorted(v) for (k, v) in self.kernel_metadata.items()
}
ret["et_kernel_metadata"] = self.et_kernel_metadata
ret["custom_classes"] = sorted(self.custom_classes)
ret["build_features"] = sorted(self.build_features)
return ret
def merge_kernel_metadata(
lhs: dict[str, list[str]],
rhs: dict[str, list[str]],
) -> dict[str, list[str]]:
kernel_metadata: dict[str, list[str]] = {}
for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
dtypes_copy = set(dtypes)
if tag_name in kernel_metadata:
dtypes_copy |= set(kernel_metadata[tag_name])
kernel_metadata[tag_name] = list(dtypes_copy)
return kernel_metadata
def merge_et_kernel_metadata(
lhs: dict[str, list[str]],
rhs: dict[str, list[str]],
) -> dict[str, list[str]]:
merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
for op in list(lhs.keys()) + list(rhs.keys()):
merge_et_kernel_metadata[op].update(lhs.get(op, []))
merge_et_kernel_metadata[op].update(rhs.get(op, []))
return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()}
def combine_selective_builders(
lhs: SelectiveBuilder, rhs: SelectiveBuilder
) -> SelectiveBuilder:
include_all_operators = lhs.include_all_operators or rhs.include_all_operators
debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
operators = merge_operator_dicts(lhs.operators, rhs.operators)
kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
et_kernel_metadata = merge_et_kernel_metadata(
lhs.et_kernel_metadata, rhs.et_kernel_metadata
)
include_all_non_op_selectives = (
lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
)
custom_classes = lhs.custom_classes.union(rhs.custom_classes)
build_features = lhs.build_features.union(rhs.build_features)
return SelectiveBuilder(
include_all_operators,
debug_info,
operators,
kernel_metadata,
et_kernel_metadata,
custom_classes,
build_features,
include_all_non_op_selectives,
)
def op_name_from_native_function(f: NativeFunction) -> str:
# This was originally read from the 'operator_name_with_overload' field in the
# declaration dict, which was the part before the first '(' in 'schema_string'.
return f"{f.namespace}::{f.func.name}"