I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,128 @@
# mypy: ignore-errors
import contextlib
import functools
import logging
from unittest.mock import patch
import torch
from torch._dynamo import disable
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
from torch._functorch.aot_autograd import aot_module_simplified
from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
class AotAutograd:
def __init__(self, **kwargs) -> None:
self.__name__ = "compiler_fn"
self.kwargs = kwargs
def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
if kwargs:
log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
return flatten_graph_inputs(
gm,
example_inputs,
self,
)
# Hack to get around circular import problems with aot_eager_decomp_partition
if callable(self.kwargs.get("decompositions")):
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
# NB: dont delete counter increment
counters["aot_autograd"]["total"] += 1
use_fallback = False
if use_fallback:
log.debug("Unable to use AOT Autograd because graph has mutation")
counters["aot_autograd"]["not_ok"] += 1
return gm
# OK attempt to compile
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs))
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
self.kwargs["inference_compiler"] = (
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
)
from functorch.compile import nop
from torch._inductor.debug import enable_aot_logging
# debug asserts slow down compile time noticeably,
# So only default them on when the aot_eager backend is used.
if self.kwargs.get("fw_compiler", None) == nop:
patch_config = patch("functorch.compile.config.debug_assert", True)
else:
patch_config = contextlib.nullcontext()
try:
# NB: NOT cloned!
with enable_aot_logging(), patch_config:
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
counters["aot_autograd"]["ok"] += 1
return disable(cg)
except Exception:
counters["aot_autograd"]["not_ok"] += 1
raise
def aot_autograd(**kwargs):
return AotAutograd(**kwargs)
def mem_efficient_fusion_kwargs(use_decomps):
from functorch.compile import (
default_decompositions,
min_cut_rematerialization_partition,
ts_compile,
)
kwargs = {
# these are taken from memory_efficient_fusion()
"fw_compiler": ts_compile,
"bw_compiler": ts_compile,
"partition_fn": min_cut_rematerialization_partition,
}
if use_decomps:
kwargs["decompositions"] = default_decompositions
return kwargs
def fake_tensor_unsupported(fn):
"""
Decorator for backends that need real inputs. We swap out fake
tensors for zero tensors.
"""
@functools.wraps(fn)
def wrapper(model, inputs, **kwargs):
with _disable_current_modes():
inputs = list(map(defake, inputs))
return fn(model, inputs, **kwargs)
return wrapper
def device_from_inputs(example_inputs) -> torch.device:
for x in example_inputs:
if hasattr(x, "device"):
return x.device
def dtype_from_inputs(example_inputs) -> torch.dtype:
for x in example_inputs:
if hasattr(x, "dtype"):
return x.dtype

View File

@ -0,0 +1,256 @@
# mypy: ignore-errors
import functools
from collections import defaultdict
from typing import Dict, List, Optional
import torch
from torch._dynamo import config
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import boxed_nop
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
check_multiple_devices_or_any_cpu_nodes,
format_default_skip_message,
get_mutation_stack_trace,
get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.utils import (
BoxedBool,
count_tangents,
get_first_incompatible_cudagraph_node,
num_fw_fixed_arguments,
output_node,
)
from torch.multiprocessing.reductions import StorageWeakRef
from .registry import register_backend
def find_input_mutations(g):
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
if isinstance(meta_fk(n.meta), torch.Tensor):
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if not hasattr(n.target, "_schema"):
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
def get_device_node_mapping(gm: torch.fx.GraphModule):
device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
device_node_mapping[t.device] = n
return device_node_mapping
def check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model: torch.fx.GraphModule, num_fixed
) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
return None
placeholders = get_placeholder_info(aot_model.graph)
return get_mutation_stack_trace(placeholders, mutation_indices)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
if not config.cudagraph_backend_support_input_mutation:
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model, num_fixed
):
return mut_skip
if skip := check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
):
return skip
if node := get_first_incompatible_cudagraph_node(aot_model):
return format_default_skip_message(f"incompatible op ({node.name})")
return None
def get_device_index(gm) -> int:
device = next(iter(get_device_node_mapping(gm)))
assert device.type == "cuda"
return device.index
def get_stack_traces(gm) -> List[Optional[str]]:
output = output_node(gm)
assert len(output.args) == 1
return [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
]
def cudagraphs(dynamo_model, dynamo_inputs):
from torch._inductor.cudagraph_trees import cudagraphify_impl
do_cudagraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
BoxedBool.disable(do_cudagraphs)
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {skip_msg}"
)
return interp
boxed_device_index.set(get_device_index(aot_model))
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
def backward_cudagraphs(aot_model, aot_inputs):
interp = boxed_nop(aot_model, aot_inputs)
if not do_cudagraphs:
return aot_model
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
log_cudagraph_skip_and_bump_counter(
"skipping cudagraphs due to %s", skip_msg
)
# See [Backward Generation Handling]
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_device_index.value, create_if_none_exists=False
)
assert manager is not None
def fn(inputs):
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True
return fn
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=get_device_index(aot_model),
is_backward=True,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
aot_cudagraphs = aot_autograd(
fw_compiler=forward_cudagraphs,
bw_compiler=backward_cudagraphs,
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_cudagraphs(dynamo_model, dynamo_inputs)
class CudagraphsBackend:
compiler_name = "cudagraphs"
@staticmethod
def reset():
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
@staticmethod
def __call__(model, inputs):
return cudagraphs(model, inputs)
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
# for debugging and can serve as a perf baseline.
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
static_inputs = [torch.zeros_like(x) for x in inputs]
else:
static_inputs = list(inputs)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
model(*inputs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs):
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):
dst.copy_(src)
graph.replay()
if copy_outputs:
return [x.clone() for x in static_outputs]
else:
return static_outputs
return run

