I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,179 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Tuple
import torch.fx
import torch.utils._pytree as pytree
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
def compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
options: Optional[Dict[str, Any]] = None,
):
"""
Compile a given FX graph with TorchInductor. This allows compiling
FX graphs captured without using TorchDynamo.
Args:
gm: The FX graph to compile.
example_inputs: List of tensor inputs.
options: Optional dict of config options. See `torch._inductor.config`.
Returns:
Callable with same behavior as gm but faster.
"""
from .compile_fx import compile_fx
return compile_fx(gm, example_inputs, config_patches=options)
def aot_compile(
gm: torch.fx.GraphModule,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
options: Optional[Dict[str, Any]] = None,
) -> str:
"""
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
Args:
gm: The FX graph to compile.
args: Example arguments
kwargs: Example keyword arguments
options: Optional dict of config options. See `torch._inductor.config`.
Returns:
Path to the generated shared library
"""
from .compile_fx import compile_fx_aot, graph_returns_tuple
assert graph_returns_tuple(gm), (
"Graph output must be a tuple(). This is so that we can avoid "
"pytree processing of the outputs. Please change the module to "
"have tuple outputs."
)
# We will serialize the pytree info into the .so as constant strings
in_spec = None
out_spec = None
if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen):
codegen = gm.graph._codegen
gm.graph._codegen = torch.fx.graph.CodeGen()
gm.recompile()
if codegen.pytree_info.in_spec is not None:
in_spec = codegen.pytree_info.in_spec
if codegen.pytree_info.out_spec is not None:
out_spec = codegen.pytree_info.out_spec
else:
if hasattr(gm, "_in_spec"):
in_spec = gm._in_spec
if hasattr(gm, "_out_spec"):
out_spec = gm._out_spec
serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else ""
serialized_out_spec = (
pytree.treespec_dumps(out_spec) if out_spec is not None else ""
)
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, kwargs or {})
)
# Replace non-tensor (constant) inputs with Nones, since these are not being
# used anyways by the graph
flat_example_inputs = [
x[1] if isinstance(x[1], torch.Tensor) else None for x in flat_args_with_path
]
if in_spec is not None and received_spec != in_spec:
raise ValueError( # noqa: B904
"Trying to flatten user inputs with exported input tree spec: \n"
f"{in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
options = (
{
"aot_inductor.serialized_in_spec": serialized_in_spec,
"aot_inductor.serialized_out_spec": serialized_out_spec,
}
if options is None
else {
**options,
"aot_inductor.serialized_in_spec": serialized_in_spec,
"aot_inductor.serialized_out_spec": serialized_out_spec,
}
)
return compile_fx_aot(
gm,
flat_example_inputs, # type: ignore[arg-type]
config_patches=options,
)
def list_mode_options(
mode: Optional[str] = None, dynamic: Optional[bool] = None
) -> Dict[str, Any]:
r"""Returns a dictionary describing the optimizations that each of the available
modes passed to `torch.compile()` performs.
Args:
mode (str, optional): The mode to return the optimizations for.
If None, returns optimizations for all modes
dynamic (bool, optional): Whether dynamic shape is enabled.
Example::
>>> torch._inductor.list_mode_options()
"""
mode_options: Dict[str, Dict[str, bool]] = {
"default": {},
# enable cudagraphs
"reduce-overhead": {
"triton.cudagraphs": True,
},
# enable max-autotune
"max-autotune-no-cudagraphs": {
"max_autotune": True,
},
# enable max-autotune
# enable cudagraphs
"max-autotune": {
"max_autotune": True,
"triton.cudagraphs": True,
},
}
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
def list_options() -> List[str]:
r"""Returns a dictionary describing the optimizations and debug configurations
that are available to `torch.compile()`.
The options are documented in `torch._inductor.config`.
Example::
>>> torch._inductor.list_options()
"""
from torch._inductor import config
current_config: Dict[str, Any] = config.shallow_copy_dict()
return list(current_config.keys())
def cudagraph_mark_step_begin():
"Indicates that a new iteration of inference or training is about to begin."
from .cudagraph_trees import mark_step_begin
mark_step_begin()

View File

@ -0,0 +1,298 @@
import json
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest import mock
import torch
import torch._export
from torch._inductor.utils import is_cpu_device
from .runtime.runtime_utils import cache_dir
log = logging.getLogger(__name__)
def aoti_eager_cache_dir(namespace: str, device: str) -> Path:
return Path(cache_dir()) / "aoti_eager" / namespace / device
def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
from filelock import FileLock
# Avoid circular import
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
op_conf_lock_file = f"{op_func_name_with_overload}.lock"
lock_dir = get_lock_dir()
return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT)
def load_aoti_eager_cache(
ns: str, op_func_name_with_overload: str, device_type: str
) -> List[Optional[Dict[str, Any]]]:
device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
if not op_conf.exists():
return []
try:
with aoti_eager_op_conf_lock(op_func_name_with_overload):
with open(op_conf) as f:
json_data = json.load(f)
for item in json_data:
# Get absolution path for kernel library
kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
item["kernel_path"] = kernel_lib_abs_path.as_posix()
# Check if the kernel library exists
if not kernel_lib_abs_path.exists():
return []
for metadata in item["meta_info"]:
if metadata.get("is_dynamic"):
raise NotImplementedError(
"Only support static shape for now"
)
if (
"device_type" in metadata
and metadata["device_type"] == "cpu"
):
metadata["device_index"] = -1
for dtype_key in ["dtype", "dtype_value"]:
if dtype_key in metadata:
metadata[dtype_key] = getattr(
torch, metadata[dtype_key].split(".")[-1]
)
if "layout_value" in metadata:
metadata["layout_value"] = getattr(
torch, metadata["layout_value"].split(".")[-1]
)
if "memory_format_value" in metadata:
metadata["memory_format_value"] = getattr(
torch, metadata["memory_format_value"].split(".")[-1]
)
return json_data
except Exception as e:
err_msg = f"Failed to load aoti eager cache: {e}"
log.exception(err_msg)
return []
def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]:
return {int: torch.int32, float: torch.float, bool: torch.bool}
def supported_scalar_types() -> Tuple[type, ...]:
type_to_torch_dtype = supported_builtin_dtype_torch_dtype()
return tuple(type_to_torch_dtype.keys())
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]:
metadata: Dict[str, Any] = {}
metadata["is_dynamic"] = dynamic
assert isinstance(input, torch.Tensor)
metadata["device_type"] = f"{input.device.type}"
if is_cpu_device([input]):
metadata["device_index"] = -1
else:
metadata["device_index"] = input.device.index
metadata["dtype"] = f"{input.dtype}"
metadata["sizes"] = list(input.size())
metadata["strides"] = list(input.stride())
metadata["requires_grad"] = input.requires_grad
metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr()
return metadata
def extract_tensor_list_metadata(
dynamic: bool,
input: List[torch.Tensor],
) -> Dict[str, Any]:
metadata_list = []
for item in input:
assert isinstance(item, torch.Tensor)
metadata_list.append(extract_tensor_metadata(dynamic, item))
metadata: Dict[str, Any] = {}
metadata["tensor_list"] = metadata_list
return metadata
def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
assert isinstance(input, supported_scalar_types())
metadata: Dict[str, Any] = {}
metadata["is_dynamic"] = False
# Scalar tensor
metadata["device_type"] = device_type
metadata["device_index"] = -1 if device_type == "cpu" else 0
type_to_torch_dtype = supported_builtin_dtype_torch_dtype()
metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}"
metadata["scalar_value"] = input
return metadata
def extract_string_metadata(input: str) -> Dict[str, Any]:
assert isinstance(input, str)
metadata: Dict[str, Any] = {}
metadata["string_value"] = input
return metadata
def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]:
assert isinstance(input, torch.dtype)
metadata: Dict[str, Any] = {}
metadata["dtype_value"] = f"{input}"
return metadata
def extract_device_metadata(input: torch.device) -> Dict[str, Any]:
assert isinstance(input, torch.device)
metadata: Dict[str, Any] = {}
metadata["device_type_value"] = f"{input.type}"
metadata["device_index_value"] = input.index
return metadata
def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]:
assert isinstance(input, torch.layout)
metadata: Dict[str, Any] = {}
metadata["layout_value"] = f"{input}"
return metadata
def aoti_compile_with_persistent_cache(
ns: str,
op_func_name_with_overload: str,
device_type: str,
dynamic: bool,
f: Callable[..., Any],
args: Tuple[Any],
kwargs: Dict[str, Any],
*,
dynamic_shapes: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
) -> str:
"""
Compile the given function with persistent cache for AOTI eager mode.
"""
assert not dynamic, "Only support static shape for now"
flattened_inputs = list(args) + list(kwargs.values())
if not all(
isinstance(
input,
(
supported_scalar_types(),
torch.Tensor,
list,
str,
torch.dtype,
torch.device,
torch.layout,
),
)
for input in flattened_inputs
):
err_msg = f"Unsupported input types: {flattened_inputs}"
log.exception(err_msg)
raise NotImplementedError(err_msg)
for input in flattened_inputs:
if isinstance(input, list) and not all(
isinstance(item, torch.Tensor) for item in input
):
err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}"
log.exception(err_msg)
raise NotImplementedError(err_msg)
persistent_cache = aoti_eager_cache_dir(ns, device_type)
if not persistent_cache.exists():
persistent_cache.mkdir(parents=True)
persistent_cache_lib = persistent_cache / "lib"
if not persistent_cache_lib.exists():
persistent_cache_lib.mkdir()
with mock.patch.dict(
os.environ,
{"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()},
):
try:
kernel_lib_path = torch._export.aot_compile(
f,
args,
kwargs,
dynamic_shapes=dynamic_shapes,
remove_runtime_assertions=remove_runtime_assertions,
disable_constraint_solver=disable_constraint_solver,
# Some operations may have non-Tensor parameters like int, float, bool. These
# non-Tensor parameters will not be the input of the graph. Therefore, we do
# need to keep the same signature.
same_signature=False,
)
kernel_metadata_items = []
for idx, input in enumerate(flattened_inputs):
if isinstance(input, torch.Tensor):
metadata = extract_tensor_metadata(dynamic, input)
elif isinstance(input, list):
assert all(isinstance(item, torch.Tensor) for item in input)
metadata = extract_tensor_list_metadata(dynamic, input)
elif isinstance(input, supported_scalar_types()):
metadata = extract_scalar_metadata(device_type, input)
elif isinstance(input, str):
metadata = extract_string_metadata(input)
elif isinstance(input, torch.dtype):
metadata = extract_dtype_metadata(input)
elif isinstance(input, torch.device):
metadata = extract_device_metadata(input)
elif isinstance(input, torch.layout):
metadata = extract_layout_metadata(input)
else:
raise NotImplementedError(f"Unsupported input type: {type(input)}")
metadata["arg_order"] = idx
kernel_metadata_items.append(metadata)
kernel_meta_info: Dict[str, Any] = {}
kernel_meta_info["meta_info"] = kernel_metadata_items
kernel_meta_info["kernel_path"] = (
Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
)
json_data = []
update_json = True
op_conf = persistent_cache / f"{op_func_name_with_overload}.json"
mode = "r" if op_conf.exists() else "w"
with aoti_eager_op_conf_lock(op_func_name_with_overload):
with open(op_conf, mode) as op_conf_file:
try:
json_data = json.load(op_conf_file)
except Exception as e:
json_data = []
assert isinstance(json_data, list)
for item in json_data:
assert isinstance(item, dict)
# Same kernel meta info already exists in the json file
if item["meta_info"] == kernel_metadata_items:
update_json = False
break
if update_json:
json_data.append(kernel_meta_info)
with open(op_conf, "w") as op_conf_file:
json.dump(json_data, op_conf_file, indent=4)
return kernel_lib_path
except Exception as e:
err_msg = f"Failed to compile {op_func_name_with_overload}: {e}"
log.exception(err_msg)
return ""

View File

@ -0,0 +1,297 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import functools
import logging
import multiprocessing
import os
import sys
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from functools import partial
from time import time
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
import torch
from torch._dynamo.device_interface import get_registered_device_interfaces
from torch._inductor import config
from torch._inductor.codecache import (
CodeCacheFuture,
CppCodeCache,
CppPythonBindingsCodeCache,
CUDACodeCache,
HalideCodeCache,
LambdaFuture,
ROCmCodeCache,
TritonCodeCache,
TritonFuture,
)
from torch._inductor.compile_worker.subproc_pool import (
_warm_process_pool,
AnyPool,
SubprocPool,
)
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import (
_set_triton_ptxas_path,
_worker_compile_triton,
)
from torch.hub import _Faketqdm, tqdm
from torch.utils._triton import has_triton_package
if TYPE_CHECKING:
from torch._inductor.runtime.hints import HalideMeta
# timing metrics for time spent in the compilation
_cumulative_compile_time = 0.0
_t0: Optional[float] = None
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
def pre_fork_setup():
"""
Setup that must be done prior to forking with a process pool.
"""
# ensure properties have been calculated before processes
# are forked
caching_device_properties()
# Computing the triton key can be slow. If we call it before fork,
# it will be cached for the forked subprocesses.
try:
from triton.compiler.compiler import triton_key
triton_key()
except ImportError:
# Triton might not be installed or might be an old version.
pass
def caching_device_properties():
for _, device_interface in get_registered_device_interfaces():
if device_interface.is_available():
device_interface.Worker.get_device_properties()
def _compile_start() -> None:
global _t0
if _t0 is None:
_t0 = time()
def _compile_end() -> None:
global _cumulative_compile_time, _t0
if _t0 is not None:
t1 = time()
_cumulative_compile_time += t1 - _t0
_t0 = None
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
_IS_WINDOWS = sys.platform == "win32"
log = logging.getLogger(__name__)
# Used to keep track of all process pools invoked so far.
_pool_set: Set[AnyPool] = set()
def shutdown_compile_workers() -> None:
"""Shut down all outstanding compile-worker pools."""
for pool in _pool_set:
pool.shutdown()
after_fork()
def after_fork():
"""Reset pools to initial state without shutting them down"""
_pool_set.clear()
AsyncCompile.process_pool.cache_clear()
try:
os.register_at_fork(after_in_child=after_fork)
except AttributeError:
pass # register_at_fork does not exists on windows
class AsyncCompile:
def __init__(self) -> None:
pass
@staticmethod
@functools.lru_cache(1)
def pool() -> ThreadPoolExecutor:
assert config.compile_threads > 1
return ThreadPoolExecutor(config.compile_threads)
@staticmethod
def _get_ready():
"""No-op function to help mark when the subprocess pool is ready."""
return "ready"
@staticmethod
@functools.lru_cache(1)
def process_pool() -> AnyPool:
assert config.compile_threads > 1
pool: AnyPool
if config.worker_start_method == "subprocess":
# Wrapper around ProcessPoolExecutor forks in a new process we control
pool = SubprocPool(config.compile_threads)
else:
pre_fork_setup()
ctx = multiprocessing.get_context(config.worker_start_method)
pool = ProcessPoolExecutor(
config.compile_threads,
mp_context=ctx,
initializer=partial(_async_compile_initializer, os.getpid()),
)
# when this pool is created in a subprocess object, the normal exit handler
# doesn't run, and we need to register our own handler.
# exitpriority has to be high, because another one of the finalizers will
# kill the worker thread that sends the shutdown message to the workers...
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
# Set an attribute we can check to see if the pool is ready.
pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr]
_pool_set.add(pool)
return pool
@classmethod
def warm_pool(cls) -> None:
if config.compile_threads <= 1:
return
_compile_start()
_warm_process_pool(cls.process_pool(), config.compile_threads)
_compile_end()
@classmethod
def submit(cls, task: Callable[..., Any]) -> Any:
if config.compile_threads <= 1:
return task()
return cls.pool().submit(task)
def _use_process_pool(self):
return (
config.compile_threads > 1
and self.process_pool().ready_future.done() # type: ignore[union-attr]
)
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
kernel_code_log.info("Triton Kernel:\n%s", source_code)
_compile_start()
_set_triton_ptxas_path()
kernel = TritonCodeCache.load(kernel_name, source_code)
if self._use_process_pool():
# We want to support changing these env vars after (and while) the
# process pool is running, so pass them to the subprocess to reset.
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
return TritonFuture(
kernel,
self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
extra_env,
),
)
else:
kernel.precompile()
return kernel
def multi_kernel(self, *args, **kwargs) -> Any:
from torch._inductor.codegen.multi_kernel import MultiKernelCall
# no need to call this in parallel since the sub-kernels are already parallel tasks
return MultiKernelCall(*args, **kwargs)
def cpp(self, source_code: str):
kernel_code_log.info("CPP Kernel:\n%s", source_code)
if config.compile_threads <= 1:
return CppCodeCache.load(source_code).kernel
else:
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
return LambdaFuture(lambda: get_result().kernel)
def cpp_pybinding(self, argtypes: List[str], source_code: str):
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
if config.compile_threads <= 1:
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
else:
get_result = CppPythonBindingsCodeCache.load_pybinding_async(
argtypes, source_code, submit_fn=self.submit
)
return LambdaFuture(get_result)
def cuda(self, source_code, dst_file_ext):
kernel_code_log.info("CUDA Kernel:\n%s", source_code)
def task():
return CUDACodeCache.load(source_code, dst_file_ext)[0]
return self.submit(task)
def rocm(self, source_code, dst_file_ext):
kernel_code_log.info("ROCm Kernel:\n%s", source_code)
def task():
return ROCmCodeCache.load(source_code, dst_file_ext)[0]
return self.submit(task)
def halide(self, meta: HalideMeta, source_code: str):
kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
if config.compile_threads <= 1:
return HalideCodeCache.generate_halide(meta, source_code)
else:
get_result = HalideCodeCache.generate_halide_async(
meta, source_code, submit_fn=self.submit
)
return LambdaFuture(get_result)
def wait(self, scope: Dict[str, Any]) -> None:
num_kernels = len(
[
value
for key, value in scope.items()
if isinstance(value, (Future, CodeCacheFuture))
]
)
pbar = tqdm(
total=num_kernels,
desc="Inductor Compilation",
disable=config.disable_progress,
delay=0,
)
if config.compile_threads > 1:
for key, result in scope.items():
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
pbar.set_postfix_str(key)
if isinstance(result, (Future, CodeCacheFuture)):
try:
scope[key] = result.result()
except BrokenProcessPool as e:
raise RuntimeError(
"A compilation subprocess exited unexpectedly. This "
"is likely due to a crash. To facilitate debugging, "
"you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 "
"to cause compilation to occur in the main process."
) from e
pbar.update(1)
_compile_end()
if (
os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
or os.environ.get("TORCH_WARM_POOL", "1") != "1"
# The subprocess pool is only used for the Triton backend
or not has_triton_package()
):
pass
else:
AsyncCompile.warm_pool()

View File

@ -0,0 +1,296 @@
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicDecision,
)
class MMRankingA100(LearnedHeuristicDecision):
def __init__(self) -> None:
self.choices: List[Choice] = []
self.fill_choices()
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 166912
and str(metadata.device_capa) == "(8, 0)"
)
def get_confidence_threshold(self) -> float:
return 0.0
def get_choice(self, idx: int) -> Optional[str]:
if idx < len(self.choices):
return self.choices[idx]
return None
def fill_choices(self) -> None:
self.choices.append('extern_mm')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
def get_name(self) -> str:
return 'mm'
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
if context.get_value('arith_intensity') <= 52.6245059967041:
if context.get_value('n') <= 34.0:
if context.get_value('n') <= 18.0:
if context.get_value('k*n') <= 312.0:
return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)]
else:
if context.get_value('k') <= 40.0:
return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)]
else:
return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)]
else:
if context.get_value('mat1_stride_0') <= 20.0:
return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)]
else:
if context.get_value('k') <= 68.0:
return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)]
else:
return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)]
else:
if context.get_value('k') <= 35.0:
if context.get_value('k') <= 18.0:
if context.get_value('m*n') <= 19505152.0:
return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)]
else:
return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)]
else:
if context.get_value('n') <= 68.0:
return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)]
else:
return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)]
else:
if context.get_value('m*n') <= 309760.0:
return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)]
else:
if context.get_value('n') <= 72.0:
return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)]
else:
return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)]
else:
if str(context.get_value('using_tf32')) != 'False':
if context.get_value('m*n') <= 815360.0:
if context.get_value('k') <= 1184.0:
return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)]
else:
return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)]
else:
if context.get_value('arith_intensity') <= 187.23922729492188:
if context.get_value('mat1_stride_0') <= 198.0:
return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)]
else:
return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)]
else:
return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)]
else:
return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)]