View File

@ -0,0 +1,336 @@
# mypy: ignore-errors
import dataclasses
import functools
import logging
from importlib import import_module
from typing import Any, List, Optional
import torch
from functorch.compile import min_cut_rematerialization_partition
from torch import _guards
from torch._functorch import config as functorch_config
from torch._functorch.compilers import ts_compile
from .common import aot_autograd
from .registry import register_debug_backend as register_backend
log = logging.getLogger(__name__)
"""
This file contains TorchDynamo backends intended for debugging uses.
"""
@register_backend
def eager(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("eager backend ignoring extra kwargs %s", kwargs)
return gm.forward
@register_backend
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
# This backend is intended to check that dynamo-generated GraphModules
# do not cause errors.
def inner(*args):
try:
return gm(*args)
except Exception as e:
raise torch._dynamo.exc.TorchDynamoException(
"Unexpected exception when running generated GraphModule"
) from e
return inner
@register_backend
def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
from torch.fx.experimental.proxy_tensor import make_fx
def runnable_gm(*args):
return torch.fx.Interpreter(gm).run(*args)
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
pre_dispatch_gm.print_readable()
return pre_dispatch_gm
@register_backend
def eager_debug(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
from torch._subclasses.schema_check_mode import SchemaCheckMode
# We could add more debugging bits here.
# Right now, this backend can be used to check for and error on
# custom dispatcher ops that have incorrect schemas.
def inner(*args):
with SchemaCheckMode():
return torch.fx.Interpreter(gm).run(*args)
return inner
@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
return torch.jit.script(gm)
# used boxed call to discard inputs when they are no longer needed
def boxed_nop(fx_g, example_inputs):
def run(args):
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
return run
# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
aot_eager = aot_autograd(
fw_compiler=boxed_nop,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True,
)
register_backend(name="aot_eager", compiler_fn=aot_eager)
aot_eager_default_partitioner = aot_autograd(
fw_compiler=boxed_nop, keep_inference_input_mutations=True
)
register_backend(
name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
)
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
# inductor problems.
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning(
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
)
with functorch_config.patch(unlift_effect_tokens=True):
return aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=boxed_nop,
bw_compiler=boxed_nop,
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
)(gm, fake_tensor_inputs)
register_backend(
name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
)
# AOT Autograd with torchscript backend. Default partitioner.
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...)
aot_ts = aot_autograd(fw_compiler=ts_compile)
register_backend(name="aot_ts", compiler_fn=aot_ts)
# These buggy backends are used for inducing bugs so that we can test
# our repro extraction / minifier scripts
class ReluCompileError(Exception):
pass
class TestingOnlyCompileError(Exception):
pass
@register_backend
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
raise ReluCompileError
return gm
@register_backend
def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch._assert
node.args = (False, "ReluRuntimeError")
gm.recompile()
return gm
@register_backend
def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch.add
node.args = (node.args[0], 1)
gm.recompile()
return gm
@register_backend
def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
# Require at least one non-trivial thing in the graph,
# see https://github.com/pytorch/pytorch/issues/102898
for node in gm.graph.nodes:
if node.op == "call_function":
break
else:
return gm
for t in example_inputs:
if not t.is_leaf:
raise TestingOnlyCompileError
return gm
@dataclasses.dataclass
class ExplainOutput:
"""
This is the output of :func:`torch._dynamo.explain()`
There is no reason to create this class directly.
"""
graphs: List[torch.fx.GraphModule]
graph_count: int
graph_break_count: int
break_reasons: List[
Any
] # Type is GraphCompileReason but doesn't matter for this purpose
op_count: int
ops_per_graph: Optional[List[torch.fx.Node]] = None
out_guards: Optional[List[_guards.Guard]] = None
compile_times: Optional[str] = None
def __str__(self) -> str:
output = f"Graph Count: {self.graph_count}\n"
output += f"Graph Break Count: {self.graph_break_count}\n"
output += f"Op Count: {self.op_count}\n"
output += "Break Reasons:\n"
for idx, break_reason in enumerate(self.break_reasons):
output += f" Break Reason {idx+1}:\n"
output += f" Reason: {break_reason.reason}\n"
output += " User Stack:\n"
for frame_summary in break_reason.user_stack:
output += f" {frame_summary}\n"
if self.ops_per_graph is not None:
output += "Ops per Graph:\n"
for idx, ops in enumerate(self.ops_per_graph):
output += f" Ops {idx+1}:\n"
for op in ops:
output += f" {op}\n"
if self.out_guards is not None:
output += "Out Guards:\n"
for i, guard in enumerate(self.out_guards):
output += f" Guard {i+1}:\n"
output += f" {str(guard)}"
if self.compile_times is not None:
output += f"Compile Times: {self.compile_times}\n"
return output
def _explain_graph_detail(
gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
):
"""
This function is a utility which processes a torch.fx.GraphModule and
accumulates information about its ops, graph breaks, and other details. It
is intended to be used by the ExplainWithBackend class and
`torch._dynamo.explain()` to provide details from Dynamo's graph capture.
Parameters:
gm (torch.fx.GraphModule): The GraphModule to be processed.
graphs (list): A list that accumulates all the GraphModules processed.
op_count (int): The total count of operations in all GraphModules processed so far.
ops_per_graph (list): A list that accumulates the operations of each GraphModule.
break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule.
Returns:
tuple: A tuple containing the processed GraphModule, the updated lists of graphs,
operations per graph, and break reasons, and the updated operation count.
"""
graphs.append(gm)
ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
op_count += len(ops)
ops_per_graph.append(ops)
if gm.compile_subgraph_reason.graph_break:
break_reasons.append(gm.compile_subgraph_reason)
return gm, graphs, op_count, ops_per_graph, break_reasons
class ExplainWithBackend:
"""
This class is intended to be used as a backend for `torch.compile`. It is
composable with other backends. When used in this way, it accumulates
information about graph breaks, ops, and other info and provides a string
representation summarizing this information.
Attributes:
backend (str): The name of the backend to use for optimization.
graphs (list): A list of the graphs captured by TorchDynamo.
op_count (int): The total number of operations in all optimized graphs.
break_reasons (list): A list of graph break reasons with stack traces.
Example Usage:
def fn(x):
x = torch.sigmoid(x)
return x
torch._dynamo.reset()
eb = ExplainWithBackend("inductor")
optimized_fn = torch.compile(fn, backend=eb)
result = optimized_fn(torch.randn(5))
print(eb.output())
"""
def __init__(self, backend) -> None:
from .registry import lookup_backend
self.backend = lookup_backend(backend)
self.graphs = []
self.op_count = 0
self.break_reasons = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
gm, self.graphs, self.op_count, [], self.break_reasons
)
return self.backend(gm, example_inputs)
def output(self) -> ExplainOutput:
graph_count = len(self.graphs)
output = ExplainOutput(
self.graphs,
graph_count,
graph_count - 1,
self.break_reasons,
self.op_count,
)
return output