View File

@ -0,0 +1,321 @@
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicDecision,
)
class MMRankingH100(LearnedHeuristicDecision):
def __init__(self) -> None:
self.choices: List[Choice] = []
self.fill_choices()
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 232448
and str(metadata.device_capa) == "(9, 0)"
)
def get_confidence_threshold(self) -> float:
return 0.0
def get_choice(self, idx: int) -> Optional[str]:
if idx < len(self.choices):
return self.choices[idx]
return None
def fill_choices(self) -> None:
self.choices.append('extern_mm')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
def get_name(self) -> str:
return 'mm'
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
if context.get_value('arith_intensity') <= 29.89772129058838:
if context.get_value('n') <= 34.0:
if context.get_value('n') <= 18.0:
if context.get_value('k*n') <= 432.0:
if context.get_value('arith_intensity') <= 7.8700292110443115:
return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)]
else:
return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)]
else:
if context.get_value('k') <= 40.0:
return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)]
else:
return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)]
else:
if context.get_value('mat1_stride_0') <= 40.0:
if context.get_value('mat1_stride_0') <= 20.0:
return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)]
else:
return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)]
else:
if context.get_value('mat1_stride_0') <= 68.0:
return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)]
else:
return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)]
else:
if context.get_value('k') <= 18.0:
if context.get_value('m*k') <= 528.0:
return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)]
else:
if context.get_value('n') <= 80.0:
return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)]
else:
return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)]
else:
if context.get_value('k') <= 36.0:
if context.get_value('n') <= 68.0:
return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)]
else:
return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)]
else:
if context.get_value('mat2_stride_0') <= 384.0:
return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)]
else:
return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)]
else:
if context.get_value('arith_intensity') <= 56.995582580566406:
if context.get_value('n') <= 68.0:
if context.get_value('k*n') <= 4448.0:
if context.get_value('m*n') <= 29626368.0:
return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)]
else:
return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)]
else:
if context.get_value('k') <= 348.0:
return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)]
else:
return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)]
else:
if context.get_value('m') <= 3264.0:
return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)]
else:
if context.get_value('k') <= 62.5:
return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)]
else:
return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)]
else:
if context.get_value('m*n') <= 1097728.0:
return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)]
else:
if context.get_value('m*n') <= 3244032.0:
return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)]
else:
if context.get_value('n') <= 136.0:
return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)]
else:
return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)]

View File

@ -0,0 +1,150 @@
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicDecision,
)
class MixedMMA100(LearnedHeuristicDecision):
def __init__(self) -> None:
self.choices: List[Choice] = []
self.fill_choices()
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 166912
and str(metadata.device_capa) == "(8, 0)"
)
def get_confidence_threshold(self) -> float:
return 0.0
def get_choice(self, idx: int) -> Optional[str]:
if idx < len(self.choices):
return self.choices[idx]
return None
def fill_choices(self) -> None:
self.choices.append('extern_fallback_mixed_mm')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
def get_name(self) -> str:
return 'mixed_mm'
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
if str(context.get_value('1LEQmLEQ16')) != 'True':
if context.get_value('m') <= 32.5:
if context.get_value('n') <= 6976.0:
if context.get_value('n') <= 3520.0:
if context.get_value('m*n') <= 37632.0:
return None
else:
return [(1.000, 13)]
else:
if context.get_value('m*k') <= 452352.0:
return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)]
else:
return [(0.778, 8), (0.222, 13)]
else:
if context.get_value('k*n') <= 102776832.0:
if context.get_value('n') <= 14656.0:
return [(1.000, 11)]
else:
return [(0.889, 11), (0.111, 13)]
else:
return [(1.000, 11)]
else:
if context.get_value('m*n') <= 446464.0:
if context.get_value('m*n') <= 223424.0:
if context.get_value('mat1_stride_0') <= 3968.0:
return None
else:
return None
else:
if context.get_value('m*n') <= 346112.0:
return [(0.960, 16), (0.040, 7)]
else:
return [(0.750, 16), (0.136, 14), (0.114, 7)]
else:
if str(context.get_value('33LEQmLEQ64')) != 'True':
if context.get_value('n') <= 6976.0:
return [(1.000, 14)]
else:
return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)]
else:
if context.get_value('n') <= 13888.0:
return [(0.710, 14), (0.275, 21), (0.014, 12)]
else:
return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)]
else:
if context.get_value('n') <= 3520.0:
if context.get_value('arith_intensity') <= 3.994754433631897:
if str(context.get_value('mat2_dtype')) != 'torch.uint8':
if context.get_value('m*k') <= 18944.0:
return [(0.577, 5), (0.423, 6)]
else:
return [(0.988, 5), (0.012, 6)]
else:
if context.get_value('arith_intensity') <= 2.9899919033050537:
return None
else:
return None
else:
if context.get_value('arith_intensity') <= 7.956453561782837:
if context.get_value('k*n') <= 9244032.0:
return [(0.822, 5), (0.178, 6)]
else:
return [(0.977, 5), (0.023, 0)]
else:
if context.get_value('m*k') <= 978944.0:
return [(1.000, 5)]
else:
return [(0.971, 5), (0.029, 0)]
else:
if context.get_value('n') <= 13632.0:
if context.get_value('n') <= 6976.0:
return [(1.000, 6)]
else:
if context.get_value('k') <= 3968.0:
return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)]
else:
return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)]
else:
if context.get_value('k*n') <= 39518208.0:
return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)]
else:
if context.get_value('n') <= 20800.0:
return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)]
else:
return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)]

View File

@ -0,0 +1,149 @@
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicDecision,
)
class MixedMMH100(LearnedHeuristicDecision):
def __init__(self) -> None:
self.choices: List[Choice] = []
self.fill_choices()
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 232448
and str(metadata.device_capa) == "(9, 0)"
)
def get_confidence_threshold(self) -> float:
return 0.0
def get_choice(self, idx: int) -> Optional[str]:
if idx < len(self.choices):
return self.choices[idx]
return None
def fill_choices(self) -> None:
self.choices.append('extern_fallback_mixed_mm')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
def get_name(self) -> str:
return 'mixed_mm'
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
if context.get_value('arith_intensity') <= 15.988086223602295:
if context.get_value('n') <= 25280.0:
if context.get_value('n') <= 1344.0:
if context.get_value('mat1_stride_0') <= 7808.0:
return [(0.581, 7), (0.419, 6)]
else:
if context.get_value('m*n') <= 7680.0:
return [(0.875, 0), (0.125, 6)]
else:
return [(0.833, 0), (0.167, 7)]
else:
if context.get_value('n') <= 8512.0:
if str(context.get_value('mat2_dtype')) != 'torch.int8':
return [(0.763, 6), (0.237, 7)]
else:
return [(0.725, 7), (0.275, 6)]
else:
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
else:
return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
else:
if context.get_value('n') <= 42254.0:
if context.get_value('n') <= 33856.0:
if context.get_value('k*n') <= 68157440.0:
return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
else:
return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
else:
return [(0.659, 5), (0.341, 6)]
else:
if context.get_value('k*n') <= 326052992.0:
if context.get_value('n') <= 55232.0:
return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
else:
return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
else:
if context.get_value('n') <= 57024.0:
return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
else:
return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
else:
if context.get_value('m*n') <= 543936.0:
if str(context.get_value('17LEQmLEQ32')) != 'True':
if context.get_value('m*n') <= 262272.0:
if context.get_value('n') <= 1592.5:
return [(0.860, 0), (0.140, 9)]
else:
return None
else:
if context.get_value('m*k') <= 1294336.0:
return [(0.833, 17), (0.150, 18), (0.017, 15)]
else:
return [(0.917, 17), (0.083, 8)]
else:
if context.get_value('n') <= 12416.0:
if context.get_value('m*n') <= 43008.0:
return None
else:
return [(0.853, 14), (0.147, 9)]
else:
return [(0.625, 12), (0.375, 14)]
else:
if context.get_value('m') <= 32.5:
if context.get_value('mat2_stride_1') <= 6656.0:
if context.get_value('n') <= 69184.0:
return [(0.611, 12), (0.361, 14), (0.028, 13)]
else:
return [(1.000, 12)]
else:
if context.get_value('mat2_stride_1') <= 20864.0:
return [(1.000, 12)]
else:
return [(0.958, 12), (0.042, 9)]
else:
if context.get_value('m*n') <= 1085440.0:
if context.get_value('n') <= 9152.0:
return [(1.000, 18)]
else:
return [(0.780, 18), (0.160, 16), (0.060, 20)]
else:
if context.get_value('m') <= 67.0:
return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
else:
return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]

View File

@ -0,0 +1,109 @@
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicRegression,
)
class PadMMA100(LearnedHeuristicRegression):
def __init__(self) -> None:
pass
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 166912
and str(metadata.device_capa) == "(8, 0)"
)
def get_feedback(self, context: AHContext, choice: Choice) -> float:
context.context_dict[CHOICE_COL] = choice
return self.predict(context)
def get_confidence_threshold(self) -> float:
return 1.7025303314066
def get_name(self) -> str:
return 'pad_mm'
def predict(self, context: AHContext) -> float:
if str(context.get_value('choice')) != 'pad':
if str(context.get_value('using_tf32')) != 'False':
if context.get_value('m*n') <= 4171264.0:
if context.get_value('m*k') <= 3999308.0:
return 1.8751469764071178
else:
if str(context.get_value('n_multiple_32')) != 'True':
return 0.9117231355626345
else:
return 1.1607689608873861
else:
if str(context.get_value('n_multiple_2')) != 'True':
if str(context.get_value('using_tf32')) != 'True':
return 0.7430382200435992
else:
return 0.8531269794448678
else:
if str(context.get_value('k_multiple_2')) != 'True':
return 0.7577181972719917
else:
return 0.8977349440424219
else:
if context.get_value('m*n') <= 1299712.0:
return 1.1669723418995592
else:
if context.get_value('mat2_stride_1') <= 45217.5:
if context.get_value('m*n') <= 55884158.0:
return 1.0262769936909601
else:
return 1.0022677428470845
else:
if context.get_value('m') <= 18478.0:
return 1.1127066261894312
else:
return 1.0337740659894263
else:
if str(context.get_value('mat1_dtype')) != 'torch.float32':
if str(context.get_value('n_multiple_2')) != 'False':
if str(context.get_value('k_multiple_2')) != 'True':
if context.get_value('mat1_stride_0') <= 561.0:
return 1.2900382135142956
else:
return 1.5761737616057887
else:
if context.get_value('num_dims_needs_padding') <= 1.5:
return 1.0472263310239422
else:
return 1.1727673465762514
else:
if context.get_value('k') <= 28238.5:
if context.get_value('k/(m*n)') <= 0.00026227018679492176:
return 1.6770542505397175
else:
return 1.3974785435105923
else:
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
return 1.3952699800111992
else:
return 1.5759286511628336
else:
if str(context.get_value('using_tf32')) != 'False':
if context.get_value('m*n') <= 14119424.0:
return 0.8875772670422478
else:
if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
return 1.1467728924377265
else:
return 1.215842963532998
else:
if context.get_value('arith_intensity') <= 396.8774871826172:
return 0.89940161869551
else:
if context.get_value('mat2_stride_1') <= 45217.5:
return 0.9964328169353532
else:
return 0.9493479238294826

View File