View File

@ -0,0 +1,552 @@
# mypy: ignore-errors
import logging
import traceback
from dataclasses import dataclass, field
from typing import Any, List, Optional
from unittest import mock
import torch
from torch import fx
from torch._dynamo.output_graph import GraphCompileReason
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
from torch._logging import trace_structured
from torch.fx.node import Node
# Regular log messages should go through 'log'.
# ddp_graph_log is a separate artifact logger reserved for dumping graphs.
# See docs/source/logging.rst for more info.
log = logging.getLogger(__name__)
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
def args_str(args):
# a debug helper
if torch.is_tensor(args):
return f"T[{args.shape}]"
elif isinstance(args, tuple):
return f"tuple({', '.join([args_str(x) for x in args])})"
elif isinstance(args, list):
return f"list({', '.join([args_str(x) for x in args])})"
else:
return str(args)
@dataclass
class Bucket:
size: int = 0
params: List[str] = field(default_factory=list)
nodes: List[fx.Node] = field(default_factory=list)
# param_ids is just used for unit testing
param_ids: List = field(default_factory=list)
# keep track of any buckets that were extended for logging purposes
opcount_increased_to_capture_external_output: int = 0
paramsize_before_opcount_increase: int = 0
def bucket_has_external_output(bucket: Bucket) -> bool:
nodes_in_bucket = set()
# we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
# so we don't reverse it here
for node in bucket.nodes:
# assume node.op != output, since those are filtered in the original iteration
nodes_in_bucket.add(node)
for user in node.users:
if user not in nodes_in_bucket:
return True
return False
def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
headers = ("Index", "Size (b)", "Param Names")
rows = []
extended_buckets = []
for idx, bucket in enumerate(reversed(buckets)):
if len(bucket.params) > 0:
rows.append((idx, bucket.size, bucket.params[0]))
for param in bucket.params[1:]:
rows.append((None, None, param))
if bucket.opcount_increased_to_capture_external_output > 0:
extended_buckets.append(
(
idx,
bucket.opcount_increased_to_capture_external_output,
bucket.size - bucket.paramsize_before_opcount_increase,
)
)
if len(rows):
log.info(
"\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
bucket_bytes_cap,
len(buckets),
)
if len(extended_buckets):
log.warning(
"Some buckets were extended beyond their requested parameter capacities"
" in order to ensure each subgraph has an output node, required for fx graph partitioning."
" This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
" and returning no logical outputs. This should not be a problem, unless it results in too few graph"
" partitions for optimal DDP performance."
)
try:
from tabulate import tabulate
log.debug(
"\nDDPOptimizer produced the following bucket assignments:\n%s",
tabulate(rows, headers=headers, tablefmt="simple_grid"),
)
if len(extended_buckets):
log.warning(
"DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
tabulate(
extended_buckets,
headers=("Index", "Extra Ops", "Extra Param Size (b)"),
tablefmt="simple_grid",
),
)
except ImportError:
log.debug(
"Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
)
else:
log.debug("DDPOptimizer captured no parameters and did not split this graph.")
def has_higher_order_op(gm):
# Check if there is a higher order op in the graph
for node in gm.graph.nodes:
if node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if isinstance(maybe_param, torch.fx.GraphModule):
return True
return False
# compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler, fake_mode) -> None:
super().__init__(module)
self.compiler = compiler
self.fake_mode = fake_mode
def compile_submod(self, input_mod, args, kwargs):
"""
Compile the submodule,
using a wrapper to make sure its output is always a tuple,
which is required by AotAutograd based compilers
"""
assert len(kwargs) == 0, "We assume only args for these modules"
class WrapperModule(torch.nn.Module):
def __init__(self, submod, unwrap_singleton_tuple) -> None:
super().__init__()
self.submod = submod
self.unwrap_singleton_tuple = unwrap_singleton_tuple
def forward(self, *args):
x = self.submod(*args)
# TODO(whc)
# for some reason the isinstance check is necessary if I split one node per submod
# - even though I supposedly wrapped the output in a tuple in those cases, the real
# compiled module was still returning a tensor
if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
return x[0]
return x
unwrap_singleton_tuple = False
for sn in input_mod.graph.nodes:
if sn.op == "output":
if not isinstance(sn.args[0], tuple):
unwrap_singleton_tuple = True
sn.args = (sn.args,)
input_mod.recompile()
input_mod.compile_subgraph_reason = GraphCompileReason(
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
[
# it's close to useless to get a real stacktrace here, and quite verbose.
traceback.FrameSummary(__file__, 0, DDPOptimizer),
],
)
wrapper = WrapperModule(
self.compiler(input_mod, args),
unwrap_singleton_tuple,
)
return wrapper
# Note:
#
# The way distributed works today around fake tensors can be somewhat confusing.
# Some of these codepaths are shared in both runtime, and compile time. The presence
# of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
#
# A few things to keep in mind:
#
# 1) We invoke `compile_submod` with a real module. The output of that gets stored
# on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
#
# 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
# module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
#
# 3) Fake tensors should always be around during compile time.
#
# 4) Fake tensors should never be around at runtime.
#
# 5) We end up with a compilation mode that takes a real submodule and fake tensors,
# to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
def run_node(self, n: Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)
new_args = []
assert self.fake_mode
for arg in args:
if isinstance(arg, torch.Tensor) and not isinstance(
arg, torch._subclasses.FakeTensor
):
new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
else:
new_args.append(arg)
log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
if n.op == "call_module":
real_mod = self.fetch_attr(n.target)
if self.fake_mode:
curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
else:
curr_submod = real_mod
ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
# When calling the compiler on the submod, inputs (new_args) are expected to
# be FakeTensors already since Dynamo would have made them FakeTensors in the
# non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
# since this wrapping happens during compilation
# Note: Returning Fake Tensors on First AOT Autograd Call
#
# Inductor will optimize strides of outputs when it deems it profitable.
# For instance, converting to channels last. When we split the graph here
# into multiple inductor compilations, we need to make sure that the
# output strides of one compilation is appropriately passed to the subsequent
# compilations. However, the mapping from inductor output to dynamo output
# is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
# subclass handling, etc. In order to replay all this logic we set a flag such that
# the first invocation of inductor in aot_autograd will return Fake Tensors with
# appropriate strides. Then, all of aot autograd's runtime logic is replayed.
# This gives us the appropriately strided outputs here which will reflect runtime strides.
class FakeifyFirstAOTInvocationGuard:
def __init__(self) -> None:
self.tc = torch._guards.TracingContext.try_get()
assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True
def __del__(self) -> None:
self.tc.fakify_first_call = False
# For aot_eager and other backends, tracing context is not set
has_tracing_context = torch._guards.TracingContext.try_get() is not None
if has_tracing_context:
g = FakeifyFirstAOTInvocationGuard()
from torch._dynamo.utils import counters
init = counters["aot_autograd"]["total"]
compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
# TODO - better way of doing this?
# Only aot autograd handles fakifying first call
invoked_aot_autograd = init != counters["aot_autograd"]["total"]
# We update the original (outer) graph with a call into the compiled module
# instead of the uncompiled one.
self.module.delete_submodule(n.target)
n.target = "compiled_" + n.target
self.module.add_submodule(n.target, compiled_submod_real)
# Finally, we have to produce inputs for use compiling the next submodule,
# and these need to be FakeTensors, so we execute the module under fake_mode
# Because parameters are not fake we patch fake tensor mode to allow non fake inputs
with self.fake_mode, mock.patch.object(
self.fake_mode, "allow_non_fake_inputs", True
):
if has_tracing_context and invoked_aot_autograd:
out = compiled_submod_real(*new_args, **kwargs)
# output should be fake or subclass
assert all(
(not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
for t in (out if isinstance(out, (list, tuple)) else [out])
)
return out
else:
return curr_submod(*new_args, **kwargs)
else:
# placeholder or output nodes don't need to get compiled, just executed
return getattr(self, n.op)(n.target, new_args, kwargs)
class DDPOptimizer:
"""Note [DDPOptimizer]
DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
the boundaries of gradient-allreduce buckets chosen by DDP.
Background/Motivation
- DDP uses allreduce collectives to synchronize partial gradients computed on different workers
- DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
- Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
at around the same time during backward and thus can share the same allreduce efficiently
- Allreduces must overlap with backward compute for optimal training performance
- DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
operates when individual grads become 'ready'
- Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
fused backward function executes, preventing any overlap of compute and communication
Algorithm
- DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
this graph in reverse order to determine the true order that gradients will become ready during backward.
- Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
and a graph break introduced
- Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
into an outer module that is returned to the user
Notes
- It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
in eager.
- If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
degradation approaching the baseline case where graph-splits are not used, but not worse.
- If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
subgraphs being compiled
- DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
- DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
DDPOptimizer)
Debugging
- Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
- In many cases, the log messages are helpful (they show bucket size assignments)-
just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
- See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
in a single process (or with torchrun, in multiple processes)
Args:
bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
set to match the equivalent parameter on the original DDP module.
backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
"""
def __init__(
self,
bucket_bytes_cap: int,
backend_compile_fn,
first_bucket_cap: Optional[int] = None,
) -> None:
if first_bucket_cap is not None:
self.first_bucket_cap = first_bucket_cap
elif torch.distributed.is_available():
# this constant comes from C10D lib which is not always built
self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
else:
self.first_bucket_cap = bucket_bytes_cap
self.bucket_bytes_cap = bucket_bytes_cap
assert (
self.first_bucket_cap <= self.bucket_bytes_cap
), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
self.backend_compile_fn = backend_compile_fn
def _ignore_parameter(self, parameter):
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
def add_param(self, bucket, param, name):
bucket.size += param.untyped_storage().nbytes()
bucket.params.append(name)
bucket.param_ids.append(id(param))
def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
processed_modules.add(mod)
for name, param in mod.named_parameters():
if param.requires_grad and not self._ignore_parameter(param):
self.add_param(bucket, param, f"{prefix}_{name}")
def add_param_args(self, bucket, node):
for arg in node.args:
if not isinstance(arg, torch.fx.node.Node):
continue
if arg.op != "placeholder":
continue
param = arg.meta["example_value"]
if (
isinstance(param, torch.nn.Parameter)
and param.requires_grad
and not self._ignore_parameter(param)
):
self.add_param(bucket, param, arg.target)
def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
"""
Implements graph splitting, first determining a set of of buckets by counting
parameter sizes in reverse graph order, then invoking the user/backend compiler
to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
and returns its callable.
"""
if has_higher_order_op(gm):
# This indicates presence of a higher order op. For now, we
# have no way to break the higher order op into two buckets.
# Allowing higher order ops in the graph also requires
# changes in the split_module, becuase graph splitter
# currently assumes that all the args of all ops are
# tensors, but in the case of higher order ops, it could be
# a graph module. As a workaround, we are shortcircuiting
raise NotImplementedError(
"DDPOptimizer backend: Found a higher order op in the graph. "
"This is not supported. Please turn off DDP optimizer using "
"torch._dynamo.config.optimize_ddp=False. Note that this can "
"cause performance degradation because there will be one bucket "
"for the entire Dynamo graph. Please refer to this issue - "
"https://github.com/pytorch/pytorch/issues/104674."
)
# 1: compute the partition map according to DDP bucket logic
buckets = [Bucket()] # (size, param_names)
processed_modules = set()
for node in reversed(gm.graph.nodes):
if node.op in ("output", "placeholder"):
continue
if (
buckets[0].size >= self.bucket_bytes_cap
or len(buckets) == 1
and buckets[0].size >= self.first_bucket_cap
):
if bucket_has_external_output(buckets[0]):
buckets.insert(0, Bucket())
else:
# continue building this bucket past the point of filling its parameter capacity,
# to increase chances it contains at least one node that is either a global output or
# passed as input to a subsequent graph
if buckets[0].opcount_increased_to_capture_external_output == 0:
buckets[0].paramsize_before_opcount_increase = buckets[0].size
buckets[0].opcount_increased_to_capture_external_output += 1
if node.op == "call_function":
self.add_param_args(buckets[0], node)
elif node.op == "call_module":
target_mod = gm.get_submodule(node.target)
if target_mod not in processed_modules:
self.add_module_params_to_bucket(
target_mod, buckets[0], processed_modules, node.target
)
elif node.op == "call_method":
if isinstance(node.args[0].target, str):
target_mod = None
try:
target_mod = gm.get_submodule(node.args[0].target)
except AttributeError:
pass
if target_mod is not None and target_mod not in processed_modules:
self.add_module_params_to_bucket(
target_mod, buckets[0], processed_modules, node.target
)
# This handles situations like tmp = torch.mm(x, self.weight.t())
# t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None
# tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None
self.add_param_args(buckets[0], node)
elif node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if (
isinstance(maybe_param, torch.nn.Parameter)
and maybe_param.requires_grad
and not self._ignore_parameter(maybe_param)
):
self.add_param(buckets[0], maybe_param, node.target)
# All nodes have to be mapped to a bucket, even if they don't have their own params
# Ignored params still end up in buckets, we just don't count them towards the capacity
buckets[0].nodes.append(node)
if len(buckets) > 1 and buckets[0].size == 0:
# we collected a small preamble graph with ops that don't include parameters, fuse it back
buckets[1].nodes.extend(buckets[0].nodes)
assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
del buckets[0]
# stash buckets for testing/debugging purposes
self.buckets = buckets
pretty_print_buckets(buckets, self.bucket_bytes_cap)
if len(buckets) == 1:
# bypass split/fuse logic if there is only one bucket
return self.backend_compile_fn(gm, example_inputs)
# 2: partition the graphmodule according to bucket capacity
partition_map = {}
for idx, b in enumerate(buckets):
for node in b.nodes:
partition_map[node] = idx
split_gm = fx.passes.split_module.split_module(
gm, None, lambda node: partition_map[node]
)
debug_str = (
f"\n---orig graph---\n{gm.graph}\n"
+ f"\n---split graph---\n{split_gm.graph}\n"
)
for name, module in split_gm.named_modules():
if "." not in name and len(name):
# only print the submod graphs, not their children
debug_str += f"\n---{name} graph---\n{module.graph}\n"
debug_str += "\n---------------\n"
ddp_graph_log.debug(debug_str)
trace_structured(
"optimize_ddp_split_graph",
payload_fn=lambda: split_gm.print_readable(print_output=False),
)
for name, module in split_gm.named_modules():
if "." not in name and len(name):
trace_structured(
"optimize_ddp_split_child",
lambda: {"name": name},
payload_fn=lambda: module.print_readable(print_output=False),
)
fake_mode = detect_fake_mode(example_inputs)
if fake_mode is None:
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
submod_compiler.run(*example_inputs)
split_gm.recompile()
ddp_graph_log.debug(
"\n---final graph---\n%s\n---------------\n", split_gm.graph
)
return split_gm

View File

@ -0,0 +1,12 @@
# mypy: ignore-errors
from torch._dynamo import register_backend
@register_backend
def inductor(*args, **kwargs):
# do import here to avoid loading inductor into memory when it is not used
from torch._inductor.compile_fx import compile_fx
return compile_fx(*args, **kwargs)

View File

@ -0,0 +1,38 @@
# mypy: ignore-errors
# This backend is maintained by ONNX team. To direct issues
# to the right people, please tag related GitHub issues with `module: onnx`.
#
# Maintainers' Github IDs: wschin, xadupre
from torch.onnx._internal.onnxruntime import (
is_onnxrt_backend_supported,
torch_compile_backend,
)
from .registry import register_backend
def has_onnxruntime():
# FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported()
return is_onnxrt_backend_supported()
if is_onnxrt_backend_supported():
register_backend(name="onnxrt", compiler_fn=torch_compile_backend)
else:
def information_displaying_backend(*args, **kwargs):
raise ImportError(
"onnxrt is not registered as a backend. "
"Please make sure all dependencies such as "
"numpy, onnx, onnxscript, and onnxruntime-training are installed. "
"Suggested procedure to fix dependency problem:\n"
" (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n"
" (2) Open a new python terminal.\n"
" (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n"
" (4) If it returns `True`, then you can use `onnxrt` backend.\n"
" (5) If it returns `False`, please execute the package importing section in "
"torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails."
)
register_backend(name="onnxrt", compiler_fn=information_displaying_backend)

View File

@ -0,0 +1,125 @@
# mypy: ignore-errors
import functools
import logging
import sys
from importlib.metadata import EntryPoint
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
import torch
from torch import fx
log = logging.getLogger(__name__)
class CompiledFn(Protocol):
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
...
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
_BACKENDS: Dict[str, Optional[EntryPoint]] = {}
_COMPILER_FNS: Dict[str, CompilerFn] = {}
def register_backend(
compiler_fn: Optional[CompilerFn] = None,
name: Optional[str] = None,
tags: Sequence[str] = (),
):
"""
Decorator to add a given compiler to the registry to allow calling
`torch.compile` with string shorthand. Note: for projects not
imported by default, it might be easier to pass a function directly
as a backend and not use a string.
Args:
compiler_fn: Callable taking a FX graph and fake tensor inputs
name: Optional name, defaults to `compiler_fn.__name__`
tags: Optional set of string tags to categorize backend with
"""
if compiler_fn is None:
# @register_backend(name="") syntax
return functools.partial(register_backend, name=name, tags=tags)
assert callable(compiler_fn)
name = name or compiler_fn.__name__
assert name not in _COMPILER_FNS, f"duplicate name: {name}"
if compiler_fn not in _BACKENDS:
_BACKENDS[name] = None
_COMPILER_FNS[name] = compiler_fn
compiler_fn._tags = tuple(tags)
return compiler_fn
register_debug_backend = functools.partial(register_backend, tags=("debug",))
register_experimental_backend = functools.partial(
register_backend, tags=("experimental",)
)
def lookup_backend(compiler_fn):
"""Expand backend strings to functions"""
if isinstance(compiler_fn, str):
if compiler_fn not in _BACKENDS:
_lazy_import()
if compiler_fn not in _BACKENDS:
from ..exc import InvalidBackend
raise InvalidBackend(name=compiler_fn)
if compiler_fn not in _COMPILER_FNS:
entry_point = _BACKENDS[compiler_fn]
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
compiler_fn = _COMPILER_FNS[compiler_fn]
return compiler_fn
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
"""
Return valid strings that can be passed to:
torch.compile(..., backend="name")
"""
_lazy_import()
exclude_tags = set(exclude_tags or ())
backends = [
name
for name in _BACKENDS.keys()
if name not in _COMPILER_FNS
or not exclude_tags.intersection(_COMPILER_FNS[name]._tags)
]
return sorted(backends)
@functools.lru_cache(None)
def _lazy_import():
from .. import backends
from ..utils import import_submodule
import_submodule(backends)
from ..repro.after_dynamo import dynamo_minifier_backend
assert dynamo_minifier_backend is not None
_discover_entrypoint_backends()
@functools.lru_cache(None)
def _discover_entrypoint_backends():
# importing here so it will pick up the mocked version in test_backends.py
from importlib.metadata import entry_points
group_name = "torch_dynamo_backends"
if sys.version_info < (3, 10):
eps = entry_points()
eps = eps[group_name] if group_name in eps else []
eps = {ep.name: ep for ep in eps}
else:
eps = entry_points(group=group_name)
eps = {name: eps[name] for name in eps.names}
for backend_name in eps:
_BACKENDS[backend_name] = eps[backend_name]

View File

@ -0,0 +1,14 @@
# mypy: ignore-errors
# import torch # type: ignore[import]
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
# from .registry import register_backend # type: ignore[import]
"""
Placeholder for TensorRT backend for dynamo via torch-tensorrt
"""
# @register_backend
# def tensorrt(gm, example_inputs):
# import torch_tensorrt # type: ignore[import]
# pass

View File

@ -0,0 +1,47 @@
# mypy: ignore-errors
import logging
from functorch.compile import make_boxed_func
from ..backends.common import aot_autograd
from .registry import register_backend, register_experimental_backend
log = logging.getLogger(__name__)
@register_experimental_backend
def openxla_eval(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
def openxla_eval_boxed(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
try:
import torch_xla.core.dynamo_bridge as bridge
except ImportError as e:
raise ImportError(
"Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
) from e
compiled_graph = None
def fwd(*args):
nonlocal model
nonlocal compiled_graph
if compiled_graph is None:
compiled_graph = bridge.extract_compiled_graph(model, args)
del model
return compiled_graph(*args)
return make_boxed_func(fwd) if boxed else fwd
openxla = aot_autograd(
fw_compiler=openxla_eval_boxed,
)
register_backend(name="openxla", compiler_fn=openxla)

View File

@ -0,0 +1,194 @@
# mypy: ignore-errors
import functools
import importlib
import logging
import os
import sys
import tempfile
from types import MappingProxyType
from typing import Optional
import torch
from .common import device_from_inputs, fake_tensor_unsupported
from .registry import register_backend
log = logging.getLogger(__name__)
@register_backend
@fake_tensor_unsupported
def tvm(
gm,
example_inputs,
*,
options: Optional[MappingProxyType] = MappingProxyType(
{"scheduler": None, "trials": 20000, "opt_level": 3}
),
):
import tvm # type: ignore[import]
from tvm import relay # type: ignore[import]
from tvm.contrib import graph_executor # type: ignore[import]
jit_mod = torch.jit.trace(gm, example_inputs)
device = device_from_inputs(example_inputs)
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
example_outputs = gm(*example_inputs)
if len(example_outputs) == 0:
log.warning("Explicitly fall back to eager due to zero output")
return gm.forward
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
if device.type == "cuda":
dev = tvm.cuda(device.index)
target = tvm.target.cuda()
else:
dev = tvm.cpu(0)
target = tvm.target.Target(llvm_target())
scheduler = options.get("scheduler", None)
if scheduler is None:
scheduler = os.environ.get("TVM_SCHEDULER", None)
trials = options.get("trials", 20000)
opt_level = options.get("opt_level", 3)
if scheduler == "auto_scheduler":
from tvm import auto_scheduler
log_file = tempfile.NamedTemporaryFile()
if not os.path.exists(log_file):
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target
)
for task in tasks:
print(task.compute_dag)
else:
print("No tasks")
if len(tasks) != 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
if not os.path.exists(log_file):
assert trials > 0
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=trials,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
early_stopping=2000,
)
try:
tuner.tune(tune_option)
except Exception:
if os.path.exists(log_file):
os.unlink(log_file)
raise
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True}
):
lib = relay.build(mod, target=target, params=params)
elif scheduler == "meta_schedule":
from tvm import meta_schedule as ms
with tempfile.TemporaryDirectory() as work_dir:
if device.type != "cuda":
# meta_schedule needs num-cores to be specified
# here we use the maximum core count
target = tvm.target.Target(
f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}"
)
# TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
# once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
assert trials > 0
database = ms.relay_integration.tune_relay(
mod=mod,
target=target,
work_dir=work_dir,
max_trials_global=trials,
num_trials_per_iter=64,
params=params,
strategy="evolutionary",
opt_level=opt_level,
)
lib = ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=target,
params=params,
opt_level=opt_level,
)
elif scheduler == "default" or not scheduler:
# no autotuning
with tvm.transform.PassContext(opt_level=opt_level):
lib = relay.build(mod, target=target, params=params)
else:
raise NotImplementedError(
"This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
"There are three available options: default, auto_scheduler and meta_schedule."
)
m = graph_executor.GraphModule(lib["default"](dev))
def to_torch_tensor(nd_tensor):
"""A helper function to transfer a NDArray to torch.tensor."""
if nd_tensor.dtype == "bool":
# DLPack does not support boolean so it can't be handled by
# torch.utils.dlpack.from_pack. Workaround by going through
# numpy, although this brings additional data copy overhead.
return torch.from_numpy(nd_tensor.numpy())
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
def to_tvm_tensor(torch_tensor):
"""A helper function to transfer a torch.tensor to NDArray."""
if torch_tensor.dtype == torch.bool:
# same reason as above, fallback to numpy conversion which
# could introduce data copy overhead
return tvm.nd.array(torch_tensor.cpu().numpy())
return tvm.nd.from_dlpack(torch_tensor)
def exec_tvm(*i_args):
args = [a.contiguous() for a in i_args]
shape_info, _ = m.get_input_info()
active_inputs = {name for name, _ in shape_info.items()}
for idx, arg in enumerate(args, 0):
if arg.dim() != 0:
if arg.requires_grad:
arg = arg.detach()
inp_name = f"inp_{idx}"
if inp_name not in active_inputs:
log.warning(
"input %s skipped as not found in tvm's runtime library",
inp_name,
)
continue
m.set_input(
inp_name,
to_tvm_tensor(arg),
)
m.run()
return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())]
return exec_tvm
tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
def has_tvm():
try:
importlib.import_module("tvm")
return True
except ImportError:
return False
@functools.lru_cache(None)
def llvm_target():
if sys.platform == "linux":
cpuinfo = open("/proc/cpuinfo").read()
if "avx512" in cpuinfo:
return "llvm -mcpu=skylake-avx512"
elif "avx2" in cpuinfo:
return "llvm -mcpu=core-avx2"
return "llvm"