@ -0,0 +1,315 @@
import json
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional
import torch
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
AHOperation,
Choice,
CHOICE_COL,
Feedback,
FEEDBACK_COL,
get_metadata_str_from_log,
)
from torch._inductor.autoheuristic.learned_heuristic_controller import (
LearnedHeuristicController,
)
from torch._inductor.ir import ChoiceCaller
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import get_gpu_shared_memory
class LocalFeedback:
"""
To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
(see pad_mm.py, where the autotuning happens locally, for an example).
"""
def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
self.feedback_fn = feedback_fn
def __call__(self, choice: Choice) -> Feedback:
return self.feedback_fn(choice)
class InconsistentMetadata(Exception):
"""
Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
not match the metadata it would store if the file didn't exist.
"""
class AutoHeuristic:
"""
AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
a heuristic (see torchgen/autoheuristic/).
"""
collected_feedback: Dict[Choice, Feedback]
def __init__(
self,
fallback: Callable[[], Choice],
choices: List[Choice],
feedback: Optional[LocalFeedback],
context: AHContext,
name: str,
augment_context: Optional[List[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None:
"""
Initializes an instance of the AutoHeuristic class.
Args:
fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
AutoHeuristic is in data collection mode.
choices: A list of possible choices the heuristic can make.
feedback: An instance of LocalFeedback that provides feedback for a given choice.
context: Context to store with each choice and feedback.
name: A string that identifies the heuristic.
augment_context: An optional list of AHOperation instances that augment the context.
precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
"""
self.fallback = fallback
self.choices = choices
self.feedback = feedback
self.context = context
self.name = name
self.collected_feedback = {}
self.augment_context = augment_context
self.metadata = AHMetadata(
get_gpu_shared_memory(),
torch.cuda.get_device_capability(),
self.choices,
self.name,
)
self.precondition = precondition
if not self.satisfies_precondition():
return
if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
self.log_path = self.get_default_log_path()
else:
self.log_path = torch._inductor.config.autoheuristic_log_path
if torch._inductor.config.collect_autoheuristic(self.name):
if self.feedback is not None:
for choice in self.choices:
feedback_val = self.feedback(choice)
self.save_data(choice, feedback_val)
def satisfies_precondition(self) -> bool:
return self.precondition is None or self.precondition(
self.metadata, self.context
)
def get_choice(self) -> Choice:
"""
Returns the chosen option based on the value of autoheuristic_use.
If self.name is one of the comma separated strings in autoheuristic_use,
it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
"""
if not self.satisfies_precondition():
return self.fallback()
if torch._inductor.config.use_autoheuristic(self.name):
if self.augment_context is not None:
self.context.apply_operations(self.augment_context)
controller = LearnedHeuristicController(
self.metadata,
self.context,
)
decision = controller.get_decision()
if decision not in self.choices:
# TODO(AlnisM): We might want to allow this in the future
return self.fallback()
if decision is not None:
return decision
return self.fallback()
def get_top_k_choices(
self, top_k: int, always_included: Optional[List[str]] = None
) -> Optional[List[Choice]]:
if not self.satisfies_precondition():
return None
if torch._inductor.config.use_autoheuristic(self.name):
if self.augment_context is not None:
self.context.apply_operations(self.augment_context)
controller = LearnedHeuristicController(
self.metadata,
self.context,
)
choices = controller.get_decisions_ranked(top_k)
if choices is None:
return None
if always_included is not None:
for choice in always_included:
if choice not in choices:
choices.append(choice)
return choices
return None
def get_collected_feedback(self, choice: Choice) -> Any:
return self.collected_feedback.get(choice, None)
@staticmethod
def get_device_identifier() -> str:
# a heuristic might work well for one GPU, but not for another
# we store the collected data per GPU model and learn a heuristic per GPU model
# TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
device_name = torch.cuda.get_device_name().replace(" ", "_")
return device_name
def get_default_log_path(self) -> str:
device_name = self.get_device_identifier()
path = f"{cache_dir()}/autoheuristic/{device_name}/"
os.makedirs(path, exist_ok=True)
path += f"{self.name}.txt"
return path
def serialize_metadata(self) -> str:
metadata_dict = self.metadata.to_dict()
(
num_features,
cat_features,
) = self.context.get_numerical_and_categorical_features()
metadata_dict["numerical_features"] = num_features
metadata_dict["categorical_features"] = cat_features
return json.dumps(metadata_dict)
def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
self.collected_feedback[choice] = feedback_val
log_path = self.log_path
lines = []
log_exists = os.path.exists(log_path)
if log_exists:
# if log already exists, make sure it is consistent
metadata = self.serialize_metadata()
existing_metadata = get_metadata_str_from_log(self.log_path)
if existing_metadata != metadata:
raise InconsistentMetadata(
"Given metadata does not match existing metadata"
)
else:
lines.append(self.serialize_metadata())
feature_header = self.context.get_feature_names_csv()
header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
lines.append(header)
line = ""
feature_values = self.context.get_feature_values_csv()
line += feature_values + "," + choice + "," + str(feedback_val)
lines.append(line)
with open(log_path, "a") as f:
f.write("\n".join(lines) + "\n")
class AutoHeuristicSelectAlgorithm(AutoHeuristic):
"""
AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
when one wants to use AutoHeuristic for kernel choice selection.
"""
def __init__(
self,
fallback: Callable[[], Optional[ChoiceCaller]],
choices: List[ChoiceCaller],
input_nodes: List[Any],
context: AHContext,
name: str,
augment_context: Optional[List[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None:
"""
The arguments choices, input_nodes and name have to match the ones used in the call to
autotune_select_algorithm(), e.g. if the following call is made
autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
have to be used here.
"""
self.input_nodes = input_nodes
self.choicestr2choice: Dict[str, ChoiceCaller] = {}
for choice in choices:
self.choicestr2choice[choice.autoheuristic_id()] = choice
choices_str = list(self.choicestr2choice.keys())
def fallback_str() -> str:
fallback_choice = fallback()
if fallback_choice is None:
# TODO: Find a nicer way to handle this
return "unsure"
return fallback_choice.autoheuristic_id()
super().__init__(
fallback_str,
choices_str,
None,
context,
name,
augment_context,
precondition,
)
if (
torch._inductor.config.collect_autoheuristic(self.name)
and self.satisfies_precondition()
):
self.register_global_feedback(input_nodes, choices)
def register_global_feedback(
self, input_nodes: List[Any], choices: List[ChoiceCaller]
) -> None:
"""
Registers a callback in select_algorithm, which is called with the timing of each choice.
"""
from torch._inductor.select_algorithm import (
add_feedback_saver,
create_inputs_key,
create_precompile_key,
)
def store_global_feedback(
ah_inputs_key: str,
ah_precompile_key: str,
timings: Dict[ChoiceCaller, float],
name: str,
input_nodes: List[Any],
choices: List[ChoiceCaller],
) -> None:
current_inputs_key = create_inputs_key(input_nodes)
if current_inputs_key != ah_inputs_key:
return
current_precompile_key = create_precompile_key(
name, current_inputs_key, choices
)
if current_precompile_key != ah_precompile_key:
return
for choice, time in timings.items():
self.save_data(choice.autoheuristic_id(), time)
inputs_key = create_inputs_key(input_nodes)
precompile_key = create_precompile_key(self.name, inputs_key, choices)
feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
add_feedback_saver(feedback_saver)
def get_choice_caller(self) -> Optional[ChoiceCaller]:
choice = self.get_choice()
return self.choicestr2choice.get(choice, None)
def get_top_k_choices_caller(
self, top_k: int, always_included: Optional[List[str]] = None
) -> Optional[List[ChoiceCaller]]:
choices = self.get_top_k_choices(top_k, always_included)
if choices is None:
return None
return [self.choicestr2choice[choice] for choice in choices]

View File

@ -0,0 +1,339 @@
import functools
from typing import Any, Callable, Dict, List, Tuple
import torch
Feedback = float
Choice = str
Value = Any
CHOICE_COL = "choice"
FEEDBACK_COL = "feedback"
class AHFeature:
"""
The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is
categorical (i.e., not a continuous variable) to learn a machine learning model.
"""
def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None:
self.name = name
self.value = value
self.is_categorical = is_categorical
class AHOperation:
"""
AHOperation can be used to augment the data collected by AutoHeuristic.
One might for example store features like m, k, n, but also want to use
features like m*n, or k*n, to learn a heuristic. Instead of storing features
that can be created from the collected data, one can use AHOperation to
create new features from the collected data.
"""
def __init__(
self, name: str, func: Callable[[Any], Value], is_categorical: bool = False
) -> None:
self.name = name
self.func = func
self.is_categorical = is_categorical
def apply_operation(self, data: Any) -> None:
data[self.name] = self.func(data)
class AHContext:
"""
This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will
store the context and the collected feedback. The context could be something like the shape of a tensor, i.e.,
information that will help to learn a heuristic.
"""
features: List[AHFeature]
context_dict: Dict[str, Value]
def __init__(self) -> None:
self.features = []
self.context_dict = {}
def add_feature(
self, name: str, value: Value, is_categorical: bool = False
) -> None:
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
self.context_dict[name] = value
def get_numerical_and_categorical_features(self) -> Tuple[List[str], List[str]]:
numerical_features = []
categorical_features = []
for feature in self.features:
if feature.is_categorical:
categorical_features.append(feature.name)
else:
numerical_features.append(feature.name)
return numerical_features, categorical_features
def get_feature_names_csv(self) -> str:
return ",".join(feature.name for feature in self.features)
def get_feature_values_csv(self) -> str:
return ",".join(str(feature.value) for feature in self.features)
def get_value(self, name: str) -> Value:
return self.context_dict[name]
def apply_operations(self, operations: List[AHOperation]) -> None:
for op in operations:
op.apply_operation(self.context_dict)
class AHMetadata:
def __init__(
self,
shared_memory: Any,
device_capa: Tuple[int, int],
choices: List[Choice],
name: str,
) -> None:
# use amount of shared_memory and device_capability to identify GPU
# TODO(AlnisM): there might be a better way to do this
self.shared_memory = shared_memory
self.device_capa = device_capa
self.choices = choices
self.name = name
def to_dict(self) -> Dict[str, Value]:
return {
"shared_memory": self.shared_memory,
"device_capa": self.device_capa,
"name": self.name,
}
def get_metadata_str_from_log(log_path: str) -> str:
with open(log_path, newline="") as file:
json_string = file.readline().strip()
return json_string
def check_minsize(context: AHContext, minsize: int) -> bool:
return (
context.get_value("m") >= minsize
and context.get_value("k") >= minsize
and context.get_value("n") >= minsize
)
def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0):
# A100 precondition
return check_minsize(context, 512)
elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0):
# H100 precondition
return check_minsize(context, 768)
return True
def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
m = context.get_value("m")
k = context.get_value("k")
n = context.get_value("n")
if m > 128 or k < 1024 or n < 1024:
return False
mat1_iscontig = context.get_value("mat1_iscontig")
mat2_iscontig = context.get_value("mat2_iscontig")
return mat1_iscontig and not mat2_iscontig
def get_mult_dims_ops() -> List[AHOperation]:
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
return [m_times_k_op, m_times_n_op, k_times_n_op]
def get_arith_intensity(data: Any) -> float:
m = data["m"]
k = data["k"]
n = data["n"]
if m == 0 or k == 0 or n == 0:
return 0.0
return m * k * n / (m * k + k * n + m * n)
def pad_mm_operations() -> List[AHOperation]:
mult_dims_ops = get_mult_dims_ops()
k_div_m_times_n_op = AHOperation(
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
)
def bfloat_perf_hit(data: Any) -> bool:
m = data["m"]
k = data["k"]
n = data["n"]
is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16"
return k > (m * 1024) and k > (n * 1024) and is_bfloat
bfloat_perf_hit_op = AHOperation(
"bfloat_perf_hit", bfloat_perf_hit, is_categorical=True
)
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
dims_need_padding_ops = get_dims_need_padding_ops()
dims_multiple_ops = get_dims_multiple_ops()
is_contig_ops = get_is_contig_ops()
ah_operations = mult_dims_ops + [
k_div_m_times_n_op,
bfloat_perf_hit_op,
arith_intensity_op,
]
ah_operations.extend(dims_need_padding_ops)
ah_operations.extend(dims_multiple_ops)
ah_operations.extend(is_contig_ops)
return ah_operations
def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
return data[dim] >= lower and data[dim] <= upper
def between_ops() -> List[AHOperation]:
dims = ["m", "k", "n"]
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
ah_operations = []
for dim in dims:
for lower, upper in limits:
between_op_fn = functools.partial(
between_op, dim=dim, lower=lower, upper=upper
)
# using 'LEQ' instead of '<=' because '<=' cannot be exported to dot
between_op_name = f"{lower}LEQ{dim}LEQ{upper}"
ah_operations.append(
AHOperation(between_op_name, between_op_fn, is_categorical=True)
)
return ah_operations
def pow2_op(data: Any, dim: str, exponent: int) -> bool:
return data[dim] == 2**exponent
def mm_operations() -> List[AHOperation]:
mult_dims_ops = get_mult_dims_ops()
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
return mult_dims_ops + [arith_intensity_op]
def mixed_mm_operations() -> List[AHOperation]:
return mm_operations() + between_ops()
def is_multiple(data: Any, dim: str, mult: int) -> bool:
return data[dim] % mult == 0
def get_dims_multiple_ops() -> List[AHOperation]:
multiples = [2, 4, 8, 16, 32]
dims = ["m", "k", "n"]
dims_multiple_ops = []
for dim in dims:
for mult in multiples:
is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult)
dims_multiple_op = AHOperation(
f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True
)
dims_multiple_ops.append(dims_multiple_op)
return dims_multiple_ops
def get_dims_need_padding_ops() -> List[AHOperation]:
def mat1_innermost_needs_padding_fn(data: Any) -> bool:
mat1_stride_0 = data["mat1_stride_0"]
mat1_stride_1 = data["mat1_stride_1"]
m_padded_length = data["m_padded_length"]
k_padded_length = data["k_padded_length"]
mat1_innermost_needs_padding = False
if mat1_stride_0 == 1 and m_padded_length != 0:
mat1_innermost_needs_padding = True
if mat1_stride_1 == 1 and k_padded_length != 0:
mat1_innermost_needs_padding = True
return mat1_innermost_needs_padding
mat1_innermost_op = AHOperation(
"mat1_innermost_needs_padding",
mat1_innermost_needs_padding_fn,
is_categorical=True,
)
def mat2_innermost_needs_padding_fn(data: Any) -> bool:
mat2_stride_0 = data["mat2_stride_0"]
mat2_stride_1 = data["mat2_stride_1"]
k_padded_length = data["k_padded_length"]
n_padded_length = data["n_padded_length"]
mat2_innermost_needs_padding = False
if mat2_stride_0 == 1 and k_padded_length != 0:
mat2_innermost_needs_padding = True
if mat2_stride_1 == 1 and n_padded_length != 0:
mat2_innermost_needs_padding = True
return mat2_innermost_needs_padding
mat2_innermost_op = AHOperation(
"mat2_innermost_needs_padding",
mat2_innermost_needs_padding_fn,
is_categorical=True,
)
def num_dims_needs_padding_fn(data: Any) -> int:
m_padded_length = data["m_padded_length"]
k_padded_length = data["k_padded_length"]
n_padded_length = data["n_padded_length"]
num_dims_needs_padding = 0
if m_padded_length != 0:
num_dims_needs_padding += 1
if k_padded_length != 0:
num_dims_needs_padding += 1
if n_padded_length != 0:
num_dims_needs_padding += 1
return num_dims_needs_padding
num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn)
return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
def get_is_contig_ops() -> List[AHOperation]:
def mat1_is_contig_fn(data: Any) -> bool:
stride_0 = data["mat1_stride_0"]
stride_1 = data["mat1_stride_1"]
k = data["k"]
return stride_0 == k and stride_1 == 1
mat1_is_contig_op = AHOperation(
"mat1_iscontig", mat1_is_contig_fn, is_categorical=True
)
def mat2_is_contig_fn(data: Any) -> bool:
stride_0 = data["mat2_stride_0"]
stride_1 = data["mat2_stride_1"]
n = data["n"]
return stride_0 == n and stride_1 == 1
mat2_is_contig_op = AHOperation(
"mat2_iscontig", mat2_is_contig_fn, is_categorical=True
)
return [mat1_is_contig_op, mat2_is_contig_op]
def context_add_strides(context: AHContext, name: str, stride: Tuple[int, ...]) -> None:
for i, s in enumerate(stride):
context.add_feature(f"{name}_stride_{i}", s)
def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None:
using_tf32 = "not_float_32"
if dtype == torch.float32:
using_tf32 = torch.backends.cuda.matmul.allow_tf32
context.add_feature("using_tf32", using_tf32, is_categorical=True)

View File

@ -0,0 +1,119 @@
import importlib
import inspect
import pkgutil
from collections import defaultdict
from typing import Any, Dict, List, Optional
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
def find_and_instantiate_subclasses(
package_name: str, base_class: Any
) -> List[LearnedHeuristic]:
instances = []
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.walk_packages(
package.__path__, package.__name__ + "."
):
try:
module_basename = module_name.split(".")[-1]
if not module_basename.startswith("_"):
# learned heuristics start with an underscore
continue
module = importlib.import_module(module_name)
# look for classes that are subclasses of base_class
for name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and issubclass(obj, base_class)
and obj != base_class
):
instance = obj()
instances.append(instance)
except Exception as e:
print(f"Error processing module {module_name}: {e}")
return instances
class LearnedHeuristicController:
"""
Class that finds and instantiates all learned heuristics. It also provides
a way to get the decision of a learned heuristic.
"""
existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list)
"""
A dictionary that stores all the learned heuristics for each optimization.
The key is the optimization name, and the value is a list of LearnedHeuristic objects.
"""
heuristics_initialized: bool = False
"""
A flag that indicates whether the learned heuristics have been initialized.
Set to true when the get_decision() function is called for the first time.
"""
def __init__(
self,
metadata: AHMetadata,
context: AHContext,
) -> None:
self.metadata = metadata
self.context = context
def get_heuristics(self, name: str) -> List[LearnedHeuristic]:
"""
Returns a list of learned heuristics for the given optimization name.
"""
if not LearnedHeuristicController.heuristics_initialized:
# learned heuristics are generated into the following package
learned_heuristics_package = "torch._inductor.autoheuristic.artifacts"
# learned heuristics have to be of type LearnedHeuristic
base_class = LearnedHeuristic
found_heuristics = find_and_instantiate_subclasses(
learned_heuristics_package, base_class
)
for learned_heuristic in found_heuristics:
opt_name = learned_heuristic.get_name()
LearnedHeuristicController.existing_heuristics[opt_name].append(
learned_heuristic
)
LearnedHeuristicController.heuristics_initialized = True
return LearnedHeuristicController.existing_heuristics[name]
def get_decision(self) -> Optional[Choice]:
"""
Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure
which choice to make.
"""
heuristics = self.get_heuristics(self.metadata.name)
for heuristic in heuristics:
if heuristic.check_precondition(self.metadata, self.context):
return heuristic.get_decision(self.context, self.metadata.choices)
return None
def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]:
heuristics = self.get_heuristics(self.metadata.name)
for heuristic in heuristics:
if heuristic.check_precondition(self.metadata, self.context):
choices = heuristic.get_decisions_ranked(self.context)
if choices is None:
return None
avail_choices = [
choice for choice in choices if choice in self.metadata.choices
]
return avail_choices[:top_k]
return None

View File

@ -0,0 +1,92 @@
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
class LearnedHeuristic:
"""
LearnedHeuristic is a base class for all learned heuristics.
"""
def __init__(self) -> None:
pass
def check_precondition(
self,
metadata: AHMetadata,
context: AHContext,
) -> bool:
return True
def get_decision(
self, context: AHContext, choices: List[Choice]
) -> Optional[Choice]:
return None
def get_confidence_threshold(self) -> float:
return 1.0
def get_name(self) -> str:
return ""
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
return None
class LearnedHeuristicRegression(LearnedHeuristic):
def __init__(self) -> None:
super().__init__()
def get_feedback(self, context: AHContext, choice: Choice) -> float:
return 1.0
def get_decision(
self, context: AHContext, choices: List[Choice]
) -> Optional[Choice]:
choice2feedback = {}
for choice in choices:
predicted_feedback = self.get_feedback(context, choice)
choice2feedback[choice] = predicted_feedback
sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1])
highest_feedback = sorted_choices_feedback[-1][1]
second_highest_feedback = sorted_choices_feedback[-2][1]
if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
return sorted_choices_feedback[-1][0]
# We are not sure which choice is the best one
return None
class LearnedHeuristicDecision(LearnedHeuristic):
def __init__(self) -> None:
super().__init__()
def get_choice(self, idx: int) -> Optional[str]:
return None
def get_decision(
self, context: AHContext, choices: List[Choice]
) -> Optional[Choice]:
best_choices = self.get_best_choices(context)
if not best_choices:
return None
(best_choice_proba, best_choice_idx) = best_choices[0]
if best_choice_proba <= self.get_confidence_threshold():
return None
return self.get_choice(best_choice_idx)
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
feedback_idx_list = self.get_best_choices(context)
if feedback_idx_list is None:
return None
choices = [
self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
]
choices = [choice for choice in choices if choice is not None]
return choices
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
return []

View File

@ -0,0 +1,876 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import ctypes
import dataclasses
import functools
import logging
import os
import queue
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p, CDLL
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Union,
)
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch import multiprocessing
from torch._dynamo.testing import rand_strided
from torch._inductor import ir
from torch._inductor.codecache import (
CppCodeCache,
CUDACodeCache,
DLLWrapper,
get_hash,
PyCodeCache,
)
if TYPE_CHECKING:
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from types import ModuleType
from torch._inductor.select_algorithm import TritonTemplateCaller
from . import config
from .runtime.benchmarking import benchmarker
from .virtualized import V
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
EXIT_HANDLER_REGISTERED = False
log = logging.getLogger(__name__)
# Used to synchronize between parent and child processes
class Ping:
pass
class Pong:
pass
class NonzeroWorkspaceNotSupportedError(Exception):
pass
@contextlib.contextmanager
def set_cuda_visible_device(device: Optional[int]):
"""
Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
specified single device. If device is None, don't manipulate the environment.
"""
if device is None:
yield
return
current = os.environ.get(CUDA_VISIBLE_DEVICES)
os.environ[CUDA_VISIBLE_DEVICES] = str(device)
try:
yield
finally:
if current is None:
del os.environ[CUDA_VISIBLE_DEVICES]
else:
os.environ[CUDA_VISIBLE_DEVICES] = current
@dataclasses.dataclass
class TuningProcess:
"""
Abstraction for launching a helper process to benchmark kernels. Spawns
the parent process and uses multiprocessing queues to send benchmark
requests and return results.
"""
device: Optional[int] = None
process: Optional[BaseProcess] = None
request_queue: Optional[Queue[Any]] = None
response_queue: Optional[Queue[Any]] = None
@staticmethod
def process_main(
request_queue: Queue[Any],
response_queue: Queue[Any],
) -> None:
"""
Entry point for the child process.
"""
log.debug(
"Entering TuningProcess child. Visible devices = %s",
os.environ.get(CUDA_VISIBLE_DEVICES),
)
try:
TuningProcess.workloop(request_queue, response_queue)
except Exception as ex:
log.exception("Exception in TuningProcess")
@staticmethod
def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
"""
Work loop for the benchmarking subprocess.
"""
while True:
obj = request_queue.get()
if obj is None:
break # None is a sentinel for the child to terminate
elif isinstance(obj, Ping):
response_queue.put(Pong())
elif isinstance(obj, BenchmarkRequest):
response_queue.put(obj.benchmark())
else:
raise RuntimeError(f"Invalid request type {type(obj)}")
def valid(self) -> bool:
"""
True if the sub-process has been initialized.
"""
return (
self.process is not None
and self.request_queue is not None
and self.response_queue is not None
)
def clear(self) -> None:
"""
Reset to an uninitialized state.
"""
self.process = self.request_queue = self.response_queue = None
def initialize(self) -> None:
"""
Create child process, request/response queues, and do the warm up.
Set the environment to make only the provided GPU device visible
to the process.
"""
if self.valid():
return
# cuda runtime does not work with "fork", use "spawn" to start processes.
ctx = multiprocessing.get_context("spawn")
self.request_queue = ctx.Queue()
self.response_queue = ctx.Queue()
self.process = ctx.Process(
target=self.process_main,
args=(
self.request_queue,
self.response_queue,
),
)
assert self.process is not None
with set_cuda_visible_device(self.device):
self.process.start()
def put(self, obj: Any) -> None:
"""
Push a work item to the child process.
"""
# In case of a prior crash, ensure the subprocess is running
self.initialize()
assert self.request_queue is not None
self.request_queue.put(obj)
def get(
self, result_timeout=120.0, graceful_timeout=3.0, terminate_timeout=1.0
) -> Any:
"""
Get a response from the child process. Raises queue.Empty on timeout
or if the process dies.
This method is (so far) only used by TuningProcessPool, where torch._inductor.config entries are being used
to populate the timeouts:
Arguments:
@param result_timeout: Timeout in seconds, defaults to 120.0 or to
config.max_autotune_subproc_result_timeout_seconds when called by TuningProcessPool
@param graceful_timeout: Timeout in seconds to allow graceful shutdown (SIGTERM is sent after this time).
Defaults to 3.0 or to config.max_autotune_subproc_graceful_timeout_seconds
@param terminate_timeout: Timeout in seconds after SIGTERM, until we send SIGKILL if the process
remains alive. Defaults to 1.0 or to
config.max_autotune_subproc_terminate_timeout_seconds.
Returns:
A response from the child process (Any type)
"""
assert self.process is not None
assert self.response_queue is not None
while True:
try:
remaining_timeout = result_timeout
res = None
while remaining_timeout is not None and remaining_timeout >= 1.0:
remaining_timeout -= 0.5
try:
res = self.response_queue.get(timeout=0.5)
break
except queue.Empty:
if not self.process.is_alive():
raise # is being caught a few lines below
if res is None:
res = self.response_queue.get(timeout=remaining_timeout)
return res
except queue.Empty:
status = self.process.exitcode
if status is None:
self.kill(
graceful_timeout=graceful_timeout,
terminate_timeout=terminate_timeout,
)
else:
# child process crashed
self.clear()
raise
def terminate(self) -> None:
"""
Signal the child process to terminate.
"""
if self.valid():
assert self.process is not None
assert self.request_queue is not None
self.request_queue.put(None)
def wait(self) -> None:
"""
Wait for the child process to exit.
"""
if self.process is not None:
self.process.join()
self.clear()
def kill(self, graceful_timeout=5.0, terminate_timeout=1.0) -> None:
# Tries to kill the process, using a graceful_timeout in which the process
# is allowed to exit gracefully. If the process is still alive,
# it will be terminated. If that is not sufficient to end it
# within terminate_timeout seconds, it will be killed.
if self.process is not None:
self.terminate()
self.process.join(timeout=graceful_timeout)
if self.process.is_alive():
log.warning(
"Sending SIGTERM to process with PID %d",
self.process.pid,
)
self.process.terminate()
self.process.join(timeout=terminate_timeout)
if self.process.is_alive():
log.error(
"Sending SIGKILL to process with PID %d",
self.process.pid,
)
self.process.kill() # This should definitely end the process
self.clear()
@dataclasses.dataclass
class TuningProcessPool:
"""
Maintains a pool of TuningProcesses to benchmark kernels in parallel
across devices. By default, we create one TuningProcess per device and
set the sub-process environment to make only that device visible.
"""
processes: Optional[queue.Queue[TuningProcess]] = None
executor: Optional[ThreadPoolExecutor] = None
def initialize(self) -> None:
"""
Start the child processes.
"""
assert (self.processes is None) == (self.executor is None)
if self.processes is not None:
return
devices = self.get_device_list()
log.debug("Sub-process autotune device list: %s", devices)
# Launch the child processes and push a msg to "warm up"
self.processes = queue.Queue()
for device in devices:
p = TuningProcess(device=device)
p.initialize()
p.put(Ping())
self.processes.put(p)
# Wait for the initialization to finish
for p in self.processes.queue:
assert isinstance(p.get(result_timeout=None), Pong)
# Use a thread pool to manage distributing work to the subprocesses.
# Threads block on an available process, so it makes sense to match
# the number of threads with the number of devices.
self.executor = ThreadPoolExecutor(max_workers=len(devices))
# Register the exit handler for the parent process so it will terminate
# the child processes.
global EXIT_HANDLER_REGISTERED
if not EXIT_HANDLER_REGISTERED:
EXIT_HANDLER_REGISTERED = True
import atexit
atexit.register(self.terminate)
def get_device_list(self) -> Sequence[Optional[int]]:
"""
Gather the list of devices to be used in the pool.
"""
if not config.autotune_multi_device:
# Don't use multiple devices
return [None]
count = torch.cuda.device_count()
# If the user specified the visible devices in the env, use those.
if CUDA_VISIBLE_DEVICES in os.environ:
devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
assert len(devices) <= count
return devices
return list(range(count))
def terminate(self) -> None:
"""
Signal all child processes to terminate.
"""
if self.executor is not None:
self.executor.shutdown()
self.executor = None
if self.processes is not None:
for p in self.processes.queue:
p.terminate()
for p in self.processes.queue:
p.wait()
self.processes = None
def target(self, choice: TritonTemplateCaller) -> float:
"""
Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
remove it from the queue, execute the benchmark in that subprocess, and return
the TuningProcess to the queue.
"""
assert choice.bmreq is not None
assert self.processes is not None
process = self.processes.get()
process.put(choice.bmreq)
try:
return process.get(
config.max_autotune_subproc_result_timeout_seconds,
config.max_autotune_subproc_graceful_timeout_seconds,
config.max_autotune_subproc_terminate_timeout_seconds,
)
except queue.Empty:
warnings.warn(
f"Failed to benchmark choice '{choice}'. It will be ignored. "
"Please debug the root cause in case the choice can bring perf gains."
)
# set to INF so this choice will be ignored
return float("inf")
finally:
self.processes.put(process)
def benchmark(
self,
choices: List[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]:
"""
Benchmark each choice in a separate process.
"""
assert self.processes is not None, "Tuning process pool is not initialized"
assert self.executor is not None
results = {}
# Use a ThreadExecutorPool to spread the work across the subprocesses and
# to grab subprocesses as soon as they're free.
for choice, result in zip(choices, self.executor.map(self.target, choices)):
results[choice] = result
return results
tuning_pool = TuningProcessPool()
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
@dataclasses.dataclass
class TensorMeta:
device: torch.device
dtype: torch.dtype
sizes: torch._prims_common.ShapeType
strides: torch._prims_common.StrideType
offset: int
name: Optional[str] = None
@classmethod
def from_irnodes(
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
) -> Union[TensorMeta, List[TensorMeta]]:
if isinstance(irnodes, Sequence):
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
assert all(isinstance(x, TensorMeta) for x in result)
return result
node = irnodes
if isinstance(node, ir.Layout):
node = ir.Buffer("fake", node)
dtype = node.get_dtype()
assert dtype is not None
return TensorMeta(
device=node.get_device(),
dtype=dtype,
sizes=V.graph.sizevars.size_hints(
node.get_size(),
fallback=config.unbacked_symint_fallback,
),
strides=V.graph.sizevars.size_hints(
node.get_stride(),
fallback=config.unbacked_symint_fallback,
),
offset=V.graph.sizevars.size_hint(
node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
),
name=node.get_name(),
)
def to_tensor(self) -> torch.Tensor:
return rand_strided(
self.sizes,
self.strides,
device=self.device,
dtype=self.dtype,
extra_size=self.offset,
)
@dataclasses.dataclass
class BenchmarkRequest:
"""
Only handle triton template benchmark for now. The extern kernel benchmark
can be done inside the same process since they usually don't cause crash.
Important: Instances of this class and subclasses have to be serializable
across process boundaries. Do not put CUDA Tensors in here!
"""
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
extra_args: Iterable[Any],
) -> None:
# the kernel name defined in the module
self.kernel_name = kernel_name
if isinstance(input_tensor_meta, TensorMeta):
input_tensor_meta = [input_tensor_meta]
self.input_tensor_meta = input_tensor_meta
if isinstance(output_tensor_meta, (tuple, list)):
assert len(output_tensor_meta) == 1
output_tensor_meta = output_tensor_meta[0]
self.output_tensor_meta = output_tensor_meta
self.extra_args = extra_args
def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
raise NotImplementedError
def cleanup_run_fn(self) -> None:
pass
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
raise NotImplementedError
def benchmark(
self,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
debug = log.isEnabledFor(logging.DEBUG)
if debug:
start_ts = time.time()
# create args and out tensor
if output_tensor is None:
assert len(input_tensors) == 0
input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
output_tensor = self.output_tensor_meta.to_tensor()
if debug:
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
start_ts = time.time()
try:
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
except NonzeroWorkspaceNotSupportedError:
# Skipping all ops with nonzero workspace requirements
log.info("Skipping op due to nonzero workspace requirement")
return float("inf")
if debug:
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
start_ts = time.time()
out = self.do_bench(fn, *input_tensors, output_tensor)
if debug:
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
log.debug(
"InChildProcess %s: load %f, create tensor %f, bench %f",
str(self),
load_elapse, # type: ignore[possibly-undefined]
create_tensor_elapse, # type: ignore[possibly-undefined]
bench_elapse,
)
self.cleanup_run_fn()
return out
class TestBenchmarkRequest(BenchmarkRequest):
"""
Supports unit testing. Defined in this file so that the TuningProcess
sub-process knows how to unpickle these objects.
"""
def __init__(self, value: Optional[float] = None) -> None:
self.value = value
def benchmark(
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
) -> float:
if self.value is None:
raise Exception("Failed to run") # noqa: TRY002
return self.value
class GPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
device_idx_set = {
tensor.device.index
for tensor in [*input_tensors, output_tensor]
if isinstance(tensor, torch.Tensor)
and tensor.is_cuda
and tensor.device.index is not None
}
assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
if len(device_idx_set) == 1:
device_idx = next(iter(device_idx_set))
else:
device_idx = torch.cuda.current_device()
with torch.cuda.device(device_idx):
out = benchmarker.benchmark_gpu(fn)
torch.cuda.synchronize() # shake out any CUDA errors
return out
class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
extra_args: Iterable[Any],
module_path: str, # the path of the module defining the triton kernel
module_cache_key: str,
grid: List[int],
num_stages: int,
num_warps: int,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.module_path = module_path
self.module_cache_key = module_cache_key
self.grid = grid
self.num_stages = num_stages
self.num_warps = num_warps
self.matrix_instr_nonkdim = matrix_instr_nonkdim
def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
log.debug(
"benchmark module key: %s, path: %s",
self.module_cache_key,
self.module_path,
)
run_method = getattr(mod, self.kernel_name).run
extra_args = list(self.extra_args)
# Newer version of triton add warmup argument to JITFunction.run.
# This code handles backward-compatibility.
warmup_arg = {}
import inspect
if "warmup" in inspect.signature(run_method).parameters:
warmup_arg["warmup"] = False
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
if torch.version.hip and self.matrix_instr_nonkdim != 0:
return functools.partial(
run_method,
*input_tensors,
output_tensor,
*self.extra_args,
grid=self.grid,
**warmup_arg,
stream=get_raw_stream(self.output_tensor_meta.device.index),
)
else:
return functools.partial(
run_method,
*input_tensors,
output_tensor,
*self.extra_args,
grid=self.grid,
**warmup_arg,
stream=get_raw_stream(self.output_tensor_meta.device.index),
)
def precompile(self):
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
getattr(mod, self.kernel_name).precompile()
def __str__(self) -> str:
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.source_code = source_code
self.workspace_size: int = 0
self.workspace: Optional[torch.Tensor] = None
self.DLL: Optional[DLLWrapper] = None
self._workspace_size_updated = False
self.hash_key: str = ""
self.source_file: str = ""
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
def precompile(self):
# Prepopulate CUDACodeCache
# may happen in separate Threadpool
log.debug("Precompiling %s", self)
CUDACodeCache.compile(self.source_code, "so")
log.debug("Done precompiling %s", self)
def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
self.ensure_dll_loaded()
self.update_workspace_size()
args = [
c_void_p(tensor.data_ptr())
for tensor in list(input_tensors) + [output_tensor]
]
log.debug(
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
self.kernel_name,
self.source_file,
self.hash_key,
self.DLL,
args,
self.extra_args,
)
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
run_method = getattr(self.DLL, self.kernel_name)
workspace_ptr = c_void_p(0)
if self.workspace_size > 0:
self.workspace = torch.zeros(
(self.workspace_size + 7) // 8,
dtype=torch.float64,
device=output_tensor.device,
)
workspace_ptr = c_void_p(self.workspace.data_ptr())
# Generate partial function.
return functools.partial(
run_method,
*args,
*self.extra_args,
None, # null workspace size ptr
workspace_ptr, # set workspace ptr,
stream_ptr,
)
def update_workspace_size(self) -> None:
if self._workspace_size_updated:
return
self.ensure_dll_loaded()
unique_input_count = len({meta.name for meta in self.input_tensor_meta})
args = [c_void_p(None) for _ in range(unique_input_count + 1)]
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
run_method = getattr(self.DLL, self.kernel_name)
# Retrieve workspace_size and initialize workspace.
c_workspace_size = c_size_t()
run_method(
*args, # input ptrs and output ptrs
*self.extra_args,
byref(
c_workspace_size
), # set workspace size ptr to retrieve workspace size
None, # null workspace ptr
stream_ptr,
)
torch.cuda.synchronize() # shake out any CUDA errors
self.workspace_size = c_workspace_size.value
log.debug(
"update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
self.workspace_size,
self.kernel_name,
self.source_file,
self.hash_key,
self.DLL,
args,
self.extra_args,
)
self._workspace_size_updated = True
def ensure_dll_loaded(self):
if self.DLL is None:
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
self.source_code, "so"
)
def cleanup_run_fn(self) -> None:
if self.DLL is not None:
self.DLL.close()
self.workspace = None
def __str__(self) -> str:
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
class CPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
return benchmarker.benchmark_cpu(fn)
class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put Tensors in here!
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.source_code = source_code
self.hash_key = get_hash(source_code)
self.DLL: Optional[Union[CDLL, ModuleType]] = None
def precompile(self):
# Prepopulate CppCodeCache
# may happen in separate Threadpool
log.debug("Precompiling %s", self)
CppCodeCache.load(self.source_code, cuda=False)
log.debug("Done precompiling %s", self)
def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
# TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf
self.DLL = CppCodeCache.load(self.source_code, cuda=False)
args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
log.debug(
"make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
self.kernel_name,
self.DLL,
args,
self.extra_args,
)
run_method = getattr(self.DLL, self.kernel_name)
# Assume only size with type ctypes.c_ulonglong in extra_args
assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args)
run_method.argtypes = [ctypes.c_ulonglong] * (
len(args) + len(list(self.extra_args))
)
# Generate partial function.
return functools.partial(
run_method,
*args,
*self.extra_args,
)
def cleanup_run_fn(self) -> None:
if self.DLL is not None:
"""
Check close attr due to it crash on Windows.
"""
if hasattr(self.DLL, "close"):
self.DLL.close()
def __str__(self) -> str:
return f"{self.kernel_name=}"
def benchmark_in_sub_process(
choices: List[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]:
"""
Do benchmarking in a subprocess and return the perf number (latency).
"""
return tuning_pool.benchmark(choices)

View File

@ -0,0 +1,140 @@
# mypy: allow-untyped-defs
import logging
import operator
from functools import partial
from typing import Any, Callable, Dict
from sympy import Expr
import torch
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
from .utils import cache_on_self, dominated_nodes
from .virtualized import V
log = logging.getLogger(__name__)
class BoundVars:
"""
Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
It exposes the ranges of the nodes in the `bounds` variable
Note. A current limitation of this analysis is that it just works on a per-loop basis.
We should be able to propagate the bounds between across the whole graph. This may benefit
the case a bounded variable is returned by a kernel and fed into another.
"""
def __init__(self, loop_body: LoopBody) -> None:
def upper_bound(v):
return bound_sympy(v).upper if isinstance(v, Expr) else v
self.loop_body = loop_body
self.replacement_vals = {
k: ValueRanges[Expr](0, upper_bound(v) - 1)
for k, v in loop_body.var_ranges.items()
}
# avoid computing these values, pessimistically assume that they are unbounded
self.unbounded_vars = dominated_nodes(
node
for node in self.loop_body.get_nodes()
if node.target in ["load", "reduction", operator.getitem]
or "masked_subblock" in node.target
)
# To access this variable call `get_bounds()`
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"loop_body={self.loop_body},\n "
f"replacement_vals={self.replacement_vals}, \n"
f"unbounded_vars={self.unbounded_vars}, \n"
f"_bounds={self._bounds})"
)
@cache_on_self
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
submodules = self.swap_submodules(self.loop_body.submodules)
# Initialize the environment with the unbounded variables
for node in self.unbounded_vars:
# we need to evaluate masked_subblock to recurse, and we need to set indirect values
if not isinstance(node.target, str) or (
"masked_subblock" not in node.target
and "set_indirect" not in node.target
):
self._bounds[node] = ValueRanges[Expr].unknown()
with V.set_ops_handler(ValueRangeAnalysis()):
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
log.debug("get_bounds:\n%s", self.loop_body.root_block.graph)
interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
return self._bounds
def swap_submodules(
self, submodules: Dict[str, Callable[..., Any]]
) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
for key in submodules.keys():
if key == "get_index":
result[key] = self.get_index
elif "masked_subblock" in key:
subblock = self.loop_body.subblocks[key]
# The result within the lambda will reference to the final
# set of modules at the end of the for-loop as it stores a reference to it
# bind subblock in a function because python lambdas close over by reference
# moving the lambda out of make_fn would close over the reference to subblock,
# so all lambdas would have the same subblock reference that is the final
# subblock in the loop
def make_fn(subblock):
return lambda mask, value: self.masked_subblock(
subblock, self._bounds, mask, value, result
)
result[key] = make_fn(subblock)
elif "set_indirect" in key:
idx = int(key[len("set_indirect") :])
var = self.loop_body.indirect_vars[idx]
indirect = partial(self.set_indirect, var)
result[key] = indirect
else:
assert "scan" in key
result[key] = submodules[key]
return result
def masked_subblock(
self,
subblock: LoopBodyBlock,
env: Dict[torch.fx.Node, ValueRanges[Expr]],
mask: Any,
value: Any,
submodules: Dict[str, Callable[..., Any]],
) -> ValueRanges[Expr]:
interp = InterpreterShim(subblock.graph, submodules)
interp.run(V.get_ops_handler(), initial_env=env)
output = [node for node in subblock.graph.nodes if node.target == "output"]
assert len(output) == 1
# dont bother unioning with value since the load from buffer will be
# pessimistically assumed to be inf anyway
return interp.env[output[0]]
def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
assert isinstance(new, ValueRanges)
self.replacement_vals[old] = new
return new
def get_index(self, name: Expr) -> ValueRanges[Expr]:
expr = self.loop_body.indexing_exprs[name]
bound = self.replacement_vals.get(expr)
if bound is None:
bound = bound_sympy(expr, self.replacement_vals)
# The following assertion is true at the time of this writing
# We don't assert is as to not execute bound_sympy when bound is not None
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
self.replacement_vals[name] = bound
return bound

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,32 @@
# mypy: allow-untyped-defs
import re
import torch
from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
# "...
# from ..codecache import CudaKernelParamCache
# ..."
# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache
def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
if torch.version.hip is None and not force_hipify:
return source_codes
def c2_repl(m):
return PYTORCH_MAP[m.group(0)]
# We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
# it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
# keyword at the beginning of code line. However, this can happen in codegen,
# which will cause the pattern to not match.
# Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
# we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)
return source_codes

Some files were not shown because too many files have changed in this diff Show More