I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,109 @@
import torch
from . import convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, lookup_backend, register_backend
from .callback import callback_handler, on_compile_end, on_compile_start
from .code_context import code_context
from .convert_frame import replay
from .decorators import (
allow_in_graph,
assume_constant_result,
disable,
disallow_in_graph,
forbid_in_graph,
graph_break,
mark_dynamic,
mark_static,
mark_static_address,
maybe_mark_dynamic,
run,
substitute_in_graph,
)
from .eval_frame import (
_reset_guarded_backend_cache,
explain,
export,
is_dynamo_supported,
is_inductor_supported,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
)
from .external_utils import is_compiling
from .mutation_guard import GenerationTracker
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
# Register polyfill functions
from .polyfills import loader as _ # usort: skip # noqa: F401
__all__ = [
"allow_in_graph",
"assume_constant_result",
"disallow_in_graph",
"forbid_in_graph",
"substitute_in_graph",
"graph_break",
"mark_dynamic",
"maybe_mark_dynamic",
"mark_static",
"mark_static_address",
"optimize",
"optimize_assert",
"export",
"explain",
"run",
"replay",
"disable",
"reset",
"OptimizedModule",
"is_compiling",
"register_backend",
"list_backends",
"lookup_backend",
]
if torch.manual_seed is torch.random.manual_seed:
import torch.jit._builtins
# Wrap manual_seed with the disable decorator.
# Can't do it at its implementation due to dependency issues.
torch.manual_seed = torch._disable_dynamo(torch.manual_seed)
# Add the new manual_seed to the builtin registry.
torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
def reset() -> None:
"""Clear all compile caches and restore initial state"""
with convert_frame.compile_lock:
reset_code_caches()
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
guard_failures.clear()
graph_break_reasons.clear()
resume_execution.ContinueExecutionCache.cache.clear()
_reset_guarded_backend_cache()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()
convert_frame.FRAME_COUNTER = 0
convert_frame.FRAME_COMPILE_COUNTER.clear()
callback_handler.clear()
GenerationTracker.clear()
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
def reset_code_caches() -> None:
"""Clear compile caches that are keyed by code objects"""
with convert_frame.compile_lock:
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):
code = weak_code()
if code:
reset_code(code)
code_context.clear()

View File

@ -0,0 +1,127 @@
# mypy: allow-untyped-defs
import torch
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode
from torch.utils._pytree import tree_map_only
__all__ = ["trace_wrapped"]
# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist:
# if you make_fx trace through this call, we will not actually trace into fn; instead,
# we will directly insert it as a call_function to fn in the graph.
# (Unlike make_fx, Dynamo WILL inline into fn.)
# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing.
#
# Because proxy tensor tracing does not actually run the function, there are
# requirements on the behavior of fn. We are still figuring it out, but here is the current state:
#
# 1) fn SHOULD only take a single argument, which must be a tensor
# 2) fn MUST return a new tensor with the same metadata as the original tensor
# (e.g., zeros_like(input) is a permissible implementation of fn).
# This is verified via an extra assert that is inserted into the traced graph.
# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors
# participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state)
# These requirements stem from the requirement that we need to continue performing proxy tensor tracing,
# which assumes accurate fake tensor metadata, without actually running fn.
# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.
#
# Note that tensors / Python state are allowed to be mutated.
# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake
# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete
# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).
#
# The intended use case for this function is to allow AOTAutograd to defer complex
# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves
# the function call as is in the graph, and only when we Dynamo through the backward graph in
# compiled autograd do we inline into the function.
def trace_wrapped(*args, **kwargs):
with torch.no_grad():
return _trace_wrapped_op(*args, **kwargs)
class TraceWrapped(HigherOrderOperator):
def __init__(self):
super().__init__("trace_wrapped")
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
# TODO(jansel): need to ensure this does not get DCEed
_trace_wrapped_op = TraceWrapped()
def _assert_meta(grad, size, stride, dtype):
assert grad.size() == size, "size mismatch"
assert grad.stride() == stride, "stride mismatch"
assert grad.dtype == dtype, "dtype mismatch"
return grad
@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode)
def inner_trace(mode, *args, bw_state=None, **kwargs):
def self_invoke(*args, **dyn_kwargs):
with torch.no_grad():
return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs)
def unwrap_proxies(x):
if isinstance(x, torch.Tensor):
return mode.tracer.unwrap_proxy(x)
if isinstance(x, (list, tuple)):
return type(x)(map(unwrap_proxies, x))
if x is None:
return None
raise AssertionError(f"unhandled type: {type(x)}")
proxy_kwargs = {}
if bw_state is not None:
assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None
proxy_kwargs["bw_state"] = bw_state.proxy
out_proxy = mode.tracer.create_proxy(
"call_function",
self_invoke,
unwrap_proxies(args),
proxy_kwargs,
name="trace_wrapped",
)
if args[0] is None:
grad = args[1] # module backward hooks
else:
grad = args[0] # other backward hooks
grad = tree_map_only(torch.Tensor, torch.empty_like, grad)
track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
return grad
@_trace_wrapped_op.py_impl(FakeTensorMode)
def inner_fake(*args, **kwargs):
raise RuntimeError("This op should never be invoked here")
@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def _trace_wrapped_op_dense(*args, fn, **kwargs):
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return fn(*args, **kwargs)
_trace_wrapped_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(_trace_wrapped_op, deferred_error=True)
)
@_trace_wrapped_op.py_functionalize_impl
def _trace_wrapped_functionalized(ctx, *args, **kwargs):
unwrapped_args = ctx.unwrap_tensors(args)
with ctx.redispatch_to_next():
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs))

View File

@ -0,0 +1,128 @@
# mypy: ignore-errors
import contextlib
import functools
import logging
from unittest.mock import patch
import torch
from torch._dynamo import disable
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
from torch._functorch.aot_autograd import aot_module_simplified
from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
class AotAutograd:
def __init__(self, **kwargs) -> None:
self.__name__ = "compiler_fn"
self.kwargs = kwargs
def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
if kwargs:
log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
return flatten_graph_inputs(
gm,
example_inputs,
self,
)
# Hack to get around circular import problems with aot_eager_decomp_partition
if callable(self.kwargs.get("decompositions")):
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
# NB: dont delete counter increment
counters["aot_autograd"]["total"] += 1
use_fallback = False
if use_fallback:
log.debug("Unable to use AOT Autograd because graph has mutation")
counters["aot_autograd"]["not_ok"] += 1
return gm
# OK attempt to compile
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs))
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
self.kwargs["inference_compiler"] = (
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
)
from functorch.compile import nop
from torch._inductor.debug import enable_aot_logging
# debug asserts slow down compile time noticeably,
# So only default them on when the aot_eager backend is used.
if self.kwargs.get("fw_compiler", None) == nop:
patch_config = patch("functorch.compile.config.debug_assert", True)
else:
patch_config = contextlib.nullcontext()
try:
# NB: NOT cloned!
with enable_aot_logging(), patch_config:
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
counters["aot_autograd"]["ok"] += 1
return disable(cg)
except Exception:
counters["aot_autograd"]["not_ok"] += 1
raise
def aot_autograd(**kwargs):
return AotAutograd(**kwargs)
def mem_efficient_fusion_kwargs(use_decomps):
from functorch.compile import (
default_decompositions,
min_cut_rematerialization_partition,
ts_compile,
)
kwargs = {
# these are taken from memory_efficient_fusion()
"fw_compiler": ts_compile,
"bw_compiler": ts_compile,
"partition_fn": min_cut_rematerialization_partition,
}
if use_decomps:
kwargs["decompositions"] = default_decompositions
return kwargs
def fake_tensor_unsupported(fn):
"""
Decorator for backends that need real inputs. We swap out fake
tensors for zero tensors.
"""
@functools.wraps(fn)
def wrapper(model, inputs, **kwargs):
with _disable_current_modes():
inputs = list(map(defake, inputs))
return fn(model, inputs, **kwargs)
return wrapper
def device_from_inputs(example_inputs) -> torch.device:
for x in example_inputs:
if hasattr(x, "device"):
return x.device
def dtype_from_inputs(example_inputs) -> torch.dtype:
for x in example_inputs:
if hasattr(x, "dtype"):
return x.dtype

View File

@ -0,0 +1,256 @@
# mypy: ignore-errors
import functools
from collections import defaultdict
from typing import Dict, List, Optional
import torch
from torch._dynamo import config
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import boxed_nop
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
check_multiple_devices_or_any_cpu_nodes,
format_default_skip_message,
get_mutation_stack_trace,
get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.utils import (
BoxedBool,
count_tangents,
get_first_incompatible_cudagraph_node,
num_fw_fixed_arguments,
output_node,
)
from torch.multiprocessing.reductions import StorageWeakRef
from .registry import register_backend
def find_input_mutations(g):
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
if isinstance(meta_fk(n.meta), torch.Tensor):
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if not hasattr(n.target, "_schema"):
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
def get_device_node_mapping(gm: torch.fx.GraphModule):
device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
device_node_mapping[t.device] = n
return device_node_mapping
def check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model: torch.fx.GraphModule, num_fixed
) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
return None
placeholders = get_placeholder_info(aot_model.graph)
return get_mutation_stack_trace(placeholders, mutation_indices)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
if not config.cudagraph_backend_support_input_mutation:
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model, num_fixed
):
return mut_skip
if skip := check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
):
return skip
if node := get_first_incompatible_cudagraph_node(aot_model):
return format_default_skip_message(f"incompatible op ({node.name})")
return None
def get_device_index(gm) -> int:
device = next(iter(get_device_node_mapping(gm)))
assert device.type == "cuda"
return device.index
def get_stack_traces(gm) -> List[Optional[str]]:
output = output_node(gm)
assert len(output.args) == 1
return [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
]
def cudagraphs(dynamo_model, dynamo_inputs):
from torch._inductor.cudagraph_trees import cudagraphify_impl
do_cudagraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
BoxedBool.disable(do_cudagraphs)
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {skip_msg}"
)
return interp
boxed_device_index.set(get_device_index(aot_model))
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
def backward_cudagraphs(aot_model, aot_inputs):
interp = boxed_nop(aot_model, aot_inputs)
if not do_cudagraphs:
return aot_model
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
log_cudagraph_skip_and_bump_counter(
"skipping cudagraphs due to %s", skip_msg
)
# See [Backward Generation Handling]
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_device_index.value, create_if_none_exists=False
)
assert manager is not None
def fn(inputs):
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True
return fn
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=get_device_index(aot_model),
is_backward=True,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
aot_cudagraphs = aot_autograd(
fw_compiler=forward_cudagraphs,
bw_compiler=backward_cudagraphs,
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_cudagraphs(dynamo_model, dynamo_inputs)
class CudagraphsBackend:
compiler_name = "cudagraphs"
@staticmethod
def reset():
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
@staticmethod
def __call__(model, inputs):
return cudagraphs(model, inputs)
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
# for debugging and can serve as a perf baseline.
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
static_inputs = [torch.zeros_like(x) for x in inputs]
else:
static_inputs = list(inputs)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
model(*inputs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs):
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):
dst.copy_(src)
graph.replay()
if copy_outputs:
return [x.clone() for x in static_outputs]
else:
return static_outputs
return run

View File

@ -0,0 +1,336 @@
# mypy: ignore-errors
import dataclasses
import functools
import logging
from importlib import import_module
from typing import Any, List, Optional
import torch
from functorch.compile import min_cut_rematerialization_partition
from torch import _guards
from torch._functorch import config as functorch_config
from torch._functorch.compilers import ts_compile
from .common import aot_autograd
from .registry import register_debug_backend as register_backend
log = logging.getLogger(__name__)
"""
This file contains TorchDynamo backends intended for debugging uses.
"""
@register_backend
def eager(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("eager backend ignoring extra kwargs %s", kwargs)
return gm.forward
@register_backend
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
# This backend is intended to check that dynamo-generated GraphModules
# do not cause errors.
def inner(*args):
try:
return gm(*args)
except Exception as e:
raise torch._dynamo.exc.TorchDynamoException(
"Unexpected exception when running generated GraphModule"
) from e
return inner
@register_backend
def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
from torch.fx.experimental.proxy_tensor import make_fx
def runnable_gm(*args):
return torch.fx.Interpreter(gm).run(*args)
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
pre_dispatch_gm.print_readable()
return pre_dispatch_gm
@register_backend
def eager_debug(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
from torch._subclasses.schema_check_mode import SchemaCheckMode
# We could add more debugging bits here.
# Right now, this backend can be used to check for and error on
# custom dispatcher ops that have incorrect schemas.
def inner(*args):
with SchemaCheckMode():
return torch.fx.Interpreter(gm).run(*args)
return inner
@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
return torch.jit.script(gm)
# used boxed call to discard inputs when they are no longer needed
def boxed_nop(fx_g, example_inputs):
def run(args):
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
return run
# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
aot_eager = aot_autograd(
fw_compiler=boxed_nop,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True,
)
register_backend(name="aot_eager", compiler_fn=aot_eager)
aot_eager_default_partitioner = aot_autograd(
fw_compiler=boxed_nop, keep_inference_input_mutations=True
)
register_backend(
name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
)
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
# inductor problems.
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
if kwargs:
log.warning(
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
)
with functorch_config.patch(unlift_effect_tokens=True):
return aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=boxed_nop,
bw_compiler=boxed_nop,
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
)(gm, fake_tensor_inputs)
register_backend(
name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
)
# AOT Autograd with torchscript backend. Default partitioner.
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...)
aot_ts = aot_autograd(fw_compiler=ts_compile)
register_backend(name="aot_ts", compiler_fn=aot_ts)
# These buggy backends are used for inducing bugs so that we can test
# our repro extraction / minifier scripts
class ReluCompileError(Exception):
pass
class TestingOnlyCompileError(Exception):
pass
@register_backend
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
raise ReluCompileError
return gm
@register_backend
def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch._assert
node.args = (False, "ReluRuntimeError")
gm.recompile()
return gm
@register_backend
def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch.add
node.args = (node.args[0], 1)
gm.recompile()
return gm
@register_backend
def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
# Require at least one non-trivial thing in the graph,
# see https://github.com/pytorch/pytorch/issues/102898
for node in gm.graph.nodes:
if node.op == "call_function":
break
else:
return gm
for t in example_inputs:
if not t.is_leaf:
raise TestingOnlyCompileError
return gm
@dataclasses.dataclass
class ExplainOutput:
"""
This is the output of :func:`torch._dynamo.explain()`
There is no reason to create this class directly.
"""
graphs: List[torch.fx.GraphModule]
graph_count: int
graph_break_count: int
break_reasons: List[
Any
] # Type is GraphCompileReason but doesn't matter for this purpose
op_count: int
ops_per_graph: Optional[List[torch.fx.Node]] = None
out_guards: Optional[List[_guards.Guard]] = None
compile_times: Optional[str] = None
def __str__(self) -> str:
output = f"Graph Count: {self.graph_count}\n"
output += f"Graph Break Count: {self.graph_break_count}\n"
output += f"Op Count: {self.op_count}\n"
output += "Break Reasons:\n"
for idx, break_reason in enumerate(self.break_reasons):
output += f" Break Reason {idx+1}:\n"
output += f" Reason: {break_reason.reason}\n"
output += " User Stack:\n"
for frame_summary in break_reason.user_stack:
output += f" {frame_summary}\n"
if self.ops_per_graph is not None:
output += "Ops per Graph:\n"
for idx, ops in enumerate(self.ops_per_graph):
output += f" Ops {idx+1}:\n"
for op in ops:
output += f" {op}\n"
if self.out_guards is not None:
output += "Out Guards:\n"
for i, guard in enumerate(self.out_guards):
output += f" Guard {i+1}:\n"
output += f" {str(guard)}"
if self.compile_times is not None:
output += f"Compile Times: {self.compile_times}\n"
return output
def _explain_graph_detail(
gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
):
"""
This function is a utility which processes a torch.fx.GraphModule and
accumulates information about its ops, graph breaks, and other details. It
is intended to be used by the ExplainWithBackend class and
`torch._dynamo.explain()` to provide details from Dynamo's graph capture.
Parameters:
gm (torch.fx.GraphModule): The GraphModule to be processed.
graphs (list): A list that accumulates all the GraphModules processed.
op_count (int): The total count of operations in all GraphModules processed so far.
ops_per_graph (list): A list that accumulates the operations of each GraphModule.
break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule.
Returns:
tuple: A tuple containing the processed GraphModule, the updated lists of graphs,
operations per graph, and break reasons, and the updated operation count.
"""
graphs.append(gm)
ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
op_count += len(ops)
ops_per_graph.append(ops)
if gm.compile_subgraph_reason.graph_break:
break_reasons.append(gm.compile_subgraph_reason)
return gm, graphs, op_count, ops_per_graph, break_reasons
class ExplainWithBackend:
"""
This class is intended to be used as a backend for `torch.compile`. It is
composable with other backends. When used in this way, it accumulates
information about graph breaks, ops, and other info and provides a string
representation summarizing this information.
Attributes:
backend (str): The name of the backend to use for optimization.
graphs (list): A list of the graphs captured by TorchDynamo.
op_count (int): The total number of operations in all optimized graphs.
break_reasons (list): A list of graph break reasons with stack traces.
Example Usage:
def fn(x):
x = torch.sigmoid(x)
return x
torch._dynamo.reset()
eb = ExplainWithBackend("inductor")
optimized_fn = torch.compile(fn, backend=eb)
result = optimized_fn(torch.randn(5))
print(eb.output())
"""
def __init__(self, backend) -> None:
from .registry import lookup_backend
self.backend = lookup_backend(backend)
self.graphs = []
self.op_count = 0
self.break_reasons = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
gm, self.graphs, self.op_count, [], self.break_reasons
)
return self.backend(gm, example_inputs)
def output(self) -> ExplainOutput:
graph_count = len(self.graphs)
output = ExplainOutput(
self.graphs,
graph_count,
graph_count - 1,
self.break_reasons,
self.op_count,
)
return output

View File

@ -0,0 +1,552 @@
# mypy: ignore-errors
import logging
import traceback
from dataclasses import dataclass, field
from typing import Any, List, Optional
from unittest import mock
import torch
from torch import fx
from torch._dynamo.output_graph import GraphCompileReason
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
from torch._logging import trace_structured
from torch.fx.node import Node
# Regular log messages should go through 'log'.
# ddp_graph_log is a separate artifact logger reserved for dumping graphs.
# See docs/source/logging.rst for more info.
log = logging.getLogger(__name__)
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
def args_str(args):
# a debug helper
if torch.is_tensor(args):
return f"T[{args.shape}]"
elif isinstance(args, tuple):
return f"tuple({', '.join([args_str(x) for x in args])})"
elif isinstance(args, list):
return f"list({', '.join([args_str(x) for x in args])})"
else:
return str(args)
@dataclass
class Bucket:
size: int = 0
params: List[str] = field(default_factory=list)
nodes: List[fx.Node] = field(default_factory=list)
# param_ids is just used for unit testing
param_ids: List = field(default_factory=list)
# keep track of any buckets that were extended for logging purposes
opcount_increased_to_capture_external_output: int = 0
paramsize_before_opcount_increase: int = 0
def bucket_has_external_output(bucket: Bucket) -> bool:
nodes_in_bucket = set()
# we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
# so we don't reverse it here
for node in bucket.nodes:
# assume node.op != output, since those are filtered in the original iteration
nodes_in_bucket.add(node)
for user in node.users:
if user not in nodes_in_bucket:
return True
return False
def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
headers = ("Index", "Size (b)", "Param Names")
rows = []
extended_buckets = []
for idx, bucket in enumerate(reversed(buckets)):
if len(bucket.params) > 0:
rows.append((idx, bucket.size, bucket.params[0]))
for param in bucket.params[1:]:
rows.append((None, None, param))
if bucket.opcount_increased_to_capture_external_output > 0:
extended_buckets.append(
(
idx,
bucket.opcount_increased_to_capture_external_output,
bucket.size - bucket.paramsize_before_opcount_increase,
)
)
if len(rows):
log.info(
"\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
bucket_bytes_cap,
len(buckets),
)
if len(extended_buckets):
log.warning(
"Some buckets were extended beyond their requested parameter capacities"
" in order to ensure each subgraph has an output node, required for fx graph partitioning."
" This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
" and returning no logical outputs. This should not be a problem, unless it results in too few graph"
" partitions for optimal DDP performance."
)
try:
from tabulate import tabulate
log.debug(
"\nDDPOptimizer produced the following bucket assignments:\n%s",
tabulate(rows, headers=headers, tablefmt="simple_grid"),
)
if len(extended_buckets):
log.warning(
"DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
tabulate(
extended_buckets,
headers=("Index", "Extra Ops", "Extra Param Size (b)"),
tablefmt="simple_grid",
),
)
except ImportError:
log.debug(
"Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
)
else:
log.debug("DDPOptimizer captured no parameters and did not split this graph.")
def has_higher_order_op(gm):
# Check if there is a higher order op in the graph
for node in gm.graph.nodes:
if node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if isinstance(maybe_param, torch.fx.GraphModule):
return True
return False
# compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler, fake_mode) -> None:
super().__init__(module)
self.compiler = compiler
self.fake_mode = fake_mode
def compile_submod(self, input_mod, args, kwargs):
"""
Compile the submodule,
using a wrapper to make sure its output is always a tuple,
which is required by AotAutograd based compilers
"""
assert len(kwargs) == 0, "We assume only args for these modules"
class WrapperModule(torch.nn.Module):
def __init__(self, submod, unwrap_singleton_tuple) -> None:
super().__init__()
self.submod = submod
self.unwrap_singleton_tuple = unwrap_singleton_tuple
def forward(self, *args):
x = self.submod(*args)
# TODO(whc)
# for some reason the isinstance check is necessary if I split one node per submod
# - even though I supposedly wrapped the output in a tuple in those cases, the real
# compiled module was still returning a tensor
if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
return x[0]
return x
unwrap_singleton_tuple = False
for sn in input_mod.graph.nodes:
if sn.op == "output":
if not isinstance(sn.args[0], tuple):
unwrap_singleton_tuple = True
sn.args = (sn.args,)
input_mod.recompile()
input_mod.compile_subgraph_reason = GraphCompileReason(
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
[
# it's close to useless to get a real stacktrace here, and quite verbose.
traceback.FrameSummary(__file__, 0, DDPOptimizer),
],
)
wrapper = WrapperModule(
self.compiler(input_mod, args),
unwrap_singleton_tuple,
)
return wrapper
# Note:
#
# The way distributed works today around fake tensors can be somewhat confusing.
# Some of these codepaths are shared in both runtime, and compile time. The presence
# of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
#
# A few things to keep in mind:
#
# 1) We invoke `compile_submod` with a real module. The output of that gets stored
# on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
#
# 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
# module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
#
# 3) Fake tensors should always be around during compile time.
#
# 4) Fake tensors should never be around at runtime.
#
# 5) We end up with a compilation mode that takes a real submodule and fake tensors,
# to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
def run_node(self, n: Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)
new_args = []
assert self.fake_mode
for arg in args:
if isinstance(arg, torch.Tensor) and not isinstance(
arg, torch._subclasses.FakeTensor
):
new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
else:
new_args.append(arg)
log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
if n.op == "call_module":
real_mod = self.fetch_attr(n.target)
if self.fake_mode:
curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
else:
curr_submod = real_mod
ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
# When calling the compiler on the submod, inputs (new_args) are expected to
# be FakeTensors already since Dynamo would have made them FakeTensors in the
# non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
# since this wrapping happens during compilation
# Note: Returning Fake Tensors on First AOT Autograd Call
#
# Inductor will optimize strides of outputs when it deems it profitable.
# For instance, converting to channels last. When we split the graph here
# into multiple inductor compilations, we need to make sure that the
# output strides of one compilation is appropriately passed to the subsequent
# compilations. However, the mapping from inductor output to dynamo output
# is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
# subclass handling, etc. In order to replay all this logic we set a flag such that
# the first invocation of inductor in aot_autograd will return Fake Tensors with
# appropriate strides. Then, all of aot autograd's runtime logic is replayed.
# This gives us the appropriately strided outputs here which will reflect runtime strides.
class FakeifyFirstAOTInvocationGuard:
def __init__(self) -> None:
self.tc = torch._guards.TracingContext.try_get()
assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True
def __del__(self) -> None:
self.tc.fakify_first_call = False
# For aot_eager and other backends, tracing context is not set
has_tracing_context = torch._guards.TracingContext.try_get() is not None
if has_tracing_context:
g = FakeifyFirstAOTInvocationGuard()
from torch._dynamo.utils import counters
init = counters["aot_autograd"]["total"]
compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
# TODO - better way of doing this?
# Only aot autograd handles fakifying first call
invoked_aot_autograd = init != counters["aot_autograd"]["total"]
# We update the original (outer) graph with a call into the compiled module
# instead of the uncompiled one.
self.module.delete_submodule(n.target)
n.target = "compiled_" + n.target
self.module.add_submodule(n.target, compiled_submod_real)
# Finally, we have to produce inputs for use compiling the next submodule,
# and these need to be FakeTensors, so we execute the module under fake_mode
# Because parameters are not fake we patch fake tensor mode to allow non fake inputs
with self.fake_mode, mock.patch.object(
self.fake_mode, "allow_non_fake_inputs", True
):
if has_tracing_context and invoked_aot_autograd:
out = compiled_submod_real(*new_args, **kwargs)
# output should be fake or subclass
assert all(
(not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
for t in (out if isinstance(out, (list, tuple)) else [out])
)
return out
else:
return curr_submod(*new_args, **kwargs)
else:
# placeholder or output nodes don't need to get compiled, just executed
return getattr(self, n.op)(n.target, new_args, kwargs)
class DDPOptimizer:
"""Note [DDPOptimizer]
DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
the boundaries of gradient-allreduce buckets chosen by DDP.
Background/Motivation
- DDP uses allreduce collectives to synchronize partial gradients computed on different workers
- DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
- Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
at around the same time during backward and thus can share the same allreduce efficiently
- Allreduces must overlap with backward compute for optimal training performance
- DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
operates when individual grads become 'ready'
- Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
fused backward function executes, preventing any overlap of compute and communication
Algorithm
- DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
this graph in reverse order to determine the true order that gradients will become ready during backward.
- Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
and a graph break introduced
- Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
into an outer module that is returned to the user
Notes
- It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
in eager.
- If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
degradation approaching the baseline case where graph-splits are not used, but not worse.
- If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
subgraphs being compiled
- DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
- DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
DDPOptimizer)
Debugging
- Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
- In many cases, the log messages are helpful (they show bucket size assignments)-
just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
- See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
in a single process (or with torchrun, in multiple processes)
Args:
bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
set to match the equivalent parameter on the original DDP module.
backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
"""
def __init__(
self,
bucket_bytes_cap: int,
backend_compile_fn,
first_bucket_cap: Optional[int] = None,
) -> None:
if first_bucket_cap is not None:
self.first_bucket_cap = first_bucket_cap
elif torch.distributed.is_available():
# this constant comes from C10D lib which is not always built
self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
else:
self.first_bucket_cap = bucket_bytes_cap
self.bucket_bytes_cap = bucket_bytes_cap
assert (
self.first_bucket_cap <= self.bucket_bytes_cap
), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
self.backend_compile_fn = backend_compile_fn
def _ignore_parameter(self, parameter):
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
def add_param(self, bucket, param, name):
bucket.size += param.untyped_storage().nbytes()
bucket.params.append(name)
bucket.param_ids.append(id(param))
def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
processed_modules.add(mod)
for name, param in mod.named_parameters():
if param.requires_grad and not self._ignore_parameter(param):
self.add_param(bucket, param, f"{prefix}_{name}")
def add_param_args(self, bucket, node):
for arg in node.args:
if not isinstance(arg, torch.fx.node.Node):
continue
if arg.op != "placeholder":
continue
param = arg.meta["example_value"]
if (
isinstance(param, torch.nn.Parameter)
and param.requires_grad
and not self._ignore_parameter(param)
):
self.add_param(bucket, param, arg.target)
def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
"""
Implements graph splitting, first determining a set of of buckets by counting
parameter sizes in reverse graph order, then invoking the user/backend compiler
to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
and returns its callable.
"""
if has_higher_order_op(gm):
# This indicates presence of a higher order op. For now, we
# have no way to break the higher order op into two buckets.
# Allowing higher order ops in the graph also requires
# changes in the split_module, becuase graph splitter
# currently assumes that all the args of all ops are
# tensors, but in the case of higher order ops, it could be
# a graph module. As a workaround, we are shortcircuiting
raise NotImplementedError(
"DDPOptimizer backend: Found a higher order op in the graph. "
"This is not supported. Please turn off DDP optimizer using "
"torch._dynamo.config.optimize_ddp=False. Note that this can "
"cause performance degradation because there will be one bucket "
"for the entire Dynamo graph. Please refer to this issue - "
"https://github.com/pytorch/pytorch/issues/104674."
)
# 1: compute the partition map according to DDP bucket logic
buckets = [Bucket()] # (size, param_names)
processed_modules = set()
for node in reversed(gm.graph.nodes):
if node.op in ("output", "placeholder"):
continue
if (
buckets[0].size >= self.bucket_bytes_cap
or len(buckets) == 1
and buckets[0].size >= self.first_bucket_cap
):
if bucket_has_external_output(buckets[0]):
buckets.insert(0, Bucket())
else:
# continue building this bucket past the point of filling its parameter capacity,
# to increase chances it contains at least one node that is either a global output or
# passed as input to a subsequent graph
if buckets[0].opcount_increased_to_capture_external_output == 0:
buckets[0].paramsize_before_opcount_increase = buckets[0].size
buckets[0].opcount_increased_to_capture_external_output += 1
if node.op == "call_function":
self.add_param_args(buckets[0], node)
elif node.op == "call_module":
target_mod = gm.get_submodule(node.target)
if target_mod not in processed_modules:
self.add_module_params_to_bucket(
target_mod, buckets[0], processed_modules, node.target
)
elif node.op == "call_method":
if isinstance(node.args[0].target, str):
target_mod = None
try:
target_mod = gm.get_submodule(node.args[0].target)
except AttributeError:
pass
if target_mod is not None and target_mod not in processed_modules:
self.add_module_params_to_bucket(
target_mod, buckets[0], processed_modules, node.target
)
# This handles situations like tmp = torch.mm(x, self.weight.t())
# t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None
# tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None
self.add_param_args(buckets[0], node)
elif node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if (
isinstance(maybe_param, torch.nn.Parameter)
and maybe_param.requires_grad
and not self._ignore_parameter(maybe_param)
):
self.add_param(buckets[0], maybe_param, node.target)
# All nodes have to be mapped to a bucket, even if they don't have their own params
# Ignored params still end up in buckets, we just don't count them towards the capacity
buckets[0].nodes.append(node)
if len(buckets) > 1 and buckets[0].size == 0:
# we collected a small preamble graph with ops that don't include parameters, fuse it back
buckets[1].nodes.extend(buckets[0].nodes)
assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
del buckets[0]
# stash buckets for testing/debugging purposes
self.buckets = buckets
pretty_print_buckets(buckets, self.bucket_bytes_cap)
if len(buckets) == 1:
# bypass split/fuse logic if there is only one bucket
return self.backend_compile_fn(gm, example_inputs)
# 2: partition the graphmodule according to bucket capacity
partition_map = {}
for idx, b in enumerate(buckets):
for node in b.nodes:
partition_map[node] = idx
split_gm = fx.passes.split_module.split_module(
gm, None, lambda node: partition_map[node]
)
debug_str = (
f"\n---orig graph---\n{gm.graph}\n"
+ f"\n---split graph---\n{split_gm.graph}\n"
)
for name, module in split_gm.named_modules():
if "." not in name and len(name):
# only print the submod graphs, not their children
debug_str += f"\n---{name} graph---\n{module.graph}\n"
debug_str += "\n---------------\n"
ddp_graph_log.debug(debug_str)
trace_structured(
"optimize_ddp_split_graph",
payload_fn=lambda: split_gm.print_readable(print_output=False),
)
for name, module in split_gm.named_modules():
if "." not in name and len(name):
trace_structured(
"optimize_ddp_split_child",
lambda: {"name": name},
payload_fn=lambda: module.print_readable(print_output=False),
)
fake_mode = detect_fake_mode(example_inputs)
if fake_mode is None:
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
submod_compiler.run(*example_inputs)
split_gm.recompile()
ddp_graph_log.debug(
"\n---final graph---\n%s\n---------------\n", split_gm.graph
)
return split_gm

View File

@ -0,0 +1,12 @@
# mypy: ignore-errors
from torch._dynamo import register_backend
@register_backend
def inductor(*args, **kwargs):
# do import here to avoid loading inductor into memory when it is not used
from torch._inductor.compile_fx import compile_fx
return compile_fx(*args, **kwargs)

View File

@ -0,0 +1,38 @@
# mypy: ignore-errors
# This backend is maintained by ONNX team. To direct issues
# to the right people, please tag related GitHub issues with `module: onnx`.
#
# Maintainers' Github IDs: wschin, xadupre
from torch.onnx._internal.onnxruntime import (
is_onnxrt_backend_supported,
torch_compile_backend,
)
from .registry import register_backend
def has_onnxruntime():
# FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported()
return is_onnxrt_backend_supported()
if is_onnxrt_backend_supported():
register_backend(name="onnxrt", compiler_fn=torch_compile_backend)
else:
def information_displaying_backend(*args, **kwargs):
raise ImportError(
"onnxrt is not registered as a backend. "
"Please make sure all dependencies such as "
"numpy, onnx, onnxscript, and onnxruntime-training are installed. "
"Suggested procedure to fix dependency problem:\n"
" (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n"
" (2) Open a new python terminal.\n"
" (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n"
" (4) If it returns `True`, then you can use `onnxrt` backend.\n"
" (5) If it returns `False`, please execute the package importing section in "
"torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails."
)
register_backend(name="onnxrt", compiler_fn=information_displaying_backend)

View File

@ -0,0 +1,125 @@
# mypy: ignore-errors
import functools
import logging
import sys
from importlib.metadata import EntryPoint
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
import torch
from torch import fx
log = logging.getLogger(__name__)
class CompiledFn(Protocol):
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
...
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
_BACKENDS: Dict[str, Optional[EntryPoint]] = {}
_COMPILER_FNS: Dict[str, CompilerFn] = {}
def register_backend(
compiler_fn: Optional[CompilerFn] = None,
name: Optional[str] = None,
tags: Sequence[str] = (),
):
"""
Decorator to add a given compiler to the registry to allow calling
`torch.compile` with string shorthand. Note: for projects not
imported by default, it might be easier to pass a function directly
as a backend and not use a string.
Args:
compiler_fn: Callable taking a FX graph and fake tensor inputs
name: Optional name, defaults to `compiler_fn.__name__`
tags: Optional set of string tags to categorize backend with
"""
if compiler_fn is None:
# @register_backend(name="") syntax
return functools.partial(register_backend, name=name, tags=tags)
assert callable(compiler_fn)
name = name or compiler_fn.__name__
assert name not in _COMPILER_FNS, f"duplicate name: {name}"
if compiler_fn not in _BACKENDS:
_BACKENDS[name] = None
_COMPILER_FNS[name] = compiler_fn
compiler_fn._tags = tuple(tags)
return compiler_fn
register_debug_backend = functools.partial(register_backend, tags=("debug",))
register_experimental_backend = functools.partial(
register_backend, tags=("experimental",)
)
def lookup_backend(compiler_fn):
"""Expand backend strings to functions"""
if isinstance(compiler_fn, str):
if compiler_fn not in _BACKENDS:
_lazy_import()
if compiler_fn not in _BACKENDS:
from ..exc import InvalidBackend
raise InvalidBackend(name=compiler_fn)
if compiler_fn not in _COMPILER_FNS:
entry_point = _BACKENDS[compiler_fn]
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
compiler_fn = _COMPILER_FNS[compiler_fn]
return compiler_fn
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
"""
Return valid strings that can be passed to:
torch.compile(..., backend="name")
"""
_lazy_import()
exclude_tags = set(exclude_tags or ())
backends = [
name
for name in _BACKENDS.keys()
if name not in _COMPILER_FNS
or not exclude_tags.intersection(_COMPILER_FNS[name]._tags)
]
return sorted(backends)
@functools.lru_cache(None)
def _lazy_import():
from .. import backends
from ..utils import import_submodule
import_submodule(backends)
from ..repro.after_dynamo import dynamo_minifier_backend
assert dynamo_minifier_backend is not None
_discover_entrypoint_backends()
@functools.lru_cache(None)
def _discover_entrypoint_backends():
# importing here so it will pick up the mocked version in test_backends.py
from importlib.metadata import entry_points
group_name = "torch_dynamo_backends"
if sys.version_info < (3, 10):
eps = entry_points()
eps = eps[group_name] if group_name in eps else []
eps = {ep.name: ep for ep in eps}
else:
eps = entry_points(group=group_name)
eps = {name: eps[name] for name in eps.names}
for backend_name in eps:
_BACKENDS[backend_name] = eps[backend_name]

View File

@ -0,0 +1,14 @@
# mypy: ignore-errors
# import torch # type: ignore[import]
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
# from .registry import register_backend # type: ignore[import]
"""
Placeholder for TensorRT backend for dynamo via torch-tensorrt
"""
# @register_backend
# def tensorrt(gm, example_inputs):
# import torch_tensorrt # type: ignore[import]
# pass

View File

@ -0,0 +1,47 @@
# mypy: ignore-errors
import logging
from functorch.compile import make_boxed_func
from ..backends.common import aot_autograd
from .registry import register_backend, register_experimental_backend
log = logging.getLogger(__name__)
@register_experimental_backend
def openxla_eval(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
def openxla_eval_boxed(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
try:
import torch_xla.core.dynamo_bridge as bridge
except ImportError as e:
raise ImportError(
"Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
) from e
compiled_graph = None
def fwd(*args):
nonlocal model
nonlocal compiled_graph
if compiled_graph is None:
compiled_graph = bridge.extract_compiled_graph(model, args)
del model
return compiled_graph(*args)
return make_boxed_func(fwd) if boxed else fwd
openxla = aot_autograd(
fw_compiler=openxla_eval_boxed,
)
register_backend(name="openxla", compiler_fn=openxla)

View File

@ -0,0 +1,194 @@
# mypy: ignore-errors
import functools
import importlib
import logging
import os
import sys
import tempfile
from types import MappingProxyType
from typing import Optional
import torch
from .common import device_from_inputs, fake_tensor_unsupported
from .registry import register_backend
log = logging.getLogger(__name__)
@register_backend
@fake_tensor_unsupported
def tvm(
gm,
example_inputs,
*,
options: Optional[MappingProxyType] = MappingProxyType(
{"scheduler": None, "trials": 20000, "opt_level": 3}
),
):
import tvm # type: ignore[import]
from tvm import relay # type: ignore[import]
from tvm.contrib import graph_executor # type: ignore[import]
jit_mod = torch.jit.trace(gm, example_inputs)
device = device_from_inputs(example_inputs)
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
example_outputs = gm(*example_inputs)
if len(example_outputs) == 0:
log.warning("Explicitly fall back to eager due to zero output")
return gm.forward
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
if device.type == "cuda":
dev = tvm.cuda(device.index)
target = tvm.target.cuda()
else:
dev = tvm.cpu(0)
target = tvm.target.Target(llvm_target())
scheduler = options.get("scheduler", None)
if scheduler is None:
scheduler = os.environ.get("TVM_SCHEDULER", None)
trials = options.get("trials", 20000)
opt_level = options.get("opt_level", 3)
if scheduler == "auto_scheduler":
from tvm import auto_scheduler
log_file = tempfile.NamedTemporaryFile()
if not os.path.exists(log_file):
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target
)
for task in tasks:
print(task.compute_dag)
else:
print("No tasks")
if len(tasks) != 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
if not os.path.exists(log_file):
assert trials > 0
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=trials,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
early_stopping=2000,
)
try:
tuner.tune(tune_option)
except Exception:
if os.path.exists(log_file):
os.unlink(log_file)
raise
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True}
):
lib = relay.build(mod, target=target, params=params)
elif scheduler == "meta_schedule":
from tvm import meta_schedule as ms
with tempfile.TemporaryDirectory() as work_dir:
if device.type != "cuda":
# meta_schedule needs num-cores to be specified
# here we use the maximum core count
target = tvm.target.Target(
f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}"
)
# TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
# once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
assert trials > 0
database = ms.relay_integration.tune_relay(
mod=mod,
target=target,
work_dir=work_dir,
max_trials_global=trials,
num_trials_per_iter=64,
params=params,
strategy="evolutionary",
opt_level=opt_level,
)
lib = ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=target,
params=params,
opt_level=opt_level,
)
elif scheduler == "default" or not scheduler:
# no autotuning
with tvm.transform.PassContext(opt_level=opt_level):
lib = relay.build(mod, target=target, params=params)
else:
raise NotImplementedError(
"This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
"There are three available options: default, auto_scheduler and meta_schedule."
)
m = graph_executor.GraphModule(lib["default"](dev))
def to_torch_tensor(nd_tensor):
"""A helper function to transfer a NDArray to torch.tensor."""
if nd_tensor.dtype == "bool":
# DLPack does not support boolean so it can't be handled by
# torch.utils.dlpack.from_pack. Workaround by going through
# numpy, although this brings additional data copy overhead.
return torch.from_numpy(nd_tensor.numpy())
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
def to_tvm_tensor(torch_tensor):
"""A helper function to transfer a torch.tensor to NDArray."""
if torch_tensor.dtype == torch.bool:
# same reason as above, fallback to numpy conversion which
# could introduce data copy overhead
return tvm.nd.array(torch_tensor.cpu().numpy())
return tvm.nd.from_dlpack(torch_tensor)
def exec_tvm(*i_args):
args = [a.contiguous() for a in i_args]
shape_info, _ = m.get_input_info()
active_inputs = {name for name, _ in shape_info.items()}
for idx, arg in enumerate(args, 0):
if arg.dim() != 0:
if arg.requires_grad:
arg = arg.detach()
inp_name = f"inp_{idx}"
if inp_name not in active_inputs:
log.warning(
"input %s skipped as not found in tvm's runtime library",
inp_name,
)
continue
m.set_input(
inp_name,
to_tvm_tensor(arg),
)
m.run()
return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())]
return exec_tvm
tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
def has_tvm():
try:
importlib.import_module("tvm")
return True
except ImportError:
return False
@functools.lru_cache(None)
def llvm_target():
if sys.platform == "linux":
cpuinfo = open("/proc/cpuinfo").read()
if "avx512" in cpuinfo:
return "llvm -mcpu=skylake-avx512"
elif "avx2" in cpuinfo:
return "llvm -mcpu=core-avx2"
return "llvm"

View File

@ -0,0 +1,257 @@
# mypy: allow-untyped-defs
import bisect
import dataclasses
import dis
import sys
from typing import Any, Set, Union
TERMINAL_OPCODES = {
dis.opmap["RETURN_VALUE"],
dis.opmap["JUMP_FORWARD"],
dis.opmap["RAISE_VARARGS"],
# TODO(jansel): double check exception handling
}
if sys.version_info >= (3, 9):
TERMINAL_OPCODES.add(dis.opmap["RERAISE"])
if sys.version_info >= (3, 11):
TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"])
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
else:
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
if sys.version_info >= (3, 12):
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
HASLOCAL = set(dis.haslocal)
HASFREE = set(dis.hasfree)
stack_effect = dis.stack_effect
def get_indexof(insts):
"""
Get a mapping from instruction memory address to index in instruction list.
Additionally checks that each instruction only appears once in the list.
"""
indexof = {}
for i, inst in enumerate(insts):
assert inst not in indexof
indexof[inst] = i
return indexof
def remove_dead_code(instructions):
"""Dead code elimination"""
indexof = get_indexof(instructions)
live_code = set()
def find_live_code(start):
for i in range(start, len(instructions)):
if i in live_code:
return
live_code.add(i)
inst = instructions[i]
if inst.exn_tab_entry:
find_live_code(indexof[inst.exn_tab_entry.target])
if inst.opcode in JUMP_OPCODES:
find_live_code(indexof[inst.target])
if inst.opcode in TERMINAL_OPCODES:
return
find_live_code(0)
# change exception table entries if start/end instructions are dead
# assumes that exception table entries have been propagated,
# e.g. with bytecode_transformation.propagate_inst_exn_table_entries,
# and that instructions with an exn_tab_entry lies within its start/end.
if sys.version_info >= (3, 11):
live_idx = sorted(live_code)
for i, inst in enumerate(instructions):
if i in live_code and inst.exn_tab_entry:
# find leftmost live instruction >= start
start_idx = bisect.bisect_left(
live_idx, indexof[inst.exn_tab_entry.start]
)
assert start_idx < len(live_idx)
# find rightmost live instruction <= end
end_idx = (
bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1
)
assert end_idx >= 0
assert live_idx[start_idx] <= i <= live_idx[end_idx]
inst.exn_tab_entry.start = instructions[live_idx[start_idx]]
inst.exn_tab_entry.end = instructions[live_idx[end_idx]]
return [inst for i, inst in enumerate(instructions) if i in live_code]
def remove_pointless_jumps(instructions):
"""Eliminate jumps to the next instruction"""
pointless_jumps = {
id(a)
for a, b in zip(instructions, instructions[1:])
if a.opname == "JUMP_ABSOLUTE" and a.target is b
}
return [inst for inst in instructions if id(inst) not in pointless_jumps]
def propagate_line_nums(instructions):
"""Ensure every instruction has line number set in case some are removed"""
cur_line_no = None
def populate_line_num(inst):
nonlocal cur_line_no
if inst.starts_line:
cur_line_no = inst.starts_line
inst.starts_line = cur_line_no
for inst in instructions:
populate_line_num(inst)
def remove_extra_line_nums(instructions):
"""Remove extra starts line properties before packing bytecode"""
cur_line_no = None
def remove_line_num(inst):
nonlocal cur_line_no
if inst.starts_line is None:
return
elif inst.starts_line == cur_line_no:
inst.starts_line = None
else:
cur_line_no = inst.starts_line
for inst in instructions:
remove_line_num(inst)
@dataclasses.dataclass
class ReadsWrites:
reads: Set[Any]
writes: Set[Any]
visited: Set[Any]
def livevars_analysis(instructions, instruction):
indexof = get_indexof(instructions)
must = ReadsWrites(set(), set(), set())
may = ReadsWrites(set(), set(), set())
def walk(state, start):
if start in state.visited:
return
state.visited.add(start)
for i in range(start, len(instructions)):
inst = instructions[i]
if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
if "LOAD" in inst.opname or "DELETE" in inst.opname:
if inst.argval not in must.writes:
state.reads.add(inst.argval)
elif "STORE" in inst.opname:
state.writes.add(inst.argval)
elif inst.opname == "MAKE_CELL":
pass
else:
raise NotImplementedError(f"unhandled {inst.opname}")
if inst.exn_tab_entry:
walk(may, indexof[inst.exn_tab_entry.target])
if inst.opcode in JUMP_OPCODES:
walk(may, indexof[inst.target])
state = may
if inst.opcode in TERMINAL_OPCODES:
return
walk(must, indexof[instruction])
return must.reads | may.reads
@dataclasses.dataclass
class FixedPointBox:
value: bool = True
@dataclasses.dataclass
class StackSize:
low: Union[int, float]
high: Union[int, float]
fixed_point: FixedPointBox
def zero(self):
self.low = 0
self.high = 0
self.fixed_point.value = False
def offset_of(self, other, n):
prior = (self.low, self.high)
self.low = min(self.low, other.low + n)
self.high = max(self.high, other.high + n)
if (self.low, self.high) != prior:
self.fixed_point.value = False
def exn_tab_jump(self, depth):
prior = (self.low, self.high)
self.low = min(self.low, depth)
self.high = max(self.high, depth)
if (self.low, self.high) != prior:
self.fixed_point.value = False
def stacksize_analysis(instructions) -> Union[int, float]:
assert instructions
fixed_point = FixedPointBox()
stack_sizes = {
inst: StackSize(float("inf"), float("-inf"), fixed_point)
for inst in instructions
}
stack_sizes[instructions[0]].zero()
for _ in range(100):
if fixed_point.value:
break
fixed_point.value = True
for inst, next_inst in zip(instructions, instructions[1:] + [None]):
stack_size = stack_sizes[inst]
# CALL_FINALLY in Python 3.8 is handled differently when determining stack depth.
# See https://github.com/python/cpython/blob/3.8/Python/compile.c#L5450.
# Essentially, the stack effect of CALL_FINALLY is computed with jump=True,
# but the resulting stack depth is propagated to the next instruction, not the
# jump target.
is_call_finally = (
sys.version_info < (3, 9) and inst.opcode == dis.opmap["CALL_FINALLY"]
)
if inst.opcode not in TERMINAL_OPCODES:
assert next_inst is not None, f"missing next inst: {inst}"
# total stack effect of CALL_FINALLY and END_FINALLY in 3.8 is 0
eff = (
0
if is_call_finally
else stack_effect(inst.opcode, inst.arg, jump=False)
)
stack_sizes[next_inst].offset_of(stack_size, eff)
if inst.opcode in JUMP_OPCODES and not is_call_finally:
stack_sizes[inst.target].offset_of(
stack_size, stack_effect(inst.opcode, inst.arg, jump=True)
)
if inst.exn_tab_entry:
# see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
# on why depth is computed this way.
depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1
stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth)
if False:
for inst in instructions:
stack_size = stack_sizes[inst]
print(stack_size.low, stack_size.high, inst)
low = min(x.low for x in stack_sizes.values())
high = max(x.high for x in stack_sizes.values())
assert fixed_point.value, "failed to reach fixed point"
assert low >= 0
return high

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,185 @@
# mypy: allow-untyped-defs
import logging
import types
import weakref
from dataclasses import dataclass
from typing import Tuple
from torch._guards import CompileId
from . import config
log = logging.getLogger(__name__)
"""
[Note on cache size limit]
Background - TorchDynamo cache is a linked list. Each cache entry is a
(check_fn, out_code, next pointer). These are stored on the f_code's co_extra
scratch space. When a frame is invoked, we walk this linked list and run
check_fn in each cache_entry to decide if the frame needs recompilation. If none
of the check_fn's returns True, we recompile and add a new entry. To ensure we
don't end up recompiling infinitely, we put limits on the cache size.
There are two limits
1) cache_size_limit
2) accumulated_cache_size_limit
Earlier we used to have only limit - maximum number of entries in 1 cache line
(which is now represented by (2) above). So, why do we need two limits? Lets try
to understand that.
In general, we want our cache limit value to be a small number (e.g. 8 or even
lower). This ensures that for frames that cause too many recompilation fall to
eager quickly. However, there is another problem that prevents us from lowering
the value of cache_size_limit. This is due to ID_MATCH'd guards. Today, we put
ID_MATCH guards on nn module if there is a graph break. This means we will have
many recompilations for the same code object because the ID_MATCH guard fails
for different instances of the nn module. This is a common pattern in how models
are authored. Therefore, this requires us to keep the cache_size_limit high.
We resolve this by introducing these two limits. The first limit (1) limits the
number of cache entries that have an ID_MATCH'd guard for an nn module instance.
And, (2)nd limit becomes a safeguard mechanism to have a maximum compilations
for a code object. One important question is - what is the limit for the code
object that does not have any ID_MATCH guard? For such code objects, we choose
(1) as the cache size limit.
Lets take an example to understand how these limits help. Suppose, we have 16
instances of a nn module and we ID_MATCH on the self object. Further, suppose
the inputs to these functions have varying batch size, leading to one
recompilation. In total, there will be 32 recompilations, and therefore 32 cache
entries on the forward code object. In the older case when we had only 1 limit,
our cache size limit must be >= 32 to capture all these recompilations. Now,
suppose there is a separate function in the same program which is very dynamic
and unsuitable for compilation. Such a function will need to undergo 32
compilations to burst the cache and fallback to eager. These 32 recompilations
are too many and we want to fallback for these compilation-unfriendly functions
sooner.
In the new scenario, we can have (1) cache_size_limit = 2, (2)
accumulated_cache_size_limit = 32. This means that each ID_MATCH'd object can
have maximum of two cache entries, and the maximum number of cache entries
(irrespective of ID_MATCH obj) is 32. This covers the case of forward code
object which has 32 recompilations. For the other function, the one unsuitable
for recompilation, our limit is 2. So, we will burst the cache in just 2
recompilations. In this manner, these 2 limits help us resolve the tension
mentioned earlier.
"""
@dataclass
class CacheSizeRelevantForFrame:
"""
We track the number of cache entries that have same id_match objects as the
given frame.
TODO(janimesh) - Consider adding a map from tuple_of_match_ids to count -
https://github.com/pytorch/pytorch/pull/107496#discussion_r1304564682 - this
could be useful for debugging as well.
"""
# Total number of CacheEntry objects in the Dynamo linked list
num_cache_entries: int = 0
# Number of CacheEntry objects having same ID_MATCH'd objects as given frame.
num_cache_entries_with_same_id_matched_objs: int = 0
def will_compilation_exceed(self, limit: int) -> bool:
# Checks if a compilation will exceed the given limit (thats why >=).
return (
self.will_compilation_exceed_accumulated_limit()
or self.will_compilation_exceed_specific_limit(limit)
)
def will_compilation_exceed_accumulated_limit(self) -> bool:
return self.num_cache_entries >= config.accumulated_cache_size_limit
def will_compilation_exceed_specific_limit(self, limit: int) -> bool:
return self.num_cache_entries_with_same_id_matched_objs >= limit
def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str):
obj = frame.f_locals.get(local_name, None)
weak_id = None
try:
weak_id = weakref.ref(obj)
except TypeError:
pass # cannot weakref bool object
return weak_id
def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
"""
Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones
in frame.f_locals.
"""
if not cache_entry:
return False
for (
local_name,
weakref_from_cache_entry,
) in cache_entry.check_fn.id_matched_objs.items():
if weakref_from_cache_entry() is not None:
weakref_from_frame = _get_weakref_from_f_locals(frame, local_name)
if weakref_from_frame != weakref_from_cache_entry:
return False
# Also covers the case where no ID_MATCH objects are saved in frame.f_locals
return True
def compute_cache_size(
frame: types.FrameType, cache_entry
) -> CacheSizeRelevantForFrame:
# Walk the linked list to calculate the cache size
num_cache_entries = 0
num_cache_entries_with_same_id_matched_objs = 0
while cache_entry:
num_cache_entries += 1
# Track the number of cache entries having same ID_MATCH'd objects as
# that of frame.f_locals. This will be used later to compare against the
# cache_size_limit.
if _has_same_id_matched_objs(frame, cache_entry):
num_cache_entries_with_same_id_matched_objs += 1
cache_entry = cache_entry.next
return CacheSizeRelevantForFrame(
num_cache_entries, num_cache_entries_with_same_id_matched_objs
)
def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool:
"""
If the frame (earlier parsed by compute_cache_size) has more than 1 cache
entry with same ID_MATCH'd objects, then its a recompilation.
"""
# Note that you can have multiple entries in the cache but still not a
# recompile, e.g., you can have 64 nn module instances, each one having an
# ID_MATCH guard, and each one having just 1 cache entry in the cache. In
# this case, we can have 64 entries in the cache, but no recompilation
# because there is only one entry for each id_matched_obj.
return cache_size.will_compilation_exceed(1)
def exceeds_cache_size_limit(
cache_size: CacheSizeRelevantForFrame, compile_id: CompileId
) -> Tuple[bool, str]:
"""
Checks if we are exceeding the cache size limit.
"""
if cache_size.will_compilation_exceed_accumulated_limit():
return True, "accumulated_cache_size_limit"
if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit):
return True, "cache_size_limit"
# NOTE this check is needed in the case that the frame's cache doesn't grow
# and we keep recompiling. This can happen if the guard check_fn becomes invalidated,
# e.g. due to guarded objects being freed. This technically makes the
# will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the
# check in case we have a better fix in the future.
if compile_id.frame_compile_id >= config.accumulated_cache_size_limit:
return True, "accumulated_cache_size_limit"
return False, ""

View File

@ -0,0 +1,83 @@
# mypy: allow-untyped-defs
class CompilationCallbackHandler:
def __init__(self):
self.start_callbacks = []
self.end_callbacks = []
def register_start_callback(self, callback):
"""
Register a callback function to be called when the compilation starts.
Args:
- callback (callable): The callback function to register.
"""
self.start_callbacks.append(callback)
return callback
def register_end_callback(self, callback):
"""
Register a callback function to be called when the compilation ends.
Args:
- callback (callable): The callback function to register.
"""
self.end_callbacks.append(callback)
return callback
def remove_start_callback(self, callback):
"""
Remove a registered start callback function.
Args:
- callback (callable): The callback function to remove.
"""
self.start_callbacks.remove(callback)
def remove_end_callback(self, callback):
"""
Remove a registered end callback function.
Args:
- callback (callable): The callback function to remove.
"""
self.end_callbacks.remove(callback)
def run_start_callbacks(self):
"""
Execute all registered start callbacks.
"""
for callback in self.start_callbacks:
callback()
def run_end_callbacks(self):
"""
Execute all registered end callbacks.
"""
for callback in self.end_callbacks:
callback()
def clear(self):
"""
Clear all registered callbacks.
"""
self.start_callbacks.clear()
self.end_callbacks.clear()
callback_handler = CompilationCallbackHandler()
def on_compile_start(callback):
"""
Decorator to register a callback function for the start of the compilation.
"""
callback_handler.register_start_callback(callback)
return callback
def on_compile_end(callback):
"""
Decorator to register a callback function for the end of the compilation.
"""
callback_handler.register_end_callback(callback)
return callback

View File

@ -0,0 +1,30 @@
# mypy: allow-untyped-defs
import types
from .utils import ExactWeakKeyDictionary
class CodeContextDict:
def __init__(self) -> None:
self.code_context = ExactWeakKeyDictionary()
def has_context(self, code: types.CodeType):
return code in self.code_context
def get_context(self, code: types.CodeType):
ctx = self.code_context.get(code)
if ctx is None:
ctx = {}
self.code_context[code] = ctx
return ctx
def pop_context(self, code: types.CodeType):
ctx = self.get_context(code)
self.code_context._remove_id(id(code))
return ctx
def clear(self):
self.code_context.clear()
code_context = CodeContextDict()

View File

@ -0,0 +1,511 @@
# mypy: allow-untyped-defs
import collections
import dataclasses
import re
import sys
import types
from typing import Counter, Dict, List, Optional
import torch.nn
from . import utils
from .bytecode_transformation import (
add_push_null,
add_push_null_call_function_ex,
create_call_function,
create_call_method,
create_dup_top,
create_instruction,
create_load_method,
create_rot_n,
Instruction,
)
from .exc import unimplemented
from .source import AttrSource, Source
from .utils import is_safe_constant, rot_n_helper
from .variables.base import VariableTracker
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
)
from .variables.torch_function import TensorWithTFOverrideVariable
@dataclasses.dataclass
class GraphOutputEntry:
index: int
variable: VariableTracker
class PyCodegen:
"""
Helper class uses for constructing Python bytecode
"""
def __init__(
self,
tx=None,
root: Optional[torch.nn.Module] = None,
graph_output_var: Optional[str] = None,
tempvars=None,
) -> None:
self.root = root
self.top_of_stack: Optional[VariableTracker] = None
self.uses: Counter[VariableTracker] = collections.Counter()
self.graph_outputs: Dict[int, GraphOutputEntry] = {}
self._output: List[Instruction] = []
self.tempvars = tempvars or {}
self.tx = tx
self.graph_output_var = graph_output_var
self.code_options = self.tx.output.code_options
self.cell_and_freevars = self.tx.cell_and_freevars
self.new_var = self.tx.output.new_var
self.mutable_side_effects_from_source = False
self.value_from_source: bool = True
def restore_stack(self, stack_values, *, value_from_source=True):
prior = self.mutable_side_effects_from_source
self.mutable_side_effects_from_source = True
prev = self.value_from_source
self.value_from_source &= value_from_source
try:
self.foreach(stack_values)
finally:
self.mutable_side_effects_from_source = prior
self.value_from_source = prev
def graph_output_vars(self):
return [x.variable for x in self.graph_outputs.values()]
def call_reconstruct(self, value):
res = value.reconstruct(self)
assert res is None, f"reconstruct!=None {value}"
def add_push_null(self, gen_fn, call_function_ex=False):
"""
`gen_fn` generates instructions via PyCodegen methods
that push a single callable to the stack.
`add_push_null` pushes a NULL to the stack before or after the
instructions generated by `gen_fn`, depending on Python version.
Will attempt to use the NULL push bit for instructions
with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
"""
old_len = len(self._output)
if sys.version_info < (3, 13):
# gen_fn may DUP_TOP instead if TOS is not cleared.
# Will cause problems since NULL will be pushed right
# before the generated instructions in <= 3.12
self.clear_tos()
gen_fn()
# inplace modify self._output
added_insts = self._output[old_len:]
del self._output[old_len:]
if call_function_ex:
self._output.extend(add_push_null_call_function_ex(added_insts))
else:
self._output.extend(add_push_null(added_insts))
if sys.version_info >= (3, 13):
# NULL will be at top of stack
self.clear_tos()
def __call__(self, value, allow_cache=True):
"""Generate code such that top-of-stack (TOS) is set to value"""
if isinstance(value, Source):
self.call_reconstruct(value)
self.clear_tos()
return
assert isinstance(value, VariableTracker)
output = self._output
graph_outputs = self.graph_outputs
if self.top_of_stack is value and allow_cache:
output.append(create_dup_top())
return
if self.mutable_side_effects_from_source:
# this is needed to get aliasing relationships right
# value.mutable_local.source will get mutated to hold `value`
# mutable_side_effects_from_source=False is used to codegen the mutation
# mutable_side_effects_from_source=True is used to codegen a reference
from .side_effects import MutableSideEffects
if isinstance(value.mutable_local, MutableSideEffects):
self(value.mutable_local.source)
return
if allow_cache:
if value.mutable_local and value.mutable_local in self.tempvars:
output.append(self.create_load(self.tempvars[value.mutable_local]))
self.top_of_stack = value
return
if self.tempvars.get(value) is not None:
output.append(self.create_load(self.tempvars[value]))
self.top_of_stack = value
return
if value.source is not None and allow_cache and self.value_from_source:
self.call_reconstruct(value.source)
elif value.is_python_constant() and is_safe_constant(
value.as_python_constant()
):
output.append(self.create_load_const(value.as_python_constant()))
elif isinstance(value, TensorWithTFOverrideVariable):
graph_outputs_key = self.add_graph_output(value)
self.add_push_null(
lambda: self.load_import_from(utils.__name__, "to_subclass")
)
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.append(
self.create_load_global(
value.global_mangled_class_name(self.tx), add=True
)
)
output.extend(create_call_function(2, False))
elif (
isinstance(value, SymNodeVariable)
and value.python_type() == float
and not self.tx.export
):
# This is a little unusual; force the output convention to be a
# Tensor here. Don't do this for export because this is
# apparently load bearing for export tests (but I am a bit
# doubtful it actually works in the real world)
# NB: It works to add_graph_output on a computed expression
# as_tensor here, because we memoize as_tensor calls on
# SymNodeVariable!
graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx))
def gen_fn():
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.append(self.create_load_attr("item"))
self.add_push_null(gen_fn)
output.extend(create_call_function(0, False))
elif isinstance(
value,
(
TensorVariable,
SymNodeVariable,
UnspecializedPythonVariable,
NumpyNdarrayVariable,
),
):
graph_outputs_key = self.add_graph_output(value)
if isinstance(value, NumpyNdarrayVariable):
self.add_push_null(
lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
)
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.extend(create_call_function(1, False))
elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
def gen_fn():
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.append(self.create_load_attr("item"))
self.add_push_null(gen_fn)
output.extend(create_call_function(0, False))
else:
self.load_graph_output(graph_outputs[graph_outputs_key].index)
elif isinstance(value, NNModuleVariable):
parts = value.module_key.split(".")
if parts[0] in self.code_options["co_varnames"]:
output.append(self.create_load(parts[0]))
parts = parts[1:]
else:
assert self.root is not None
output.append(self.create_load_output(self.root))
for part in parts:
output.append(self.create_load_attr(part))
else:
self.uses[value] += 1
try:
self.call_reconstruct(value)
except NotImplementedError:
unimplemented(f"reconstruct: {value}")
if allow_cache and value in self.tempvars:
self._output.append(create_dup_top())
self.add_cache(value)
self.top_of_stack = value
def add_graph_output(self, value):
graph_outputs_key = id(value.as_proxy())
if graph_outputs_key not in self.graph_outputs:
self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
len(self.graph_outputs), value
)
return graph_outputs_key
def load_graph_output(self, index):
output = self._output
output.append(self.create_load(self.graph_output_var))
output.append(self._create_load_const(index))
output.append(create_instruction("BINARY_SUBSCR"))
def add_cache(self, value):
var = self.new_var()
self.tempvars[value] = var
if value.mutable_local:
self.tempvars[value.mutable_local] = var
self._output.append(self.create_store(var))
def foreach(self, items):
for i in items:
self(i)
def setup_globally_cached(self, name, value):
"""Store value in a new global"""
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
f_globals = self.tx.f_globals
if name in f_globals:
assert id(f_globals[name]) == id(value)
else:
f_globals[name] = value
return [self.create_load_global(name, add=True)]
def clear_tos(self):
self.top_of_stack = None
def append_output(self, inst):
assert isinstance(inst, Instruction)
self._output.append(inst)
self.clear_tos()
def extend_output(self, insts):
assert all(isinstance(x, Instruction) for x in insts)
self._output.extend(insts)
self.clear_tos()
def get_instructions(self) -> List[Instruction]:
return self._output
def create_load(self, name) -> Instruction:
if name in self.cell_and_freevars():
return create_instruction("LOAD_DEREF", argval=name)
assert name in self.code_options["co_varnames"], f"{name} missing"
return create_instruction("LOAD_FAST", argval=name)
def create_load_closure(self, name) -> Instruction:
assert name in self.cell_and_freevars()
inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
return create_instruction(inst_name, argval=name)
def create_store(self, name) -> Instruction:
if name in self.cell_and_freevars():
return create_instruction("STORE_DEREF", argval=name)
assert name in self.code_options["co_varnames"]
return create_instruction("STORE_FAST", argval=name)
def create_load_global(self, name, add=False) -> Instruction:
if add:
self.tx.output.update_co_names(name)
assert name in self.code_options["co_names"], f"{name} not in co_names"
return create_instruction("LOAD_GLOBAL", argval=name)
def create_load_const(self, value) -> Instruction:
assert is_safe_constant(value), f"unsafe constant {value}"
return self._create_load_const(value)
def _create_load_const(self, value) -> Instruction:
return create_instruction("LOAD_CONST", argval=value)
create_load_output = _create_load_const
def load_method(self, name):
self.tx.output.update_co_names(name)
self.append_output(create_load_method(name))
def call_method(self, nargs):
self.extend_output(create_call_method(nargs))
def create_load_attr(self, name) -> Instruction:
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
return create_instruction("LOAD_ATTR", argval=name)
def load_attr(self, name):
self.append_output(self.create_load_attr(name))
def create_load_attrs(self, names):
return [self.create_load_attr(name) for name in names.split(".")]
def create_store_attr(self, name) -> Instruction:
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
return create_instruction("STORE_ATTR", argval=name)
def store_attr(self, name):
self.append_output(self.create_store_attr(name))
def load_function_name(self, fn_name, push_null, num_on_stack=0):
"""Load the global fn_name on the stack num_on_stack down"""
output = []
if push_null and sys.version_info >= (3, 11):
output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
if num_on_stack > 0:
output.extend(
[
*self.rot_n(num_on_stack + 2),
*self.rot_n(num_on_stack + 2),
]
)
else:
output.extend(
[
self.create_load_global(fn_name, add=True),
*self.rot_n(num_on_stack + 1),
]
)
return output
def rot_n(self, n):
try:
return create_rot_n(n)
except AttributeError:
# desired rotate bytecode doesn't exist, generate equivalent bytecode
return [
create_instruction("BUILD_TUPLE", arg=n),
self._create_load_const(rot_n_helper(n)),
*create_rot_n(2),
create_instruction("CALL_FUNCTION_EX", arg=0),
create_instruction("UNPACK_SEQUENCE", arg=n),
]
def pop_null(self):
# POP_TOP doesn't work for null, so we pop nulls by pushing in a
# nop function, calling it (which consumes the null), and popping the result.
assert sys.version_info >= (3, 11)
return [
self._create_load_const(lambda: None),
# 3.13 swapped NULL and callable
*(
(create_instruction("SWAP", arg=2),)
if sys.version_info >= (3, 13)
else ()
),
*create_call_function(0, False),
create_instruction("POP_TOP"),
]
def pop_top(self):
self.append_output(create_instruction("POP_TOP"))
def call_function(self, nargs: int, push_null: bool):
self.extend_output(create_call_function(nargs, push_null=push_null))
def dup_top(self):
self.append_output(create_dup_top())
def store(self, varname):
self.append_output(self.create_store(varname))
def make_function_with_closure(
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
):
freevars = code.co_freevars
assert freevars
output = self._output
def gen_fn():
for var in freevars:
assert var in self.cell_and_freevars()
inst_name = (
"LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
)
output.append(create_instruction(inst_name, argval=var))
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
if sys.version_info >= (3, 13):
output.extend(
[
create_instruction("MAKE_FUNCTION"),
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
]
)
else:
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
if push_null and sys.version_info >= (3, 11):
self.add_push_null(gen_fn)
output.extend(self.rot_n(num_on_stack + 2))
output.extend(self.rot_n(num_on_stack + 2))
else:
gen_fn()
output.extend(self.rot_n(num_on_stack + 1))
self.clear_tos()
def create_load_python_module(self, mod) -> Instruction:
"""
Generate a LOAD_GLOBAL instruction to fetch a given python module.
"""
output = self.tx.output
global_scope = output.global_scope
name = re.sub(r"^.*[.]", "", mod.__name__)
if global_scope.get(name, None) is mod:
return self.create_load_global(name, add=True)
prefix = f"___module_{name}"
global_name = self.tx.output.install_global_by_id(prefix, mod)
return self.create_load_global(global_name, add=True)
def make_call_generated_code(self, fn_name: str) -> None:
"""Call the generated code function stored in fn_name"""
self.extend_output(self.load_function_name(fn_name, True))
graphargs = self.tx.output.graphargs
for arg in graphargs:
if arg.pass_arg_as_tensor:
self.add_push_null(
lambda: self.extend_output(
[
self.create_load_python_module(torch),
self.create_load_attr("as_tensor"),
]
)
)
self.call_reconstruct(arg)
self.extend_output(create_call_function(1, False))
else:
self.call_reconstruct(arg)
self.extend_output(create_call_function(len(graphargs), False))
def load_import_from(self, module_name, object_name) -> None:
self(AttrSource(self.tx.import_source(module_name), object_name))
def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]:
if sys.version_info >= (3, 13):
output = create_call_function(nargs, push_null)
assert output[-1].opname == "CALL"
output.insert(-1, self.create_load_const(kw_names))
output[-1] = create_instruction("CALL_KW", arg=nargs)
return output
elif sys.version_info >= (3, 11):
output = create_call_function(nargs, push_null)
if sys.version_info >= (3, 12):
idx = -1
expected_inst = "CALL"
else:
idx = -2
expected_inst = "PRECALL"
assert output[idx].opname == expected_inst
kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
output.insert(idx, kw_names_inst)
return output
return [
self.create_load_const(kw_names),
create_instruction("CALL_FUNCTION_KW", arg=nargs),
]
def create_delete(self, value) -> Instruction:
return create_instruction("DELETE_FAST", argval=value)

View File

@ -0,0 +1,533 @@
# mypy: allow-untyped-defs
import contextlib
import functools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
import torch
from torch._dynamo.external_utils import (
call_backward,
call_hook,
FakeCompiledAutogradEngine,
)
from torch._dynamo.source import GetItemSource, LocalSource
from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal
from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import clone_preserve_strides
from torch._subclasses import FakeTensorMode
from torch.fx import GraphModule
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import (
decompose,
disable_autocast_cache,
disable_proxy_modes_tracing,
fetch_object_proxy,
ProxyTorchDispatchMode,
PythonKeyTracer,
track_tensor_tree,
)
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch.fx.traceback import preserve_node_meta, set_stack_trace
from torch.utils._traceback import CapturedTraceback
if TYPE_CHECKING:
from torch.fx.proxy import Proxy
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
def snapshot_verbose_logging_enabled():
return torch._logging._internal.log_state.is_artifact_enabled(
"compiled_autograd_verbose"
)
def cpp_verbose_log_fn(msg: str) -> None:
verbose_log.debug(msg)
def snapshot_cudagraph_enabled():
return torch._inductor.config.triton.cudagraphs
def maybe_clone(x):
if x is not None:
return clone_preserve_strides(x)
return x
class AutogradCompilerInstance:
def __init__(self, compiler_fn) -> None:
self.compiler_fn = compiler_fn
self.stack = contextlib.ExitStack()
self.close = self.stack.close
self.shape_env = ShapeEnv()
self.fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=True,
allow_non_fake_inputs=True,
shape_env=self.shape_env,
)
self.fx_tracer = PythonKeyTracer()
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
self.hooks_proxy: Optional[Proxy] = None
self.graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
def wrap_fake(self, x, source):
assert isinstance(x, torch.Tensor)
return self.fake_tensor_mode.from_tensor(x, source=source)
@staticmethod
def source(name, idx) -> GetItemSource:
return GetItemSource(LocalSource(name), idx)
def begin_capture(
self,
inputs: List[torch.Tensor],
sizes: List[int],
scalars: List[Union[int, float]],
):
counters["compiled_autograd"]["captures"] += 1
self.aot_graph_cls_name: Optional[str] = None
self.aot_graph_infos: Dict[int, Dict[str, Any]] = {}
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = (
self.fx_tracer.create_proxy("placeholder", name, (), {})
for name in self.graph_placeholders
)
# tensor inputs to fake tensors
inputs = [
self.wrap_fake(x, self.source("inputs", idx))
for idx, x in enumerate(inputs)
]
self.bind_tensors_to_proxies(inputs, args_proxy)
# size inputs to symints
sizes = [
self.shape_env.create_unspecified_symint_and_symbol(
val,
self.source("sizes", idx),
DimDynamic.DYNAMIC,
)
for idx, val in enumerate(sizes)
]
self.bind_tensors_to_proxies(sizes, sizes_proxy)
for idx, val in enumerate(scalars):
source = self.source("scalars", idx)
if isinstance(val, int):
scalars[idx] = self.shape_env.create_unspecified_symint_and_symbol(
val,
source,
DimDynamic.DYNAMIC,
)
elif isinstance(val, float):
scalars[idx] = self.shape_env.create_symfloatnode(
self.shape_env.create_unspecified_symbol(
val,
source=source,
dynamic_dim=DimDynamic.DYNAMIC,
),
hint=val,
source=source,
)
else:
raise AssertionError("Unexpected scalar type: ", type(val))
self.bind_tensors_to_proxies(scalars, scalars_proxy)
# TODO(jansel): are all these modes needed?
self.stack.enter_context(decompose({}))
self.stack.enter_context(self.fake_tensor_mode)
self.stack.enter_context(self.proxy_mode)
self.stack.enter_context(disable_autocast_cache())
self.stack.enter_context(preserve_node_meta())
return inputs, sizes, scalars
def proxy_call_backward(
self,
inputs,
output_metadatas,
saved_tensors,
backward_idx: int,
):
assert self.hooks_proxy is not None
backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index]
proxies = self.fx_tracer.create_proxy(
kind="call_function",
target=call_backward,
args=(
backward_c_function,
self.to_proxy(saved_tensors),
*self.to_proxy(inputs),
),
kwargs={},
)
with disable_proxy_modes_tracing():
# create fake Tensors
grad_ins: List[Optional[torch.Tensor]] = []
for output_metadata in output_metadatas:
if output_metadata is None:
grad_ins.append(None)
continue
layout, device, dtype, size = output_metadata
grad_ins.append(
torch.empty(size=size, dtype=dtype, layout=layout, device=device)
)
self.bind_tensors_to_proxies(grad_ins, proxies)
return tuple(grad_ins)
def proxy_call_hook(self, hook, *args, **kwargs):
return self.fx_tracer.create_proxy(
"call_function",
call_hook,
(
hook,
*[self.to_proxy(x) for x in args],
),
kwargs,
)
def tensor_pre_hook(self, inputs, hook_id, i: int):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxy = self.proxy_call_hook(
hook,
inputs[i],
hook_type="tensor_pre_hook",
)
with disable_proxy_modes_tracing():
inputs[i] = maybe_clone(inputs[i])
self.bind_tensors_to_proxies([inputs[i]], [proxy])
return inputs
def pre_hook(self, inputs, hook_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
inputs,
hook_type="pre_hook",
)
with disable_proxy_modes_tracing():
inputs = [maybe_clone(x) for x in inputs]
self.bind_tensors_to_proxies(inputs, proxies)
return inputs
def post_hook(self, outputs, inputs, hook_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
outputs,
inputs,
hook_type="post_hook",
)
with disable_proxy_modes_tracing():
outputs = [maybe_clone(x) for x in outputs]
self.bind_tensors_to_proxies(outputs, proxies)
return outputs
def post_acc_grad_hook(self, input, hook_id):
assert isinstance(input, torch.Tensor)
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxy = self.proxy_call_hook(
hook,
input,
hook_type="post_acc_grad_hook",
)
with disable_proxy_modes_tracing():
input = [maybe_clone(input)]
self.bind_tensors_to_proxies(input, [proxy])
return input
# Note: [Compiled autograd and cudagraphs]
# Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_.
# When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph
# with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the
# scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too.
def move_graph_nodes_to_cuda(self, graph) -> List[int]:
to_move: Dict[int, torch.fx.Node] = {}
has_cuda_inputs = False
nodes = list(graph.nodes)
assert nodes[0].target == "inputs"
inputs = nodes[0]
inputs_users = list(inputs.users.keys())
# input access nodes should immediately follow placeholder nodes
first_getitem_idx = len(self.graph_placeholders)
assert nodes[first_getitem_idx] == inputs_users[0]
last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
assert nodes[last_getitem_idx] == inputs_users[-1]
for i, node in enumerate(inputs_users):
if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
has_cuda_inputs = True
continue
is_cpu = node.meta["val"].device.type == "cpu"
is_scalar = len(node.meta["val"].size()) == 0
if is_cpu and is_scalar:
node_users = list(node.users.keys())
if all(
isinstance(user.target, torch._ops.OpOverload)
and user.target.namespace in ("prims", "aten")
for user in node_users
):
# all users are prims/aten, can move safely
to_move[i] = node
# only move cpu scalars to cuda if there were cuda activations in this graph,
# this is to handle the case where cudagraphs is enabled on a cpu-only graph
if has_cuda_inputs:
for node in to_move.values():
node.meta["val"] = node.meta["val"].cuda()
# return runtime indices we need to move to cuda
return list(to_move.keys())
return []
def end_capture(self, outputs):
self.fx_tracer.create_proxy(
"call_function",
FakeCompiledAutogradEngine._exec_final_callbacks_stub,
(),
{},
)
self.stack.close()
self.fx_tracer.create_node(
"output",
"output",
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
{},
)
self.rename_aot_dispatcher_nodes()
self.reorder_accumulate_grad_nodes()
runtime_inputs_to_move: List[int] = []
if snapshot_cudagraph_enabled():
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
graph = GraphModule(
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
)
set_locals_to_steal(graph, ["inputs"])
lazy_graph_code = lazy_format_graph_code(
"Compiled autograd graph",
graph,
include_device=True,
include_stride=True,
colored=True,
)
compiled_autograd_log.info("%s", lazy_graph_code)
verbose_log.debug("%s", lazy_graph_code)
trace_structured(
"compiled_autograd_graph",
payload_fn=lambda: graph.print_readable(print_output=False),
)
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
global in_compiled_autograd_region
try:
in_compiled_autograd_region = True
for i in runtime_inputs_to_move:
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
return compiled_fn(inputs, sizes, scalars, hooks)
finally:
in_compiled_autograd_region = False
return runtime_wrapper, self.compiler_fn(graph)
def rename_aot_dispatcher_nodes(self):
"""
Renames nodes as they appear in the AOTDispatcher backward graphs, prefixed by AOT id
e.g. AOTDispatcher backward graph X's `sin_Y` -> `aotX_sin_Y`
"""
if self.aot_graph_cls_name is None:
return
def is_similar(a: torch.fx.node.Node, b: torch.fx.node.Node):
target_match = a.target == b.target
if not target_match:
target_match = (
hasattr(a.target, "__name__")
and hasattr(b.target, "__name__")
and a.target.__name__ == b.target.__name__
)
return (
target_match
and a.op == b.op
and a.type == b.type
and len(a.all_input_nodes) == len(b.all_input_nodes)
)
for nodecall_index, info in self.aot_graph_infos.items():
ca_node_start_idx = info["ca_node_start_idx"]
aot_id = info["aot_id"]
aot_graph = info["aot_gm"].graph
# 1. Find the first op from user code in the AOT graph
aot_it = iter(aot_graph.nodes)
aot_node = next(aot_it)
assert aot_node is not None
try:
while aot_node.op != "call_function":
aot_node = next(aot_it)
except StopIteration:
continue
try:
# 2. Find the first op in the compiled autograd graph segment
ca_it = iter(self.fx_tracer.graph.nodes)
for _ in range(ca_node_start_idx):
next(ca_it)
ca_node = next(ca_it)
# Graphs should all end with output node
while ca_node.op != "output" and not is_similar(ca_node, aot_node):
# The compiled autograd graph may contain lazily inserted ops
# We skip those when aligning nodes
ca_node = next(ca_it)
# 3. Keep alligned and rename nodes
while aot_node.op != "output" and ca_node.op != "output":
if not ca_node.users:
# TODO: DCE for compiled autograd graph
ca_node = next(ca_it)
continue
if not is_similar(aot_node, ca_node):
# There should be no lazily inserted ops in the middle of a match
# So any deviation is an error
raise StopIteration
ca_node.name = f"aot{aot_id}_{aot_node.name}"
for i, inp in enumerate(aot_node.all_input_nodes):
ca_node.all_input_nodes[i].name = f"aot{aot_id}_{inp.name}"
aot_node = next(aot_it)
ca_node = next(ca_it)
except StopIteration:
verbose_log.debug(
"Failed to match %s%s (NodeCall %s) nodes with AOT backward graph %s nodes",
self.aot_graph_cls_name,
aot_id,
nodecall_index,
aot_id,
)
def reorder_accumulate_grad_nodes(self):
"""
Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of
the graph. This differs from eager mode, which schedules them as soon as possible. This
pass attempts to reorder the graph to mimic eager behavior.
"""
for node in self.fx_tracer.graph.find_nodes(
op="call_function", target=torch.ops.inductor.accumulate_grad_.default
):
arg = max(node.args) # last arg
if arg is not node.prev and arg.op != "placeholder":
arg.append(node)
def to_proxy(self, t):
if t is None:
return None
if isinstance(t, list):
return [self.to_proxy(x) for x in t]
if isinstance(t, tuple):
return tuple(self.to_proxy(x) for x in t)
# can it be torch.SymInt as the code used to imply?
assert isinstance(t, torch.Tensor)
proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
return proxy_tensor.proxy
def bind_tensors_to_proxies(self, tensors, proxies):
if isinstance(proxies, torch.fx.Proxy):
proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index]
assert len(tensors) == len(proxies)
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
def bind_backward_state(self, index: int):
assert self.hooks_proxy is not None
proxy = self.hooks_proxy[index] # type: ignore[index]
bw_state = BackwardState()
track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
return bw_state
def set_node_origin(
self,
node_name: str,
nodecall_index: int,
pyobj: Optional[torch.autograd.Function],
):
maybe_aot_id = ""
if pyobj is not None:
forward_cls = pyobj._forward_cls # type: ignore[attr-defined]
if hasattr(forward_cls, "_aot_id"):
# backward was created by AOT Dispatcher
self.aot_graph_cls_name = node_name
maybe_aot_id = forward_cls._aot_id
self.aot_graph_infos[nodecall_index] = {
"ca_node_start_idx": len(self.fx_tracer.graph.nodes),
"aot_id": maybe_aot_id,
"aot_gm": forward_cls._lazy_backward_info.bw_module,
}
new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
raw_stack_trace = CapturedTraceback.extract().format()[-1]
new_stack_trace = raw_stack_trace.replace(
"raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code
)
set_stack_trace(new_stack_trace)
# state of the autograd engine dispatch, kept in sync by enable/disable context managers
compiled_autograd_enabled = False
# global flag to check if we are processing graphs produced from a compiled autograd graph
in_compiled_autograd_region = False
@contextlib.contextmanager
def enable(compiler_fn):
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn)
)
if snapshot_verbose_logging_enabled():
torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn)
global compiled_autograd_enabled
compiled_autograd_enabled = True
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
if not prior:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
@contextlib.contextmanager
def disable():
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
global compiled_autograd_enabled
compiled_autograd_enabled = False
try:
yield
finally:
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
# return to starting state of a new process
def reset() -> None:
compiled_autograd_enable = False
assert not in_compiled_autograd_region
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)

View File

@ -0,0 +1,401 @@
# mypy: allow-untyped-defs
# This file establishes the public comptime interface to Dynamo.
# This allows Dynamo users to execute arbitrary Python code while
# Dynamo is symbolically evaluating their original programs.
#
# The goal of the public API is to give users rope, without actually
# leaking private implementation details of Dynamo.
import builtins
import dis
import time
import traceback
from typing import Optional, Union
import torch
from torch.fx.experimental.symbolic_shapes import free_symbols
from .exc import unimplemented
from .variables import NewCellVariable
from .variables.constant import ConstantVariable
from .variables.misc import ClosureVariable
from .variables.tensor import SymNodeVariable
class ComptimeVar:
"""
A ComptimeVar represents a Python value, at some particular point
in time, in the Python code we are symbolically evaluating with
torchdynamo. This must be distinguished from a runtime value, as
at compile-time there are some properties of the variable we
do not know (for example, if the ComptimeVar represents a Tensor,
we only know metadata about the tensor; we do NOT know what the
actual data in the Tensor is.)
"""
def __init__(self, v) -> None:
self.__variable = v
def as_proxy(self):
"""
Returns an fx.Proxy (or tuple/list of fx.Proxy) representing
this variable in the FX graph we are assembling to pass
to the user compiler.
This method only works for variables we actually track in
the FX graph, aka Tensors (and ints, if you are compiling
with dynamic shapes). In particular, if you have a list
or tuple of tensors, you will get a list/tuple of proxies
(not a single proxy representing the entire list/tuple).
"""
return self.__variable.as_proxy()
def is_proxy(self):
"""
Returns True if as_proxy() would succeed.
"""
return self.__variable.is_proxy()
def as_fake(self):
"""
Returns a "fake" value (either a FakeTensor or a SymInt)
representing the variable in question. This only works
for variables that denote Tensor or int. You can use
this to query metadata; e.g., v.as_fake().size(0) will
tell you the compile-time known size of the tensor.
WARNING: Do NOT mutate the returned tensor.
"""
return self.__variable.as_proxy().node.meta["example_value"]
def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]:
"""
Returns the size of the tensor (if dim is None) or the size
at the dimension dim. The returned size may be a SymInt.
"""
return self.as_fake().size(dim)
def python_type(self):
"""
Returns what type(v) would have returned for the variable
at compile time.
"""
return self.__variable.python_type()
def as_python_constant(self):
"""
Returns the Python value this variable would have, but only if it is
completely known at compile-time (e.g., it is constant).
WARNING: Do NOT mutate the returned constant. The returned constant
may or may not correspond to the actual value this variable may take
on at runtime; for example, if the variable in question is a constant
list, we may return a copy of that list.
"""
return self.__variable.as_python_constant()
def is_python_constant(self):
"""
Returns True if as_python_constant would succeed.
"""
return self.__variable.is_python_constant()
def is_dynamic(self):
if isinstance(self.__variable, SymNodeVariable):
fs = free_symbols(self.__variable.sym_num)
return bool(fs)
return False
def force_static(self):
"""
Forces that a value is static, inducing a guard on its specific value
"""
if isinstance(self.__variable, SymNodeVariable):
self.__variable.evaluate_expr()
elif isinstance(self.__variable, ConstantVariable):
# TODO: Maybe complain if this isn't a int/bool/float variable
pass
else:
raise AssertionError(
f"cannot force {self.__variable} ({type(self.__variable)}) static"
)
def _i_will_not_complain_if_bc_breaks_VariableTracker(self):
"""
Returns the internal data structure VariableTracker that Dynamo uses
to represent variables at compile time. There are no BC guarantees on
this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if you rely on
it.
"""
return self.__variable
def __repr__(self) -> str:
return self.__variable.debug_repr()
# TODO: API for adding a custom guard
class ComptimeContext:
"""
This context class provides access to a public API for Dynamo's internals.
If there is something here you would find useful that is missing, please
file a feature request at https://github.com/pytorch/pytorch/
"""
def __init__(self, tx) -> None:
self.__tx = tx
def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar:
"""
Retrieve the compile-time known information about a local.
"""
tx = self.__get_tx(stacklevel)
# This is analogous to LOAD_DEREF
if hasattr(tx, "closure_cells") and name in tx.closure_cells:
cell = tx.closure_cells[name]
if isinstance(cell, ClosureVariable):
return ComptimeVar(tx.output.root_tx.symbolic_locals[cell.name])
else:
return ComptimeVar(tx.output.side_effects.load_cell(cell))
else:
r = tx.symbolic_locals[name]
if isinstance(r, NewCellVariable):
return ComptimeVar(tx.output.side_effects.load_cell(r))
else:
return ComptimeVar(r)
def graph_break(self, msg="ComptimeContext.graph_break"):
"""
Manually trigger a graph break
"""
unimplemented(msg)
def graph(self):
"""
Retrieve the partially constructed FX graph that would be
passed to the user compiler after compilation.
"""
return self.__tx.output.graph
def assert_static(self, val):
"""
Asserts that the int is static (and not dynamic, per dynamic shapes)
"""
assert (
not val.is_dynamic()
), "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)"
def print_graph(self, *, verbose=True, file=None):
"""
Print the partially constructed FX graph that would be passed
to the user compiler after compilation.
"""
print(
self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file
)
def parent(self):
return ComptimeContext(self.__tx.parent)
def __get_tx(self, stacklevel):
tx = self.__tx
for _ in range(stacklevel):
tx = tx.parent
return tx
def print(self, val, *, file=None):
print(repr(val), file=file)
def print_disas(self, *, file=None, stacklevel=0):
"""
Print the current series of opcodes being executed (not including
parent frames), including where you are in the particular opcode
stream.
"""
tx = self.__get_tx(stacklevel)
print(
dis.Bytecode(
tx.f_code,
current_offset=tx.instructions[tx.instruction_pointer].offset,
).dis(),
file=file,
)
def print_value_stack(self, *, file=None, stacklevel=0):
"""
Print the current Python value stack. Note that this is NOT the same
as the traceback; use print_bt() to print that. Note that at
stacklevel=0, this will typically be empty, as comptime cannot
currently be used in an expression context where there would be
intermediates on the stack. If you would find this useful, please
file a bug at https://github.com/pytorch/pytorch/
NB: Stack grows downwards in our print
"""
tx = self.__get_tx(stacklevel)
for s in tx.stack:
print(f"- {s.debug_repr()}", file=file)
def print_locals(self, *, file=None, stacklevel=0):
"""
Print all of the locals available in the current context.
By default this view is very limited; you can get more information
about any individual local using get_local().
"""
tx = self.__get_tx(stacklevel)
for k, v in tx.symbolic_locals.items():
print(f"{k} = {v.debug_repr()}", file=file)
def print_bt(self, *, file=None, stacklevel=0):
"""
Print the user code backtrace, starting at the beginning of the
frame Dynamo started evaluating. Note that this MAY NOT go all
the way to the torch.compile invocation, as we may have done
a graph break and are compiling an intermediate frame as the
starting point. If you think the other behavior would be better,
file a bug at https://github.com/pytorch/pytorch/
"""
stack = []
tx = self.__get_tx(stacklevel)
while tx is not None:
stack.append(tx.frame_summary())
tx = getattr(tx, "parent", None)
print(
"".join(traceback.StackSummary.from_list(reversed(stack)).format()),
file=file,
)
def print_guards(self, *, file=None):
"""
Print the currently installed guards for the Dynamo context.
This does NOT include guards associated with variables that
may or may not be installed in the future if those variables
are used.
"""
# TODO: improve print format, current guard format is extremely
# verbose
print(
"\n".join(f"{repr(guard)}" for guard in sorted(self.__tx.output.guards)),
file=file,
)
def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self):
"""
Returns the internal data structure InstructionTranslator that Dynamo
uses to track state of symbolic evaluation. There are no BC
guarantees on this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if
you rely on it.
"""
return self.__tx
def sleep(self, sec):
time.sleep(sec)
class _Comptime:
@staticmethod
def __call__(fn, fallback_fn=lambda: None):
"""fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise"""
fallback_fn()
# Convenience wrappers that are more compact to use
@staticmethod
def graph_break():
comptime(lambda ctx: ctx.graph_break())
@staticmethod
def print(e):
comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e))
@staticmethod
def print_graph():
comptime(lambda ctx: ctx.print_graph())
@staticmethod
def print_disas(*, stacklevel=0):
comptime(
lambda ctx: ctx.print_disas(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
)
)
@staticmethod
def print_value_stack(*, stacklevel=0):
comptime(
lambda ctx: ctx.print_value_stack(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
)
)
# This is a more useful variant of print_value_stack that can be used
# in an expression context; e.g., x + print_value_stack_and_return(y + z),
# you will see x on the stack prior to the addition operation
@staticmethod
def print_value_stack_and_return(e, *, stacklevel=0):
comptime(
lambda ctx: ctx.print_value_stack(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
)
)
return e
@staticmethod
def print_locals(*, stacklevel=0):
comptime(
lambda ctx: ctx.print_locals(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
)
)
@staticmethod
def print_bt(*, stacklevel=0):
comptime(
lambda ctx: ctx.print_bt(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
)
)
@staticmethod
def print_guards():
comptime(lambda ctx: ctx.print_guards())
@staticmethod
def assert_static(val):
comptime(lambda ctx: ctx.assert_static(ctx.get_local("val")))
@staticmethod
def force_static(val):
comptime(lambda ctx: ctx.get_local("val").force_static())
@staticmethod
def breakpoint():
"""
Like pdb breakpoint(), but drop into pdb whenever this line
of code is compiled by dynamo. Use it by putting
this in your model code::
from torch._dynamo.comptime import comptime
comptime.breakpoint()
And then, inside pdb, you can access 'ctx' to query things
about the compilation context::
(Pdb) !ctx.print_bt()
(Pdb) !ctx.print_locals()
(Pdb) p ctx.get_local("attention").as_fake()
"""
def inner(inner_ctx):
ctx = inner_ctx.parent()
builtins.breakpoint()
comptime(inner)
@staticmethod
def sleep(sec):
comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant()))
comptime = _Comptime()

View File

@ -0,0 +1,490 @@
# mypy: allow-untyped-defs
import getpass
import inspect
import os
import re
import sys
import tempfile
from os.path import abspath, dirname
from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
import torch
def is_fbcode():
return not hasattr(torch.version, "git_version")
# to configure logging for dynamo, aot, and inductor
# use the following API in the torch._logging module
# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
# see this design doc for more detailed info
# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
# the name of a file to write the logs to
# [@compile_ignored: debug]
log_file_name: Optional[str] = None
# [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors
verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend
verify_correctness = False
# need this many ops to create an FX graph
minimum_call_count = 1
# turn on/off DCE pass
dead_code_elimination = True
# disable (for a function) when cache reaches this size
# controls the maximum number of cache entries with a guard on same ID_MATCH'd
# object. It also controls the maximum size of cache entries if they don't have
# any ID_MATCH'd guards.
# [@compile_ignored: runtime_behaviour]
cache_size_limit = 8
# [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps
accumulated_cache_size_limit = 256
# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit
skip_code_recursive_on_cache_limit_hit = True
# whether or not to specialize on int inputs. This only has an effect with
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
# inputs. Note that assume_static_by_default will also cause ints to get
# specialized, so this is mostly useful for export, where we want inputs
# to be dynamic, but accesses to ints should NOT get promoted into inputs.
specialize_int = False
# Whether or not to specialize on float inputs. Dynamo will always promote
# float inputs into Tensor inputs, but at the moment, backends inconsistently
# support codegen on float (this is to be fixed).
specialize_float = True
# legacy config, does nothing now!
dynamic_shapes = True
use_lazy_graph_module = (
os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1"
)
# This is a temporarily flag, which changes the behavior of dynamic_shapes=True.
# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic.
# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API
# see [Note - on the state of mark_dynamic]
assume_static_by_default = True
# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction
# with assume_static_by_default=True.
# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
automatic_dynamic_shapes = True
# This flag changes how the shapes of parameters are treated.
# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic
# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static,
# while the shapes of torch.Tensor are assumed to be dynamic.
force_parameter_static_shapes = True
# This flag ensures that the shapes of a nn module are always assumed to be static
# If the flag is set to True, then the shapes of a nn.module are assumed to be static
# If the flag is set to False, then the shapes of a nn.module can be dynamic
force_nn_module_property_static_shapes = True
# Typically, if you mark_dynamic a dimension, we will error if the dimension
# actually ended up getting specialized. This knob changes the behavior so
# that we don't error at all. This is helpful for our CI where I'm using a
# heuristic to mark batch dimensions as dynamic and the heuristic may get it
# wrong.
allow_ignore_mark_dynamic = False
# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
guard_nn_modules = True
# Uses CPython internal dictionary tags to detect mutation. There is some
# overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag.
# guard_nn_modules unspecializes the nn module instance and adds guard for each
# relevant member of the nn modules. On the other hand,
# guard_nn_modules_using_dict_tags specializes on each nn module instance but
# uses low overhead dict version matching to detect mutations, obviating the
# need to guard on members of the nn modules. With
# guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required
# but kept around for debugging and discussing unspecializing nn module
# variables.
# TODO(janimesh, voz): Remove both of these flags (or atleast guard_nn_modules)
# once we have reached stability for the guard_nn_modules_using_dict_tags.
guard_nn_modules_using_dict_tags = True
# This feature doesn't really work. We offer this flag for experimental
# purposes / if you want to help us build out support.
#
# torchdynamo has limited support for tensor subclasses that implement
# __torch_function__ see [Note: __torch_function__] in torch_function.py.
# Our current support is limited to tensor subclasses
# that DO NOT store metadata on the tensor (in general, dynamo does not
# support Python code that stores extra attributes on tensors at present).
# If your tensor subclass purely changes function call behavior via
# __torch_function__, you can allow torchdynamo to trace into it by
# adding it to traceable_tensor_subclasses. We don't do any safety checks,
# so it is up to you to ensure that your subclass is well behaved. See also
# https://github.com/pytorch/torchdynamo/issues/1948
#
# We do NOT currently support __torch_dispatch__. The implementation is
# currently buggy, the main show stopper for nontrivial use is
# https://github.com/pytorch/torchdynamo/issues/1952
traceable_tensor_subclasses: Set[Type[Any]] = set()
# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager.
# This is a good way to get your model to work one way or another, but you may
# lose optimization opportunities this way. Devs, if your benchmark model is failing
# this way, you should figure out why instead of suppressing it.
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
# Record and write an execution record of the current frame to a file
# if an exception is encountered
# @compile_ignored[debug]
replay_record_enabled = os.environ.get("TORCH_COMPILE_REPLAY_RECORD", "0") == "1"
# Rewrite assert statement in python with torch._assert
rewrite_assert_with_torch_assert = True
# Disable dynamo
disable = os.environ.get("TORCH_COMPILE_DISABLE", False)
# [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo
cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False)
# legacy config, does nothing now!
skipfiles_inline_module_allowlist: Dict[Any, Any] = {}
# If a string representing a PyTorch module is in this ignorelist,
# the `allowed_functions.is_allowed` function will not consider it
# when creating a list of PyTorch functions that will appear in
# FX IR.
allowed_functions_module_string_ignorelist = {
"torch.distributions",
"torch.testing",
"torch._refs",
"torch._prims",
"torch._decomp",
}
# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
# None - Minifier is switched off
# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
# [@compile_ignored: debug]
repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
# Compiler compilation debug info
# 1: Dumps the original graph out to repro.py if compilation fails
# 2: Dumps a minifier_launcher.py if compilation fails.
# 3: Always dumps a minifier_launcher.py. Good for segfaults.
# 4: Dumps a minifier_launcher.py if the accuracy fails.
# [@compile_ignored: debug]
repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
# By default, we try to detect accuracy failure by running both forward
# and backward of a torchdynamo produced graph (if you are using repro_after
# 'dynamo'). This setting forces us to only test the forward graph and
# not the backward graph. This can be helpful if you're trying to debug
# an inference only problem, but the minifier seems to be choking on the
# backwards step
# TODO: Detect this situation automatically so the user doesn't need
# to manually configure this
# [@compile_ignored: debug]
repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
# The tolerance we should use when testing if a compiled graph
# has diverged so that we should treat it as an accuracy failure
# [@compile_ignored: debug]
repro_tolerance = 1e-3
# Whether to ignore non-floating point values when checking accuracy.
# Checking accuracy of non-floating point values such as boolean tensors
# can lead to false positives.
# [@compile_ignored: debug]
repro_ignore_non_fp = os.environ.get("TORCHDYNAMO_REPRO_IGNORE_NON_FP") == "1"
# If True, when testing if two models are the same, we will test them against
# a third fp64 reference and only report a problem if the RMSE relative to the
# fp64 is greater. However, this will use more memory; you may disable this
# if memory usage is too high.
# [@compile_ignored: runtime_behaviour]
same_two_models_use_fp64 = True
# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
# When this flag is set to False, we introduce a graph break instead of capturing.
# This requires dynamic_shapes to be True.
capture_scalar_outputs = os.environ.get("TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS") == "1"
# Not all backends support operators that have dynamic output shape (e.g.,
# nonzero, unique). When this flag is set to False, we introduce a graph
# break instead of capturing. This requires dynamic_shapes to be True.
# If you set this to True, you probably also want capture_scalar_outputs
# (these are separated for historical reasons).
capture_dynamic_output_shape_ops = (
os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1"
)
# hybrid backed unbacked symints
prefer_deferred_runtime_asserts_over_guards = False
# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
# range constraints + dims + derived dims language, we raise constraint violation
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
# and allows complex guards as runtime assertions in the graph.
allow_complex_guards_as_runtime_asserts = False
# By default, dynamo will treat all ints as backed SymInts, which means (1) it
# will wait to see the int change over multiple runs before generalizing and
# (2) it will still always 0/1 specialize an int. When true, this knob
# forces dynamo to treat _length_per_key and _offset_per_key on
# KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that
# they (1) generalize immediately and (2) unsoundly never compare equal to
# 0/1. This is not on by default as AOTAutograd/Inductor cannot currently
# compile this code; however, this can be useful for export.
force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
# false_fn produces code with identical guards.
enforce_cond_guards_match = True
# Specify how to optimize a compiled DDP module. The flag accepts a boolean
# value or a string. There are 4 modes.
# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically
# split model graph into pieces to match DDP bucket sizes to allow DDP
# comm/compute overlap.
# 2. "python_reducer" (experimental): this optimization requires the usage
# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer
# and use the Python reducer to allow compiled_autograd to trace the
# communication and allow comm/compute overlap without graph-breaks.
# 3. "python_reducer_without_compiled_forward" (experimental): this mode is
# similar to "python_reducer". One should only use this optimization mode
# when compiled_autograd is used but the DDP module is not compiled.
# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor
# will Python reducer be used. With this mode, there will be no graph-breaks
# and the original DDP C++ reducer will be used. There will no comm/compute
# overlap. This mode CANNOT be used with compiled_autograd.
# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be
# specified with a boolean value. True is using ddp_optimizer and False is
# no optimization.
optimize_ddp: Union[bool, str] = True
# By default, Dynamo emits runtime asserts (e.g. torch._check, torch._check_is_size) in the graph.
# In some cases those asserts could be performance costly
# E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync.
# Setting this to True keeps them hinting to symbolic shapes engine,
# but not be emitted in the graph.
do_not_emit_runtime_asserts: bool = (
os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1"
)
_ddp_optimization_mode = [
"ddp_optimizer",
"python_reducer", # experimental mode
"python_reducer_without_compiled_forward", # experimental mode
"no_optimization",
]
def _get_optimize_ddp_mode():
m = sys.modules[__name__]
if isinstance(m.optimize_ddp, bool):
if m.optimize_ddp:
mode = "ddp_optimizer"
else:
mode = "no_optimization"
elif isinstance(m.optimize_ddp, str):
mode = m.optimize_ddp
else:
raise ValueError(f"Invalid type, {type(optimize_ddp)=}")
assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}"
return mode
# Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS
skip_torchrec = True
# No longer used
optimize_ddp_lazy_compile = False
# Whether to skip guarding on FSDP-managed modules
skip_fsdp_guards = True
# Whether to apply torch._dynamo.disable() to FSDP2 hooks.
# Defaults to True. If Traceable FSDP2 is used, set this to False.
skip_fsdp_hooks = True
# Make dynamo skip guarding on hooks on nn modules
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
# dynamo will not notice and will execute whichever version you first compiled.
skip_nnmodule_hook_guards = True
# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage = True
# If True, raise when aot autograd is unsafe to use
raise_on_unsafe_aot_autograd = False
# If true, error if you torch.jit.trace over a dynamo-optimized function.
# If false, silently suppress dynamo
error_on_nested_jit_trace = True
# If true, error with a better message if we symbolically trace over a
# dynamo-optimized function. If false, silently suppress dynamo.
error_on_nested_fx_trace = True
# Disables graph breaking on rnn. YMMV with backends.
allow_rnn = False
# If true, enables feature that captures PyTorch sparsity in the
# exported FX graph. This flag should become the default eventually
# and be removed, but currently provides a way to fall back to old
# graph breaking behavior.
capture_sparse_compute = False if is_fbcode() else True
# If true, error if we try to compile a function that has
# been seen before.
# [@compile_ignored: runtime_behaviour]
error_on_recompile = False
# [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything)
report_guard_failures = True
# [@compile_ignored: debug] root folder of the project
base_dir = dirname(dirname(dirname(abspath(__file__))))
# Trace through NumPy or graphbreak
trace_numpy = True
# Default NumPy dtypes when tracing with torch.compile
# We default to 64bits. For efficiency, one may want to change these to float32
numpy_default_float = "float64"
numpy_default_complex = "complex128"
numpy_default_int = "int64"
# use numpy's PRNG if True, pytorch otherwise
use_numpy_random_stream = False
# Use C++ guard manager
enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1") == "1"
# Inline inbuilt nn modules
inline_inbuilt_nn_modules = not is_fbcode()
# When set, total compile time instruction count is recorded using
# torch._dynamo.utilsCompileTimeInstructionCounter.
record_compile_time_instruction_count = False
def default_debug_dir_root():
# [@compile_ignored: debug]
DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
if DEBUG_DIR_VAR_NAME in os.environ:
return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
elif is_fbcode():
return os.path.join(
tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug"
)
else:
return os.path.join(os.getcwd(), "torch_compile_debug")
# [@compile_ignored: debug]
debug_dir_root = default_debug_dir_root()
# [@compile_ignored: debug]
_save_config_ignore = {
"repro_after",
"repro_level",
# workaround: "cannot pickle PyCapsule"
"constant_functions",
# workaround: "cannot pickle module"
"skipfiles_inline_module_allowlist",
}
# for backend="cudagraphs", mutations on input be sent to the cudagraph backend
# or replayed in aot_autograd epilogue. default is False because mutation on inputs
# can prevent cudagraphing.
cudagraph_backend_keep_input_mutation = False
# enable cudagraph support for mutated inputs from prior cudagraph pool
cudagraph_backend_support_input_mutation = False
# When True, only ops that have the torch.Tag.pt2_compliant tag
# will be allowed into the graph; all other ops will be disallowed
# and will fall back to eager-mode PyTorch. Useful to ensure
# correctness of custom ops.
only_allow_pt2_compliant_ops = False
capture_autograd_function = True
# enable/disable dynamo tracing for `torch.func` transforms
capture_func_transforms = True
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
log_compilation_metrics = True
# A set of logging functions which will be reordered to the end of graph breaks,
# allowing dynamo to construct larget graph. Note that there are some
# limitations to this, such as how it does not correctly print objects that were
# mutated after the print statement.
reorderable_logging_functions: Set[Callable[[Any], None]] = set()
# simulates what would happen if we didn't have support for BUILD_SET opcode,
# used for testing
inject_BUILD_SET_unimplemented_TESTING_ONLY = False
_autograd_backward_strict_mode_banned_ops = [
"stride",
"requires_grad",
"storage_offset",
"layout",
"data",
]
_autograd_backward_strict_mode_banned_ops.extend(
[name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)]
)
# Enables caching of dispatches to fake tensors.
fake_tensor_cache_enabled = (
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1"
)
# Enables cross checking between the fake tensor cache and dispatch.
fake_tensor_cache_crosscheck_enabled = (
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1"
)
# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile().
# Note: AOT Autograd will still trace joint graphs.
compiled_autograd = False
# Enables use of collectives *during* compilation to synchronize behavior
# across ranks. Today, this is used solely to modify automatic_dynamic_shapes
# behavior, making it so that we infer that if an input is dynamic by
# inspecting whether or not its input size varies across ranks. Because
# this synchronization uses collectives, all ranks must run compilation at
# the same time; ranks must not diverge with graph breaks. This can be most
# reliably achieved by ensuring PT2 only is run on SPMD programs. If this
# invariant is inviolated, you will likely deadlock NCCL and encounter a
# NCCL timeout.
enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1"
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
def _make_closure_patcher(**changes):
...
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,60 @@
# mypy: allow-untyped-defs
import threading
from contextlib import contextmanager
import torch
doc = """
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
to flow into the parameter as if it were an input to the graph (which is the only thing we are
allowed to compute gradients on).
""".strip()
class TracableCreateParameter(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, placeholder):
assert not tensor.requires_grad
return placeholder.set_(tensor)
@staticmethod
def backward(ctx, grad):
return None, grad # grad flows to placeholder
def tracable_create_parameter(tensor, placeholder):
with torch.set_grad_enabled(placeholder.requires_grad):
out = TracableCreateParameter.apply(tensor, placeholder)
return out
def new_parameter_placeholder(size, dtype, device, requires_grad):
"""Create a placeholder to be passed to the above functions"""
result = torch.nn.Parameter(
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
)
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
# Allocating a zero tensor would causes assert failures in autograd.
result.untyped_storage().resize_(0)
return result
_TLS = threading.local()
@contextmanager
def do_not_convert_to_tracable_parameter():
old_flag = getattr(_TLS, "convert_tracable_parameter", True)
_TLS.convert_tracable_parameter = False
try:
yield False
finally:
_TLS.convert_tracable_parameter = old_flag
def can_convert_to_tracable_parameter():
return getattr(_TLS, "convert_tracable_parameter", True)

View File

@ -0,0 +1,25 @@
# mypy: allow-untyped-defs
import contextlib
import threading
# Global variable to identify which SubgraphTracer we are in.
# It is sometimes difficult to find an InstructionTranslator to use.
_current_scope_id = threading.local()
def current_scope_id():
global _current_scope_id
if not hasattr(_current_scope_id, "value"):
_current_scope_id.value = 1
return _current_scope_id.value
@contextlib.contextmanager
def enter_new_scope():
global _current_scope_id
try:
_current_scope_id.value = current_scope_id() + 1
yield
finally:
_current_scope_id.value = current_scope_id() - 1

View File

@ -0,0 +1,824 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="method-assign"
import atexit
import copy
import cProfile
import functools
import getpass
import inspect
import itertools
import logging
import os
import re
import subprocess
import sys
import tempfile
import textwrap
from collections import Counter
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional, TypeVar
import torch
import torch._prims_common as utils
import torch._subclasses.meta_utils
from torch import Tensor
from torch._dynamo.testing import rand_strided
from torch._prims_common import is_float_dtype
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
from . import config
from .utils import clone_inputs, get_debug_dir
log = logging.getLogger(__name__)
T = TypeVar("T")
inductor_config = import_module("torch._inductor.config")
use_buck = inductor_config.is_fbcode()
if use_buck:
import libfb.py.build_info
extra_deps = []
extra_imports = ""
if use_buck:
extra_deps = [
"//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
"//caffe2/torch/fb/sparsenn:sparsenn_operators",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
]
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined]
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
class BuckTargetWriter:
def __init__(self, filename):
self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
self.target = self.py_file.replace(".py", "")
# Get main_module path from fbcode
self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
self.path = self.path[self.path.find("fbcode.") :]
self.path = self.path[7:]
# Get cmd line path
tmp = self.subdir
tmp = tmp[tmp.find("fbcode/") :][7:]
self.cmd_line_path = f"//{tmp}:{self.target}"
def build(self):
extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
return textwrap.dedent(
f"""
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
python_binary(
name="{self.target}",
srcs = ["{self.py_file}"],
compile = False,
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch",
"//triton:triton",
"{cur_target}",
],
cpp_deps = [
{extra_cpp_deps}
],
main_module = "{self.path}",
par_style = "xar",
)
"""
)
def write(self, print_msg=True):
target_file = os.path.join(self.subdir, "TARGETS")
with open(target_file, "w") as fd:
fd.write(self.build())
# log.warning("Wrote isolation TARGETS file at %s", target_file)
cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
if print_msg:
log.warning(
"Found an example that reproduces the error. Run this cmd to repro - %s",
" ".join(cmd_split),
)
return cmd_split
def minifier_dir():
path = os.path.join(get_debug_dir(), "minifier")
if path is None:
path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
return path
MAX_CONSTANT_NUMEL_INLINE = 4
class NNModuleToString:
safe_reprs = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
torch.nn.Softmax,
torch.nn.ReLU,
torch.nn.GELU,
torch.nn.Identity,
torch.nn.MaxPool2d,
torch.nn.Embedding,
torch.nn.Tanh,
torch.nn.ConvTranspose1d,
torch.nn.GLU,
torch.nn.LSTM,
torch.nn.Flatten,
torch.nn.AdaptiveAvgPool2d,
]
@staticmethod
def can_convert_to_string(gm):
cant_convert = set()
for _, module in gm.named_children():
if type(module) not in NNModuleToString.safe_reprs:
cant_convert.add(module)
if len(cant_convert) > 0:
log.warning("We have not tested reprs of some modules - %s", cant_convert)
# TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
return True
@staticmethod
def convert(gm):
from torch.nn.modules.module import _addindent
tab = " " * 4
model_str = textwrap.dedent(
"""
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
"""
)
for module_name, module in gm.named_children():
module_str = f"{module.__repr__()}"
# module should be a core torch.nn.Module, so all parameters
# should be on the same device.
example_param = next(module.parameters(), None)
if example_param is not None and example_param.is_cuda:
module_str = f"{module_str}.cuda()"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in gm._buffers.items():
if buffer is None:
continue
# Serialize full data for small buffers
if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE:
from torch._tensor_str import PRINT_OPTS
assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE
tensor_str = repr(buffer)
elif torch.is_floating_point(buffer):
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
else:
tensor_str = (
f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
)
if buffer.is_cuda:
tensor_str = f"{tensor_str}.cuda()"
model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
for param_name, param in gm._parameters.items():
if param is None:
continue
maybe_device = ""
if param.is_cuda:
maybe_device = ', device="cuda"'
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))"
model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
# TODO - Keep this code for now. But, I don't think we will need this.
# attrs = dir(gm)
# for attr in attrs:
# if "_tensor_constant" in attr:
# val = getattr(gm, attr)
# model_str += f" {attr} = {val!r}\n"
model_str += f"{_addindent(gm.code, 4)}\n"
return model_str
@functools.lru_cache(None) # subprocess is expensive
def _cuda_system_info_comment():
if not torch.cuda.is_available():
return "# torch.cuda.is_available()==False, no GPU info collected\n"
model_str = "# CUDA Info: \n"
try:
cuda_version_out = subprocess.check_output(["nvcc", "--version"])
cuda_version_lines = cuda_version_out.decode().split("\n")
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
model_str += f"{comment}\n"
except (FileNotFoundError, subprocess.CalledProcessError):
model_str += "# nvcc not found\n"
gpu_names = Counter(
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
)
model_str += "# GPU Hardware Info: \n"
for name, count in gpu_names.items():
model_str += f"# {name} : {count} \n"
model_str += "\n"
return model_str
def generate_config_string(*, stable_output=False):
import torch._functorch.config
import torch._inductor.config
if stable_output:
return "# config omitted due to stable_output=True"
experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined]
return f"""\
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
{torch._functorch.config.codegen_config()}
{experimental_config}
"""
def get_minifier_repro_path():
return os.path.join(minifier_dir(), "minifier_launcher.py")
def helper_for_dump_minify(contents):
minified_repro_path = get_minifier_repro_path()
log.warning("Writing minified repro to:\n%s", minified_repro_path)
if use_buck:
BuckTargetWriter(minified_repro_path).write()
try:
with open(minified_repro_path, "w") as fd:
fd.write(contents)
except OSError as e:
log.exception("")
raise NotImplementedError("Could not write to {minified_repro_path}") from e
class AccuracyError(Exception):
pass
def clone_inputs_retaining_gradness(example_inputs):
"""
This clone inputs is different from utils clone_input. In case of minifier,
all the tensors are leaf tensors while creating a new graph. So, we set the
requires_grad field w/o checking the leafness of the tensor.
"""
cloned_inputs = clone_inputs(example_inputs)
for idx in range(len(example_inputs)):
if isinstance(cloned_inputs[idx], torch.Tensor):
cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
return cloned_inputs
def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
"""
Runs a forward and possibly backward iteration for a given mod and args.
When disable_clone is True, we will use args as-is without cloning.
This is higher fidelity but we may destroy the args in the process.
"""
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
gm = copy.deepcopy(gm)
if not disable_clone:
args = clone_inputs_retaining_gradness(args)
if hasattr(gm, "zero_grad"):
gm.zero_grad(True)
# TorchInductor returned callable expects lists. So, may need a boxed calling convention.
out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args)
if only_fwd:
return out
if requires_bwd_pass(out):
loss = reduce_to_scalar_loss(out)
loss.backward()
return collect_results(gm, out, None, args)
def same_two_models(
gm,
opt_gm,
example_inputs,
only_fwd=False,
*,
require_fp64=False,
ignore_non_fp=False,
):
"""
Check two models have same accuracy.
require_fp64: if True, raise an error if we unable to calculate the fp64 reference
ignore_non_fp: if True, do not compare outputs which are not floating point. This
is mostly useful for the minifier (which wants to avoid quantizing floating point
error into integer/boolean error)
"""
from .utils import same
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
fp64_ref = None
if config.same_two_models_use_fp64:
try:
fp64_model, fp64_examples = cast_to_fp64(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
except Exception:
if require_fp64:
raise RuntimeError( # noqa: B904
"Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False"
)
log.warning("Could not generate fp64 outputs")
try:
res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
except Exception as e:
# This means that the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return True.
log.exception(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph."
)
return True
passing = same(
ref,
res,
fp64_ref,
tol=config.repro_tolerance,
equal_nan=True,
ignore_non_fp=ignore_non_fp,
)
return passing
def cast_dtype_args_to_fp64(model):
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.prims.convert_element_type.default
):
assert len(node.args) == 2
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
node.args = (node.args[0], torch.float64)
if node.op == "call_function":
dtype = node.kwargs.get("dtype")
if dtype is not None and is_float_dtype(dtype):
new_kwargs = dict(node.kwargs)
new_kwargs["dtype"] = torch.float64
node.kwargs = new_kwargs
model.graph.lint()
model.recompile()
return model
def cast_to(dtype, model, inputs):
from torch.utils._pytree import tree_map
model = model.to(dtype)
if dtype == torch.float64:
# If casting to fp64 for accuracy comparison, we need to
# replace dtype arguments embedded in the graph with fp64
model = cast_dtype_args_to_fp64(model)
inputs = tree_map(
lambda x: x.to(dtype)
if isinstance(x, torch.Tensor) and x.is_floating_point()
else x,
inputs,
)
return model, inputs
def cast_to_fp64(model, inputs):
return cast_to(torch.float64, model, inputs)
def backend_accuracy_fails(
gm,
example_inputs,
compiler_fn,
only_fwd=False,
*,
require_fp64=False,
ignore_non_fp=False,
):
try:
compiled_gm = compiler_fn(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
return not same_two_models(
gm,
compiled_gm,
example_inputs,
only_fwd,
require_fp64=require_fp64,
ignore_non_fp=ignore_non_fp,
)
except Exception as e:
# This means that the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return False.
log.exception(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph"
)
return False
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# REPRO SUPPORT CODE
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Helper functions for computing what the default values of tensor
# values should be. These all coincide with factory functions, e.g., torch.empty
def _stride_or_default(
stride: Optional["torch._prims_common.StrideType"],
*,
shape: "torch._prims_common.ShapeType",
) -> "torch._prims_common.StrideType":
return stride if stride is not None else utils.make_contiguous_strides_for(shape)
def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
return lambda x: x if x is not None else d
_dtype_or_default = _mk_defaulter(torch.float32)
_device_or_default = _mk_defaulter(torch.device("cpu"))
_storage_offset_or_default = _mk_defaulter(0)
_requires_grad_or_default = _mk_defaulter(False)
_is_leaf_or_default = _mk_defaulter(False)
class NopInputReader:
def __init__(self) -> None:
self.total = 0
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
self.total += 1
def tensor(self, *args, **kwargs):
pass
def symint(self, *args, **kwargs):
pass
# TODO: Support bundling the entire repro into a zip file for ease of
# transferring around
class InputReader:
def __init__(self, save_dir=None, *, pbar=None):
# If None, we will generate random data instead. It's important
# to natively support this use case as it will allow people to
# share repros without including the real data, if the problem
# reproduces even on random data.
if save_dir is None:
log.warning("no save_dir specified, will generate random data")
self.store = ContentStoreReader(save_dir) if save_dir is not None else None
self.args = []
self.pbar = pbar
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
if self.pbar is not None:
self.pbar.update(1)
device = _device_or_default(device)
dtype_hint = _dtype_or_default(dtype_hint)
if self.store is not None and storage_hash is not None:
try:
storage = self.store.read_storage(storage_hash)
except FileNotFoundError:
pass
else:
if device != storage.device:
log.warning("device mismatch: %s != %s", device, storage.device)
# TODO: transfer it to the right device? But failing this
# way would be very mysterious! Would have been better
# not to store device in the serialized format...
return storage
log.warning("could not load %s, generating random data instead", storage_hash)
shape = (nbytes // dtype_hint.itemsize,)
stride = _stride_or_default(None, shape=shape)
return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
def tensor(
self,
storage,
shape,
stride=None,
*,
storage_offset=None,
dtype=None,
requires_grad=None,
is_leaf=None,
**metadata,
):
stride = _stride_or_default(stride, shape=shape)
storage_offset = _storage_offset_or_default(storage_offset)
dtype = _dtype_or_default(dtype)
is_leaf = _is_leaf_or_default(is_leaf)
requires_grad = _requires_grad_or_default(requires_grad)
t = torch.tensor(
[], dtype=dtype, device=storage.device, requires_grad=requires_grad
)
with torch.no_grad():
t.set_(storage, storage_offset, shape, stride)
if not is_leaf:
# Fake up some autograd history in a very naughty way
with torch.enable_grad():
t = t.clone(memory_format=torch.preserve_format)
with torch.no_grad():
t.set_(storage, storage_offset, shape, stride)
assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf
torch._utils.set_tensor_metadata(t, metadata)
self.args.append(t)
return t # for BC
def symint(self, val):
self.args.append(val)
return val # for BC
# Here is our writer strategy:
# 1. We will stream all of the inputs to disk
# 2. You can now deterministically randomize the inputs, or reload
# the inputs from disk
# 3. You can YOLO run the script without the inputs, in which case
# we'll fill the inputs with random data and pray. This is the
# legacy behavior, but it's also useful if you want to find out
# if we're so broken even random inputs trigger it
# 4. We could offer an in process "check if the randomized thing
# works too" but this is delicate so we don't do it
class InputWriter:
def __init__(self, save_dir, *, stable_hash=False):
self._lines = []
# TODO: consider ensuring tensor and storage counters line up?
self.storage_counter = itertools.count()
self.save_dir = save_dir
self.store = (
ContentStoreWriter(save_dir, stable_hash=stable_hash)
if save_dir is not None
else None
)
self.seen_storages = {}
def lines(self):
r = [
"def load_args(reader):",
]
r.extend(f" {l}" for l in self._lines)
# In case we need to change the internal format of load_args
# in an FC-breaking way
r.append("load_args._version = 0")
return r
# Storages are untyped, but we need to initialize them with data if
# we don't have the real data, so we give a hint saying what kind
# of initialization may be appropriate
#
# If we had a FakeTensor, device_hint tells us what device should be
def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str:
ws = StorageWeakRef(untyped_storage)
v = self.seen_storages.get(ws)
if v is not None:
return v
v = f"buf{next(self.storage_counter)}"
maybe_dtype_hint = ""
if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
# TODO: being optional on device is kind of pointless as the default
# is CPU but most repros we care about are CUDA
maybe_device = ""
device = untyped_storage.device
if device.type == "meta":
assert device_hint is not None
device = device_hint
if _device_or_default(None) != device:
maybe_device = f", device={device!r}"
nbytes = untyped_storage.nbytes()
storage_hash = None
if self.store is not None and untyped_storage.device.type != "meta":
storage_hash = self.store.write_storage(untyped_storage)
self._lines.append(
f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
)
self.seen_storages[ws] = v
return v
def tensor(self, name, t) -> None:
from torch.fx.experimental.symbolic_shapes import statically_known_true
storage = self.storage(
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
)
args = []
# NB: this is positional, must come first
if _stride_or_default(None, shape=t.shape) != t.stride():
args.append(str(tuple(t.stride())))
if _dtype_or_default(None) != t.dtype:
args.append(f"dtype={t.dtype!r}")
if not statically_known_true(
_storage_offset_or_default(None) == t.storage_offset()
):
args.append(f"storage_offset={t.storage_offset()!r}")
tensor_metadata = torch._utils.get_tensor_metadata(t)
if tensor_metadata:
args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items())
if _requires_grad_or_default(None) != t.requires_grad:
args.append(f"requires_grad={t.requires_grad!r}")
is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t)
if _is_leaf_or_default(None) != is_leaf:
args.append(f"is_leaf={is_leaf!r}")
self._lines.append(
"reader.tensor("
+ ", ".join([storage, str(tuple(t.shape)), *args])
+ f") # {name}"
)
# TODO: this doesn't actually symint atm
def symint(self, name, val) -> None:
if isinstance(val, torch.SymInt):
val = val.node.hint
self._lines.append(f"reader.symint({val!r}) # {name}")
def aot_graph_input_parser(
func: Callable[[List[Tensor]], List[Tensor]],
device: str = "cuda",
sym_shapes: Optional[Dict[str, int]] = None,
default_sym_shape: Optional[int] = None,
) -> Dict[str, Any]:
"""
Takes in a function which has been printed with print_readable() and constructs kwargs to run it.
Handles Tensor inputs, Symints, and a graph module which might have tensor constants.
Consider a function `forward` defined as follows:
def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",):
_tensor_constant0: "i64[4190]" = self._tensor_constant0
# Further implementation
kwargs = aot_graph_input_parser(forward)
forward(**kwargs)
"""
from torch.fx.graph import dtype_abbrs
dtype_map = {value: key for key, value in dtype_abbrs.items()}
dtype_pattern = "|".join(dtype_abbrs.values())
# Extracting the source code from the function
source = inspect.getsource(func)
# Regular expressions
tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)"
tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]"
sym_shape_regex = r"Sym\((s\d+)\)"
class TensorContainer:
"Container for tensors as attributes"
# Dictionary for tensors from annotations
kwargs: Dict[str, Any] = {}
sym_shapes = sym_shapes or {}
def get_sym_int(symint):
torch._check(
symint in sym_shapes or default_sym_shape is not None,
lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in",
)
return sym_shapes.get(symint, default_sym_shape)
def gen_tensor(shape, dtype) -> Tensor:
# Resolve symbolic shapes to concrete values
resolved_shape = []
dynamic_dims = []
for i, dim in enumerate(shape):
dim = dim.strip()
if "s" in dim:
s = get_sym_int(dim)
resolved_shape.append(s)
dynamic_dims.append(i)
else:
if dim:
resolved_shape.append(int(dim))
constructor = torch.randn if dtype.is_floating_point else torch.zeros
out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg]
for d in dynamic_dims:
torch._dynamo.mark_dynamic(out, d)
return out
# Parse function annotations for tensor generation
annotations = func.__annotations__
for param, annotation in annotations.items():
# Skip 'return' annotation
if param == "return":
continue
match = re.search(tensor_regex, annotation)
if match:
data_type, shape_str = match.groups()
shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type]
kwargs[param] = gen_tensor(shape, dtype)
match = re.search(sym_shape_regex, annotation)
if match:
kwargs[param] = get_sym_int(match.group(1))
if "self" in inspect.signature(func).parameters:
container = TensorContainer()
kwargs["self"] = container
for match in re.finditer(tensor_assignment_regex, source):
attr_name, data_type, shape_str, _ = match.groups()
shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type]
setattr(container, attr_name, gen_tensor(shape, dtype))
return kwargs
def profile_to_file(filename: str) -> Callable[[T], T]:
"""
Decorator to cProfile a given function and save the result to disk on process exit.
Args:
filename: filename to save profile to
"""
prof = cProfile.Profile()
filename = os.path.abspath(os.path.expanduser(filename))
def decorator(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
prof.enable()
try:
return fn(*args, **kwargs)
finally:
prof.disable()
return wrapper
def save_it():
prof.dump_stats(filename)
sys.stderr.write(
textwrap.dedent(
f"""\
Wrote profile to {filename}, view with:
snakeviz {filename}
"""
)
)
atexit.register(save_it)
return decorator

View File

@ -0,0 +1,580 @@
# mypy: allow-untyped-defs
# ruff: noqa: TCH004
import functools
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import trace_rules, variables
from .comptime import comptime
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
from .exc import IncorrectUsage
from .external_utils import is_compiling
from .utils import is_function
if TYPE_CHECKING:
from types import FunctionType
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_eval_frame,
set_guard_error_hook,
skip_code,
unsupported,
)
from .variables import VariableTracker
else:
for name in dir(torch._C._dynamo.eval_frame):
if name.startswith("__"):
continue
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
_F = TypeVar("_F", bound=Callable[..., Any])
def run(fn=None):
"""Don't do any dynamic compiles, just use prior optimizations"""
if fn is not None:
fn = innermost_fn(fn)
assert callable(fn)
return RunOnlyContext()(fn)
return RunOnlyContext()
def disable(fn=None, recursive=True):
"""
Decorator and context manager to disable TorchDynamo
If recursive=True, Dynamo is completely skipped on the decorated function
frame as well as the recursively invoked functions.
If recursive=False, Dynamo skips frames associated with the function code,
but still process recursively invoked frames.
"""
if recursive:
if fn is not None:
fn = innermost_fn(fn)
assert callable(fn)
return DisableContext()(fn)
return DisableContext()
else:
return skip(fn)
def skip(fn=None):
"""
Skip frames associated with the function code, but still process recursively
invoked frames
"""
if fn is None:
return skip
fn = innermost_fn(fn)
assert callable(fn)
skip_code(fn.__code__)
fn._torchdynamo_disable = True
return fn
def assume_constant_result(fn):
fn._dynamo_marked_constant = True
return fn
def allow_in_graph(fn):
"""
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
and instead directly write it to the graph when encountered.
See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation
WARNING: this API can be a footgun, please read the documentation carefully.
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable:
trace_rules._disallowed_callable_ids.remove(id(fn))
trace_rules._allowed_callable_ids.add(id(fn))
return fn
def _disallow_in_graph_helper(throw_if_not_allowed):
def inner(fn):
if isinstance(fn, (list, tuple)):
return [disallow_in_graph(x) for x in fn]
assert callable(fn), "disallow_in_graph expects a callable"
if (
throw_if_not_allowed
and trace_rules.lookup_callable(fn)
!= variables.TorchInGraphFunctionVariable
and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable
):
raise IncorrectUsage(
"disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). "
"Allowed callables means callables that TorchDynamo puts as-is in the extracted graph."
)
trace_rules._allowed_callable_ids.remove(id(fn))
trace_rules._disallowed_callable_ids.add(id(fn))
return fn
return inner
def disallow_in_graph(fn):
"""
Customize which functions TorchDynamo will exclude in the generated
graph and force a graph break on.
::
torch._dynamo.disallow_in_graph(torch.sub)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = torch.sub(x, 1)
x = torch.add(x, 1)
return x
fn(...)
Will break the graph on `torch.sub`, and give two graphs each with a
single `torch.add()` op.
"""
return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn)
@_disallow_in_graph_helper(throw_if_not_allowed=False)
def graph_break():
"""Force a graph break"""
def forbid_in_graph(fn):
"""
Customize which functions TorchDynamo will assert are not present while tracing.
If you want a graph break on this function instead, use disallow_in_graph.
TODO(voz): We now have allow_in_graph, disallow_in_graph, forbid_in_graph - some more robust
documentation would not be amiss.
"""
if isinstance(fn, (list, tuple)):
return [forbid_in_graph(x) for x in fn]
assert callable(fn), "forbid_in_graph applies only to callables"
fn._dynamo_forbidden = True
return fn
def substitute_in_graph(
original_fn: _F,
*,
can_constant_fold_through: bool = False,
skip_signature_check: bool = False,
# type that is embedded in the Python interpreter
is_embedded_type: bool = False, # internal use only
) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph.
.. note::
The polyfill handler is only used when inlining the original function. It is not used when
the original function is called directly. In the eager mode, the decorated function calls
the performant C function rather than the polyfill handler.
The polyfill handler is a function that will be called in place of the original function when
inlining the original function. The polyfill handler should have the same signature and the same
behavior as the original function.
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
folded through. That is, if the polyfill handler is a pure function and its arguments
are constant, the result of the polyfill handler can be constant folded during the
compilation. Defaults to ``False``.
skip_signature_check (bool, optional): Whether to skip the signature check between the
original function and the polyfill handler. Defaults to ``False``.
Returns:
A decorator that registers the polyfill handler for the original function.
Example::
>>> # xdoctest: +SKIP("conflict with the tests: duplicate polyfill handlers")
>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
Traceback (most recent call last):
...
torch._dynamo.exc.Unsupported: ...
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(a, b, /):
... for i, item in enumerate(a):
... if item is b or item == b:
... return i
... raise ValueError("sequence.index(x): x not in sequence")
>>>
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
"""
if not is_function(original_fn) and not (
is_embedded_type and inspect.isclass(original_fn)
):
raise TypeError(
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
)
if is_embedded_type:
if not inspect.isclass(original_fn):
raise TypeError(
f"substitute_in_graph expects a class but got {type(original_fn)!r}"
)
from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS
if id(original_fn) in ITERTOOLS_TYPE_IDS:
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
def wrapper(traceable_fn: _F) -> _F:
if not is_function(traceable_fn):
raise TypeError(
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
)
if not skip_signature_check:
try:
original_sig = inspect.signature(original_fn)
except ValueError:
pass
else:
traceable_sig = inspect.signature(traceable_fn)
def sig_ident(sig):
# Ignore annotations for parameters and return type
return (
tuple(
p.name
for p in sig.parameters.values()
if (
p.kind
not in {
p.KEYWORD_ONLY,
# the name of *args and **kwargs is not important
p.VAR_POSITIONAL,
p.VAR_KEYWORD,
}
)
),
{
p.name
for p in sig.parameters.values()
if p.kind == p.KEYWORD_ONLY
},
{
p.name: p.default
for p in sig.parameters.values()
# the name of *args and **kwargs is not important
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
},
)
wildcard_sig = inspect.signature(lambda *args, **kwargs: None)
if (
sig_ident(original_sig) != sig_ident(traceable_sig)
and sig_ident(original_sig) != sig_ident(wildcard_sig)
and sig_ident(traceable_sig) != sig_ident(wildcard_sig)
):
raise TypeError(
f"Signature mismatch between {original_fn} and {traceable_fn}: "
f"{original_sig} != {traceable_sig}"
)
from torch._dynamo.guards import GuardBuilder
from torch._dynamo.trace_rules import get_torch_obj_rule_map
from torch._dynamo.variables import PolyfilledFunctionVariable
from torch._dynamo.variables.builder import VariableBuilder
id_dispatch_map = VariableBuilder._id_dispatch()
if id(original_fn) in id_dispatch_map:
raise ValueError(
f"Duplicate dispatch rule for {original_fn}: "
"already registered in VariableBuilder's id dispatch map"
)
rule_map: Dict[Any, Type[VariableTracker]] = get_torch_obj_rule_map()
if original_fn in rule_map:
raise ValueError(
f"Duplicate object {original_fn} with different rules: "
f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}"
)
polyfill_handlers: Dict[Callable[..., Any], FunctionType]
polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers()
if original_fn in polyfill_handlers:
raise ValueError(
f"Duplicate polyfill handlers for {original_fn}: "
f"already handled by {polyfill_handlers[original_fn]}"
)
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
# C++ function.
@functools.wraps(traceable_fn)
def wrapped(*args, **kwargs):
return original_fn(*args, **kwargs)
def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable:
return PolyfilledFunctionVariable(
value,
source=self.source,
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
)
id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn
rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable
polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = wrapped # type: ignore[assignment]
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined]
return wrapped # type: ignore[return-value]
return wrapper
# Helper function to flatten a tensor subclass and apply a function to
# all inner tensors that match the outer dim. Used to reduce duplication
# across the various marking APIs.
def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
assert is_traceable_wrapper_subclass(t)
attrs, ctx = t.__tensor_flatten__()
assert isinstance(t, torch.Tensor)
for attr in attrs:
inner = getattr(t, attr)
if inner.dim() == t.dim():
func(inner, *args, **kwargs)
@dataclass(frozen=True)
class _DimRange:
"""
This represents an dimension of a tensor and the corresponding
min and max values it can take. Don't create this
class directly; instead, use :func:`mark_dynamic`.
"""
dim: int
min: int
max: int
@forbid_in_graph
def mark_unbacked(t, index):
"""
Mark a tensor as having an unbacked dim. This changes the semantics of operations,
we will always report the size does not equal zero/one, we will turn asserts
on this index into runtime asserts, and if you try to get the real value we will
raise an exception. In other words, we will treat this dimension as if it was
data dependent (we do not know anything about its value.)
"""
# You could have copied the mark_dynamic behavior but I'm not convinced
# it's what you want
assert not is_traceable_wrapper_subclass(t), "not implemented yet"
if isinstance(index, int):
if not hasattr(t, "_dynamo_unbacked_indices"):
t._dynamo_unbacked_indices = set()
t._dynamo_unbacked_indices.add(index)
return
assert isinstance(index, (list, tuple))
for i in index:
mark_unbacked(t, i)
@forbid_in_graph
def mark_dynamic(t, index, *, min=None, max=None):
"""
Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.
[Note - on the state of mark_dynamic]
The behavior of having a dynamic dimension on a tensor is governed by a few factors:
1) torch._dynamo.config dynamic_shapes True or False.
a) dynamic_shapes=True - dynamic_shapes must be True for mark_dynamic to work.
a) dynamic_shapes=False - This config will raise an exception when used in conjunction with
mark_dynamic. We will eventually support this.
2) If the dimension is fully constrained - as in, it does not allow more than a single value
in both eager (torch.compile, torch._dynamo.optimize) mode and export mode (torch._dynamo.export),
we will raise an error
3) If the dimension is partially constrained - allowing at least 2 values but not the full unbounded
range of shapes, in eager we will pass it through, but export will raise an error.
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
before torch.compile.
"""
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(
mark_dynamic, t, index, min=min, max=max
)
if isinstance(index, int):
if not hasattr(t, "_dynamo_dynamic_indices"):
t._dynamo_dynamic_indices = set()
t._dynamo_dynamic_range = set()
# TODO(voz): Should we bounds check?
t._dynamo_dynamic_indices.add(index)
t._dynamo_dynamic_range.add(_DimRange(index, min, max))
return
assert isinstance(index, (list, tuple))
for i in index:
mark_dynamic(t, i, min=min, max=max)
@forbid_in_graph
def maybe_mark_dynamic(t, index):
"""
Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this
dimension ends up getting specialized, don't error).
"""
if is_traceable_wrapper_subclass(t):
# default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index)
if isinstance(index, int):
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
t._dynamo_weak_dynamic_indices = set()
# TODO(voz): Should we bounds check?
t._dynamo_weak_dynamic_indices.add(index)
return
assert isinstance(index, (list, tuple))
for i in index:
maybe_mark_dynamic(t, i)
def mark_static(t, index=None):
"""
Mark a tensor as having a static dim or mark a nn module class as static.
For tensors
===========
This will prevent us from attempting to compile it dynamically
when dynamic=True; this can improve trace-time performance.
This has lower precedence than mark_dynamic.
Unlike mark_dynamic, this can be done inside a graph, in which case it
induces specialization on the tensor.
For nn.Module classes
=====================
For static nn.Module classes, TorchDynamo assumes that the module instance
attributes will not be modified after compilation. This will ensure that
TorchDynamo keeps integer attributes CONSTANT and not symints.
From TorchDynamo implementation side, the instances of static-marked
nn.Module class will be converted to UnspecializedBuiltinNNModuleVariable,
which have the same properties.
Note that we still have to guard on the attributes, because different
instances of the nn.Module can have different values of the attributes. The
key point here is that the attributes are static.
"""
if is_compiling():
if index is None:
for s in t.size():
comptime.force_static(s)
else:
comptime.force_static(t.size(index))
return
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_static() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module):
t._dynamo_marked_static = True
return t
if not isinstance(t, torch.Tensor):
raise TypeError(
f"mark_static expects a tensor/nn.Module class but recieved {type(t)}"
)
if isinstance(index, int):
if not hasattr(t, "_dynamo_static_indices"):
t._dynamo_static_indices = set() # type: ignore[attr-defined]
# TODO(voz): Should we bounds check?
t._dynamo_static_indices.add(index) # type: ignore[attr-defined]
elif index is None:
for i in range(t.dim()):
mark_static(t, i)
else:
assert isinstance(index, (list, tuple))
for i in index:
mark_static(t, i)
@forbid_in_graph
def mark_static_address(t, guard=True):
"""
Marks an input tensor whose data_ptr will not change across multiple calls
to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation
is not needed for this input. The data_ptr will be guarded if guard=True. Note:
Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called.
"""
if not isinstance(t, torch.Tensor):
raise TypeError(f"mark_static_address expects a tensor but recieved {type(t)}")
if guard:
t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined]
else:
t._dynamo_static_input_type = "unguarded" # type: ignore[attr-defined]
# Note: this carefully avoids eagerly import einops.
# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2
def _allow_in_graph_einops():
import einops
try:
# requires einops > 0.6.1, torch >= 2.0
from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
_ops_were_registered_in_torchdynamo,
)
# einops > 0.6.1 will call the op registration logic as it is imported.
except ImportError:
# einops <= 0.6.1
allow_in_graph(einops.rearrange)
allow_in_graph(einops.reduce)
if hasattr(einops, "repeat"):
allow_in_graph(einops.repeat) # available since einops 0.2.0
if hasattr(einops, "einsum"):
allow_in_graph(einops.einsum) # available since einops 0.5.0
if hasattr(einops, "pack"):
allow_in_graph(einops.pack) # available since einops 0.6.0
if hasattr(einops, "unpack"):
allow_in_graph(einops.unpack) # available since einops 0.6.0
trace_rules.add_module_init_func("einops", _allow_in_graph_einops)

View File

@ -0,0 +1,330 @@
# mypy: allow-untyped-defs
import inspect
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
import torch
from torch._streambase import _EventBase, _StreamBase
get_cuda_stream: Optional[Callable[[int], int]]
if torch.cuda._is_compiled():
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
else:
get_cuda_stream = None
_device_t = Union[torch.device, str, int, None]
# Recording the device properties in the main process but used in worker process.
caching_worker_device_properties: Dict[str, Any] = {}
caching_worker_current_devices: Dict[str, int] = {}
class DeviceInterfaceMeta(type):
def __new__(metacls, *args, **kwargs):
class_member = args[2]
if "Event" in class_member:
assert inspect.isclass(class_member["Event"]) and issubclass(
class_member["Event"], _EventBase
), "DeviceInterface member Event should be inherit from _EventBase"
if "Stream" in class_member:
assert inspect.isclass(class_member["Stream"]) and issubclass(
class_member["Stream"], _StreamBase
), "DeviceInterface member Stream should be inherit from _StreamBase"
return super().__new__(metacls, *args, **kwargs)
class DeviceInterface(metaclass=DeviceInterfaceMeta):
"""
This is a simple device runtime interface for Inductor. It enables custom
backends to be integrated with Inductor in a device-agnostic semantic.
"""
class device:
def __new__(cls, device: _device_t):
raise NotImplementedError
class Worker:
"""
Worker API to query device properties that will work in multi processing
workers that cannot use the GPU APIs (due to processing fork() and
initialization time issues). Properties are recorded in the main process
before we fork the workers.
"""
@staticmethod
def set_device(device: int):
raise NotImplementedError
@staticmethod
def current_device() -> int:
raise NotImplementedError
@staticmethod
def get_device_properties(device: _device_t = None):
raise NotImplementedError
@staticmethod
def current_device():
raise NotImplementedError
@staticmethod
def set_device(device: _device_t):
raise NotImplementedError
@staticmethod
def maybe_exchange_device(device: int) -> int:
raise NotImplementedError
@staticmethod
def exchange_device(device: int) -> int:
raise NotImplementedError
@staticmethod
def device_count():
raise NotImplementedError
@staticmethod
def is_available() -> bool:
raise NotImplementedError
@staticmethod
def stream(stream: torch.Stream):
raise NotImplementedError
@staticmethod
def current_stream():
raise NotImplementedError
@staticmethod
def set_stream(stream: torch.Stream):
raise NotImplementedError
@staticmethod
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int):
raise NotImplementedError
@staticmethod
def get_raw_stream(device_idx: int) -> int:
raise NotImplementedError
@staticmethod
def synchronize(device: _device_t = None):
raise NotImplementedError
@staticmethod
def get_device_properties(device: _device_t = None):
raise NotImplementedError
@staticmethod
def get_compute_capability(device: _device_t = None):
raise NotImplementedError
@staticmethod
def is_bf16_supported(including_emulation: bool = False):
raise NotImplementedError
class DeviceGuard:
"""
This class provides a context manager for device switching. This is a stripped
down version of torch.{device_name}.device.
The context manager changes the current device to the given device index
on entering the context and restores the original device on exiting.
The device is switched using the provided device interface.
"""
def __init__(
self, device_interface: Type[DeviceInterface], index: Optional[int]
) -> None:
self.device_interface = device_interface
self.idx = index
self.prev_idx = -1
def __enter__(self):
if self.idx is not None:
self.prev_idx = self.device_interface.exchange_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
if self.idx is not None:
self.idx = self.device_interface.maybe_exchange_device(self.prev_idx)
return False
class CudaInterface(DeviceInterface):
device = torch.cuda.device
# register Event and Stream class into the backend interface
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
Event = torch.cuda.Event
Stream = torch.cuda.Stream
class Worker:
@staticmethod
def set_device(device: int):
caching_worker_current_devices["cuda"] = device
@staticmethod
def current_device() -> int:
if "cuda" in caching_worker_current_devices:
return caching_worker_current_devices["cuda"]
return torch.cuda.current_device()
@staticmethod
def get_device_properties(device: _device_t = None):
if device is not None:
if isinstance(device, str):
device = torch.device(device)
assert device.type == "cuda"
if isinstance(device, torch.device):
device = device.index
if device is None:
device = CudaInterface.Worker.current_device()
if "cuda" not in caching_worker_device_properties:
device_prop = [
torch.cuda.get_device_properties(i)
for i in range(torch.cuda.device_count())
]
caching_worker_device_properties["cuda"] = device_prop
return caching_worker_device_properties["cuda"][device]
current_device = staticmethod(torch.cuda.current_device)
set_device = staticmethod(torch.cuda.set_device)
device_count = staticmethod(torch.cuda.device_count)
stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
current_stream = staticmethod(torch.cuda.current_stream)
set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
synchronize = staticmethod(torch.cuda.synchronize)
get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment]
get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type]
exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type]
maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type]
is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type]
# Can be mock patched by @patch decorator.
@staticmethod
def is_available() -> bool:
return torch.cuda.is_available()
@staticmethod
def get_compute_capability(device: _device_t = None):
if torch.version.hip is None:
major, min = torch.cuda.get_device_capability(device)
return major * 10 + min
else:
return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
get_xpu_stream: Optional[Callable[[int], int]]
if torch.xpu._is_compiled():
from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
else:
get_xpu_stream = None
class XpuInterface(DeviceInterface):
device = torch.xpu.device
Event = torch.xpu.Event
Stream = torch.xpu.Stream
class Worker:
@staticmethod
def set_device(device: int):
caching_worker_current_devices["xpu"] = device
@staticmethod
def current_device() -> int:
if "xpu" in caching_worker_current_devices:
return caching_worker_current_devices["xpu"]
return torch.xpu.current_device()
@staticmethod
def get_device_properties(device: _device_t = None):
if device is not None:
if isinstance(device, str):
device = torch.device(device)
assert device.type == "xpu"
if isinstance(device, torch.device):
device = device.index
if device is None:
device = XpuInterface.Worker.current_device()
if "xpu" not in caching_worker_device_properties:
device_prop = [
torch.xpu.get_device_properties(i)
for i in range(torch.xpu.device_count())
]
caching_worker_device_properties["xpu"] = device_prop
return caching_worker_device_properties["xpu"][device]
current_device = staticmethod(torch.xpu.current_device)
set_device = staticmethod(torch.xpu.set_device)
device_count = staticmethod(torch.xpu.device_count)
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
current_stream = staticmethod(torch.xpu.current_stream)
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
synchronize = staticmethod(torch.xpu.synchronize)
get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment]
get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type]
exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type]
maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type]
# Can be mock patched by @patch decorator.
@staticmethod
def is_available() -> bool:
return torch.xpu.is_available()
@staticmethod
def get_compute_capability(device: _device_t = None):
cc = torch.xpu.get_device_capability(device)
return cc
@staticmethod
def is_bf16_supported(including_emulation: bool = False) -> bool:
return torch.xpu.is_bf16_supported()
device_interfaces: Dict[str, Type[DeviceInterface]] = {}
_device_initialized = False
def register_interface_for_device(
device: Union[str, torch.device], device_interface: Type[DeviceInterface]
):
if isinstance(device, torch.device):
device = str(device)
device_interfaces[device] = device_interface
def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]:
if isinstance(device, torch.device):
device = str(device)
if not _device_initialized:
init_device_reg()
if device in device_interfaces:
return device_interfaces[device]
raise NotImplementedError(f"No interface for device {device}")
def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]:
if not _device_initialized:
init_device_reg()
return device_interfaces.items()
def init_device_reg():
global _device_initialized
register_interface_for_device("cuda", CudaInterface)
for i in range(torch.cuda.device_count()):
register_interface_for_device(f"cuda:{i}", CudaInterface)
register_interface_for_device("xpu", XpuInterface)
for i in range(torch.xpu.device_count()):
register_interface_for_device(f"xpu:{i}", XpuInterface)
_device_initialized = True

View File

@ -0,0 +1,25 @@
from typing import Optional
import torch.distributed as dist
from . import config
_COMPILE_PG: Optional[dist.ProcessGroup] = None
def get_compile_pg() -> Optional[dist.ProcessGroup]:
if (
config.enable_compiler_collectives
and dist.is_available()
and dist.is_initialized()
):
global _COMPILE_PG
if _COMPILE_PG is None:
# , timeout=datetime.timedelta(seconds=2)
_COMPILE_PG = dist.distributed_c10d._new_group_with_tag(
pg_tag="pt2_compile_pg"
)
return _COMPILE_PG
return None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,454 @@
# mypy: allow-untyped-defs
import os
import textwrap
from enum import auto, Enum
from traceback import extract_stack, format_exc, format_list, StackSummary
from typing import Any, cast, NoReturn, Optional, Tuple, TYPE_CHECKING
import torch._guards
from . import config
from .utils import counters
if TYPE_CHECKING:
from torch._guards import CompileId
def exportdb_error_message(case_name):
return (
"For more information about this error, see: "
+ "https://pytorch.org/docs/main/generated/exportdb/index.html#"
+ case_name.replace("_", "-")
)
import logging
log = logging.getLogger(__name__)
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
class TorchDynamoException(RuntimeError):
pass
class InternalTorchDynamoError(TorchDynamoException):
pass
class RestartAnalysis(TorchDynamoException):
restart_reason: str
def __init__(self, *args, restart_reason=None) -> None:
self.restart_reason = restart_reason
super().__init__(*args)
class SpeculationRestartAnalysis(RestartAnalysis):
pass
class UnspecializeRestartAnalysis(RestartAnalysis):
pass
class CompileCollectiveRestartAnalysis(RestartAnalysis):
pass
class SkipFrame(TorchDynamoException):
pass
class TorchRuntimeError(TorchDynamoException):
pass
class InvalidBackend(TorchDynamoException):
def __init__(self, name) -> None:
super().__init__(
f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends."
)
class ResetRequired(TorchDynamoException):
def __init__(self) -> None:
super().__init__(
textwrap.dedent(
"""
Must call `torch._dynamo.reset()` before changing backends. Detected two calls to
`torch.compile()` with a different backend compiler arguments.
"""
)
)
class BackendCompilerFailed(TorchDynamoException):
def __init__(self, backend_fn, inner_exception) -> None:
self.backend_name = getattr(backend_fn, "__name__", "?")
self.inner_exception = inner_exception
msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
super().__init__(msg)
class Unsupported(TorchDynamoException):
def __init__(self, msg, *, case_name=None) -> None:
super().__init__(msg)
self.real_stack = torch._guards.TracingContext.extract_stack()
self.msg = msg
self.category: Optional[str] = None
self.add_to_stats()
self.case_name: Optional[str] = case_name
def remove_from_stats(self):
assert self.category is not None
counters[self.category][self.msg] -= 1
if counters[self.category][self.msg] <= 0:
del counters[self.category][self.msg]
def add_to_stats(self, category="unimplemented"):
self.category = category
counters[category][self.msg] += 1
class RecompileError(TorchDynamoException):
pass
class ArgsMismatchError(Unsupported):
def __init__(self, msg) -> None:
super().__init__(msg)
class AttributeMutationError(Unsupported):
def __init__(self, msg) -> None:
super().__init__(msg)
class CondOpArgsMismatchError(ArgsMismatchError):
"""
Internal error from cond() due to arguments mismatch.
"""
def __init__(self, msg) -> None:
super().__init__(msg)
class UserErrorType(Enum):
DYNAMIC_CONTROL_FLOW = auto()
ANTI_PATTERN = auto()
STANDARD_LIBRARY = auto()
CONSTRAINT_VIOLATION = auto()
DYNAMIC_DIM = auto()
INVALID_INPUT = auto()
INVALID_OUTPUT = auto()
class UserError(Unsupported):
def __init__(self, error_type: UserErrorType, msg, case_name=None) -> None:
"""
Type of errors that would be valid in Eager, but not supported in TorchDynamo.
The error message should tell user about next actions.
error_type: Type of user error
msg: Actionable error message
case_name: (Optional) Unique name (snake case) for the usage example in exportdb.
"""
if case_name is not None:
assert isinstance(case_name, str)
if msg.endswith("."):
msg += " "
else:
msg += "\n"
msg += exportdb_error_message(case_name)
super().__init__(msg)
self.error_type = error_type
self.message = msg
class SkipCodeRecursiveException(TorchDynamoException):
pass
class CacheLimitExceeded(SkipCodeRecursiveException, Unsupported):
pass
class UnsafeScriptObjectError(TorchDynamoException):
pass
class UncapturedHigherOrderOpError(TorchDynamoException):
pass
class IncorrectUsage(Exception):
pass
class ObservedException(TorchDynamoException):
# An exception observed during the tracing. This exception is used by Dynamo to handle exceptions.
pass
class ObservedUserStopIteration(ObservedException):
# An UserStopIteraion exception observed during the Dynamo tracing (e.g Dynamo tracing __next__)
value: Optional[Any]
# Reference `StopIteration_init` in CPython
# https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
def __init__(self, *args, **kwargs) -> None:
super().__init__("unhandled `raise StopIteration`")
if len(args) > 0:
self.value = args[0]
else:
self.value = None
class ObservedKeyError(ObservedException):
# A KeyError exception to be raised from inside Dynamo tracing. This can happen on dict __getitem__
pass
class ObservedAttributeError(ObservedException):
# An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
pass
observed_exception_map = {
StopIteration: ObservedUserStopIteration,
KeyError: ObservedKeyError,
AttributeError: ObservedAttributeError,
}
def raise_observed_exception(e, tx, vt):
from .variables import BuiltinVariable
# CPython here raises an exception. Since there is no python code, we have to manually setup the exception
# stack and raise the exception.
exception_vt = BuiltinVariable(e).call_function(vt, [], {})
tx.exn_vt_stack.append(exception_vt)
raise observed_exception_map[e]
def handle_observed_exception(tx):
# This is essentially exception handling code, equivalent of this pseudo code
#
# try:
# ... somebody raising StopIteration
# except StopIteration
# pass
#
# If this was going through the python code, we would have called exception_handler method, but FOR_ITER
# handles the exception completely in CPython. For example for 3.11, the resulting bytecode is
#
#
# 6 46 LOAD_GLOBAL 2 (StopIteration)
# 58 RAISE_VARARGS 1
# >> 60 PUSH_EXC_INFO
# 7 62 LOAD_GLOBAL 2 (StopIteration)
# 74 CHECK_EXC_MATCH
# 76 POP_JUMP_FORWARD_IF_FALSE 3 (to 84)
# 78 POP_TOP
# 8 80 POP_EXCEPT
#
# Fortunately this translates to a simple pop from the exn_vt_stack
tx.exn_vt_stack.pop()
# These exceptions are ok to fallback to eager/graph_break.
exceptions_allowed_to_be_fallback = (
torch._subclasses.fake_tensor.DataDependentOutputException,
torch._subclasses.fake_tensor.DynamicOutputShapeException,
torch._subclasses.fake_tensor.UnsupportedOperatorException,
torch._subclasses.fake_tensor.UnsupportedFakeTensorException,
)
def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn:
# This function calls unimplemented internally and eventually graph breaks
# or falls to eager. unimplemented itself does not print any user warnings,
# i.e., its very silent. This helper function is intended when an error is
# encountered in the torch.compile stack which is worth showing as warning
# to the user. For example, if AOT Autograd backend fails with a fake tensor
# exception, its ok to fallback to eager but not silently. Here, we can use
# this function to log the message and the stack trace.
graph_break_msg = format_error_msg_verbose(e, code)
graph_breaks_log.debug("%s", graph_break_msg)
log.warning(msg)
unimplemented(msg, from_exc=e)
_NOTHING = object()
def unimplemented(
msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None
) -> NoReturn:
assert msg != os.environ.get("BREAK", False)
if from_exc is not _NOTHING:
raise Unsupported(msg, case_name=case_name) from from_exc
raise Unsupported(msg, case_name=case_name)
def warning(msg: str) -> None:
counters["warnings"][msg] += 1
assert msg != os.environ.get("BREAK", False)
# KeyError has special handling for its args
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
class KeyErrorMsg:
def __init__(self, value) -> None:
self.value = value
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return self.__str__()
def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None:
import traceback
exc.innermost_user_frame_summary = None # type: ignore[attr-defined]
real_stack = get_real_stack(exc)
if real_stack is not None and len(real_stack) > 0:
exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined]
msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}"
if config.replay_record_enabled and hasattr(exc, "record_filename"):
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
torch._dynamo.replay('{exc.record_filename}').\n"
if not config.verbose and hasattr(exc, "real_stack"):
msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n'
if hasattr(exc, "inner_exception") and hasattr(
exc.inner_exception, "minifier_path"
):
if hasattr(exc.inner_exception, "buck_command"):
msg += (
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
f"this buck command to find the smallest traced graph "
f"which reproduces this error: {exc.inner_exception.buck_command}\n"
)
else:
msg += (
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
"this script to find the smallest traced graph which reproduces this error.\n"
)
if not config.suppress_errors and not export:
msg += (
"\n\n"
"You can suppress this exception and fall back to eager by setting:\n"
" import torch._dynamo\n"
" torch._dynamo.config.suppress_errors = True\n"
)
old_msg = "" if len(exc.args) == 0 else str(exc.args[0])
if isinstance(exc, KeyError):
exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
else:
new_msg = old_msg + msg
exc.args = (new_msg,) + exc.args[1:]
def get_exc_message(
e: Exception, compile_id: "CompileId"
) -> Tuple[Optional[str], Optional[int]]:
filename = None
lineno = None
if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined]
filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined]
lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined]
e.compile_id = compile_id # type: ignore[attr-defined]
return filename, lineno
def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]:
real_stack = getattr(exc, "real_stack", None)
if real_stack is None:
return None
# NB: it's possible for real_stack to be []; we still attempt to
# report a stack anyway because the stack_above_dynamo may still
# be useful for debugging
stack_above_dynamo = []
if frame is not None:
# NB: frame is PyInterpreterFrame on Python 3.11 and later,
# not a TRUE frame object. You can't actually feed it
# to traceback because it doesn't have enough information.
# To solve this problem, we technically should just materialize
# the frame, the same way _PyFrame_GetFrameObject would do
# (but we cannot actually do this, because this populates
# frame_obj field, which default eval frame doesn't like).
#
# Fortunately, in this case, we can hack it: there's no need
# to actually use the truly top frame, we can just extract
# from where we are right now and rely on filter_stack to
# get rid of all the dynamo frames. For ease of testing
# we apply this behavior to ALL Python versions
stack_above_dynamo = filter_stack(extract_stack())
return cast(StackSummary, stack_above_dynamo + real_stack)
# filter out all frames after entering dynamo
def filter_stack(stack):
user_stack = []
for frame in stack:
if "convert_frame" in frame.filename:
break
if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line:
continue
user_stack.append(frame)
return user_stack
def format_error_msg_verbose(
exc: Exception, code, record_filename=None, frame=None
) -> str:
msg = (
f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n"
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += format_exc()
real_stack = get_real_stack(exc, frame)
if real_stack is not None:
msg += (
"\n"
+ "=" * 10
+ " The above exception occurred while processing the following code "
+ "=" * 10
+ "\n\n"
)
msg += "".join(format_list(real_stack))
msg += "\n"
msg += "=" * 10
return msg
def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str:
msg = os.linesep * 2
if config.verbose:
msg = format_error_msg_verbose(exc, code, record_filename, frame)
else:
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
line {code.co_firstlineno} \ndue to: \n{format_exc()}"
return msg

View File

@ -0,0 +1,144 @@
# mypy: allow-untyped-defs
# This module contains functions that *will be allowed* by dynamo
import functools
import warnings
from typing import List
import torch
import torch.utils._pytree as pytree
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
If need to check specifically that TorchDynamo is used, then use
torch.compiler.is_dynamo_compiling().
TODO(khabinov): we should deprecate this function and use one of these two:
* torch.compiler.is_compiling(),
* torch.compiler.is_dynamo_compiling().
It will depend on the context where to use what.
"""
return torch.compiler.is_compiling()
def wrap_inline(fn):
"""
Create an extra frame around fn that is not in skipfiles
"""
@functools.wraps(fn)
def inner(*args, **kwargs):
return fn(*args, **kwargs)
return inner
def call_hook(hook, *args, **kwargs):
"""
Used by compiled autograd to handle hook returning None
"""
result = hook(*args)
if result is None:
return args[0]
elif kwargs["hook_type"] == "post_acc_grad_hook":
raise RuntimeError("Tensor post accumulate grad hooks should return None.")
return result
def wrap_numpy(f):
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
from ``torch.Tensor``s to ``torch.Tensor``s.
"""
if not np:
return f
@functools.wraps(f)
def wrap(*args, **kwargs):
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
)
out = f(*args, **kwargs)
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
return wrap
class FakeBackwardCFunction:
def __init__(
self,
real: torch.autograd.function.BackwardCFunction,
saved_tensors: List[torch.Tensor],
) -> None:
self.real = real
self.saved_tensors = saved_tensors
def __getattr__(self, name):
if name == "saved_variables":
warnings.warn(
"'saved_variables' is deprecated; use 'saved_tensors'",
DeprecationWarning,
)
return self.saved_tensors
# route any attribute that isn't defined on this obj
return getattr(self.real, name)
# This function corresponds to the "eager" implementation of a lifted autograd.Function.backward
def call_backward(backward_c_function, saved_tensors, *args):
fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
# in eager, we wrap in a tuple when there's only one grad output
if type(grads) is not tuple:
grads = (grads,)
return grads
def untyped_storage_size(x: torch.Tensor):
return x.untyped_storage().size()
class FakeCompiledAutogradEngine:
@staticmethod
def queue_callback(final_callbacks, cb):
final_callbacks.append(cb)
@staticmethod
def exec_final_callbacks(final_callbacks):
i = 0
while i < len(final_callbacks):
cb = final_callbacks[i]
cb()
i += 1
final_callbacks.clear()
@staticmethod
def _exec_final_callbacks_stub():
pass
def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
return getattr(bw_state, hook_name)(*args, **kwargs)
def call_module_hooks_from_backward_state(
_, result, *args, bw_state, hooks_name: str, module_name: str
):
module = getattr(bw_state, module_name)
hooks = getattr(bw_state, hooks_name)
for hook in hooks:
new_result = hook(module, result, *args)
if new_result is not None:
result = new_result
return result

View File

@ -0,0 +1,57 @@
import tokenize
from typing import Dict, List, Optional
cache: Dict[str, Dict[int, str]] = {}
def clearcache() -> None:
cache.clear()
def _add_file(filename: str) -> None:
try:
with tokenize.open(filename) as f:
tokens = list(tokenize.generate_tokens(f.readline))
except OSError:
cache[filename] = {}
return
# NOTE: undefined behavior if file is not valid Python source,
# since tokenize will have undefined behavior.
result: Dict[int, str] = {}
# current full funcname, e.g. xxx.yyy.zzz
cur_name = ""
cur_indent = 0
significant_indents: List[int] = []
for i, token in enumerate(tokens):
if token.type == tokenize.INDENT:
cur_indent += 1
elif token.type == tokenize.DEDENT:
cur_indent -= 1
# possible end of function or class
if significant_indents and cur_indent == significant_indents[-1]:
significant_indents.pop()
# pop the last name
cur_name = cur_name.rpartition(".")[0]
elif (
token.type == tokenize.NAME
and i + 1 < len(tokens)
and tokens[i + 1].type == tokenize.NAME
and (token.string == "class" or token.string == "def")
):
# name of class/function always follows class/def token
significant_indents.append(cur_indent)
if cur_name:
cur_name += "."
cur_name += tokens[i + 1].string
result[token.start[0]] = cur_name
cache[filename] = result
def get_funcname(filename: str, lineno: int) -> Optional[str]:
if filename not in cache:
_add_file(filename)
return cache[filename].get(lineno, None)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,12 @@
import dataclasses
from typing import Callable, Optional
from torch._guards import GuardsSet
from .types import GuardFail
@dataclasses.dataclass
class Hooks:
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None

View File

@ -0,0 +1,59 @@
# mypy: allow-untyped-defs
import itertools
import logging
from torch.hub import _Faketqdm, tqdm
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
disable_progress = True
# Return all loggers that torchdynamo/torchinductor is responsible for
def get_loggers():
return [
logging.getLogger("torch.fx.experimental.symbolic_shapes"),
logging.getLogger("torch._dynamo"),
logging.getLogger("torch._inductor"),
]
# Creates a logging function that logs a message with a step # prepended.
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
# so that step numbers are initialized properly. e.g.:
# @functools.lru_cache(None)
# def _step_logger():
# return get_step_logger(logging.getLogger(...))
# def fn():
# _step_logger()(logging.INFO, "msg")
_step_counter = itertools.count(1)
# Update num_steps if more phases are added: Dynamo, AOT, Backend
# This is very inductor centric
# _inductor.utils.has_triton() gives a circular import error here
if not disable_progress:
try:
import triton # noqa: F401
num_steps = 3
except ImportError:
num_steps = 2
pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
def get_step_logger(logger):
if not disable_progress:
pbar.update(1)
if not isinstance(pbar, _Faketqdm):
pbar.set_postfix_str(f"{logger.name}")
step = next(_step_counter)
def log(level, msg, **kwargs):
logger.log(level, "Step %s: %s", step, msg, **kwargs)
return log

View File

@ -0,0 +1,150 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="method-assign"
import functools
import weakref
import torch.nn
from torch.nn import Module
from . import config
from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks
unpatched_nn_module_init = torch.nn.Module.__init__
class MutationTracker:
db = ExactWeakKeyDictionary()
def __init__(self):
self.mutation_count = 0
self.watchers = []
def on_mutation(self, name):
self.mutation_count += 1
tmp = self.watchers
self.watchers = []
for ref in tmp:
guarded = ref()
if guarded is not None:
guarded.invalidate(ref)
def track(self, guarded_code):
self.watchers.append(weakref.ref(guarded_code))
def watch(obj, guarded_code):
"""invalidate guarded_code when obj is mutated"""
ensure_patched(type(obj))
if obj not in MutationTracker.db:
MutationTracker.db[obj] = MutationTracker()
tracker = MutationTracker.db[obj]
tracker.track(guarded_code)
def ensure_patched(cls):
if getattr(cls, "___needs_mutation_patch", True):
cls.___needs_mutation_patch = False
original_setattr = cls.__setattr__
@functools.wraps(original_setattr)
def custom_setattr(self, key, value):
try:
MutationTracker.db[self].on_mutation(key)
except KeyError:
pass
return original_setattr(self, key, value)
cls.__setattr__ = custom_setattr
class GenerationTracker:
generation = 0
dynamic_classes = ExactWeakKeyDictionary()
generation_values = ExactWeakKeyDictionary()
@classmethod
def tag(cls, obj):
cls.generation_values[obj] = cls.generation
@staticmethod
def mark_class_dynamic(cls):
assert issubclass(cls, torch.nn.Module)
GenerationTracker.dynamic_classes[cls] = True
@classmethod
def get_generation_value(cls, obj):
if obj not in cls.generation_values:
return -1
return cls.generation_values[obj]
@classmethod
def check(cls, obj):
return (
obj in cls.generation_values
and cls.generation_values[obj] == cls.generation
)
@classmethod
def clear(cls):
cls.generation = 0
cls.dynamic_classes = ExactWeakKeyDictionary()
cls.generation_values = ExactWeakKeyDictionary()
def is_dynamic_nn_module(obj, is_export):
"""Check for nn.Modules() created dynamically or mutated"""
if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__:
# A monkey patched `.forward` indicates something wacky is going on
return True
if hasattr(obj, "torchdynamo_force_dynamic"):
return obj.torchdynamo_force_dynamic
if is_lazy_module(obj):
return False
# For export, we will have to fix
# 1) Input signature problem because params are lifted as inputs
# 2) nn module stack info changes
# 3) adjust failing tests
if (
isinstance(obj, torch.nn.Module)
and config.inline_inbuilt_nn_modules
and not is_export
):
return True
if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks():
return True
dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
obj
)
return dyn
def install_generation_tagging_init():
"""
Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
so we can detect nn.Module instances created dynamically inside forward methods.
"""
if getattr(Module, "___needs_generation_tag_patch", True):
init = Module.__init__
def patched_init(self, *args, **kwargs):
init(self, *args, **kwargs)
GenerationTracker.tag(self)
Module.__init__ = patched_init
setstate = Module.__setstate__
def patched_setstate(self, state):
setstate(self, state)
GenerationTracker.tag(self)
Module.__setstate__ = patched_setstate
Module.___needs_generation_tag_patch = False # type: ignore[attr-defined]
GenerationTracker.generation += 1

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,166 @@
"""
Python polyfills for common builtins.
"""
# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports.
# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py.
# Add it in the TYPE_CHECKING block below as well.
# mypy: allow-untyped-defs
from typing import Any, Callable, Sequence, TYPE_CHECKING
import torch
if TYPE_CHECKING:
# Load by torch._dynamo.polyfills.loader
# See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py
# Put the submodules here to avoid circular imports
from . import (
builtins as builtins,
functools as functools,
itertools as itertools,
os as os,
sys as sys,
)
def index(iterator, item, start=0, end=None):
from itertools import islice
for i, elem in islice(enumerate(iterator), start, end):
if item == elem:
return i
# This will not run in dynamo
raise ValueError(f"{item} is not in {type(iterator)}")
def repeat(item, count):
for i in range(count):
yield item
def radians(x):
import math
return math.pi / 180.0 * x
def accumulate_grad(x, new_grad):
new_grad = torch.clone(new_grad)
if x.grad is None:
x.grad = new_grad
else:
x.grad.add_(new_grad)
def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]):
"""emulate `(1,2,3) > (1,2)` etc"""
for a, b in zip(left, right):
if a != b:
return op(a, b)
return op(len(left), len(right))
def set_isdisjoint(set1, set2):
for x in set1:
if x in set2:
return False
return True
def set_intersection(set1, set2):
intersection_set = set()
for x in set1:
if x in set2:
intersection_set.add(x)
return intersection_set
def set_union(set1, set2):
union_set = set1.copy()
for x in set2:
if x not in union_set:
union_set.add(x)
return union_set
def set_difference(set1, set2):
difference_set = set()
for x in set1:
if x not in set2:
difference_set.add(x)
return difference_set
def dropwhile(predicate, iterable):
# dropwhile(lambda x: x<5, [1,4,6,4,1]) -> 6 4 1
iterable = iter(iterable)
for x in iterable:
if not predicate(x):
yield x
break
yield from iterable
def zip_longest(*iterables, fillvalue=None):
# Create a list of iterators from the input iterables
iterators = [iter(it) for it in iterables]
result = []
while True:
row = []
active = False
for it in iterators:
try:
# Try to get the next item from the iterator
value = next(it)
row.append(value)
active = True
except StopIteration:
# If the iterator is exhausted, use the fillvalue
row.append(fillvalue)
if not active:
break
result.append(tuple(row))
return result
def getattr_and_trace(*args, **kwargs):
wrapper_obj = args[0]
attr_name = args[1]
fn = getattr(wrapper_obj, attr_name)
return fn(*args[2:], **kwargs)
def mapping_get(obj, key, value=None):
try:
return obj.__getitem__(key)
except KeyError:
return value
def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
obj = cls.__new__(cls, *args, **kwargs)
# Only call __init__ if the object is an instance of the class
# Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673
if isinstance(obj, cls):
obj.__init__(*args, **kwargs)
return obj
def foreach_lerp_inplace(self, end, weight):
# decompose foreach lerp into constituent ops, prevents a graph break due to
# converting a value to a scalar when arg[2] is a single tensor
result = torch._foreach_sub(end, self)
result = torch._foreach_mul(result, weight)
return torch._foreach_add_(self, result)
def foreach_pow_scalar(scalar, exps):
return torch._foreach_pow([scalar for _ in exps], exps)
def addcmul_inplace(self, tensor1, tensor2, value):
return self.add_(tensor1 * tensor2 * value)

View File

@ -0,0 +1,48 @@
"""
Python polyfills for builtins
"""
from __future__ import annotations
import builtins
from typing import Iterable, TypeVar
from ..decorators import substitute_in_graph
__all__ = [
"all",
"any",
"enumerate",
]
_T = TypeVar("_T")
@substitute_in_graph(builtins.all, can_constant_fold_through=True)
def all(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if not elem:
return False
return True
@substitute_in_graph(builtins.any, can_constant_fold_through=True)
def any(iterable: Iterable[object], /) -> bool:
for elem in iterable:
if elem:
return True
return False
@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type]
def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]:
if not isinstance(start, int):
raise TypeError(
f"{type(start).__name__!r} object cannot be interpreted as an integer"
)
for x in iterable:
yield start, x
start += 1

View File

@ -0,0 +1,6 @@
"""
Python polyfills for functools
"""
__all__ = [] # type: ignore[var-annotated]

View File

@ -0,0 +1,85 @@
"""
Python polyfills for itertools
"""
from __future__ import annotations
import itertools
from typing import Iterable, Iterator, TypeVar
from ..decorators import substitute_in_graph
__all__ = [
"chain",
"chain_from_iterable",
"islice",
"tee",
]
_T = TypeVar("_T")
# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
for iterable in iterables:
yield from iterable
@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
return itertools.chain(*iterable)
chain.from_iterable = chain_from_iterable # type: ignore[method-assign]
# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
s = slice(*args)
start = 0 if s.start is None else s.start
stop = s.stop
step = 1 if s.step is None else s.step
if start < 0 or (stop is not None and stop < 0) or step <= 0:
raise ValueError(
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.",
)
if stop is None:
# TODO: use indices = itertools.count() and merge implementation with the else branch
# when we support infinite iterators
next_i = start
for i, element in enumerate(iterable):
if i == next_i:
yield element
next_i += step
else:
indices = range(max(start, stop))
next_i = start
for i, element in zip(indices, iterable):
if i == next_i:
yield element
next_i += step
# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
@substitute_in_graph(itertools.tee)
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
iterator = iter(iterable)
shared_link = [None, None]
def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
try:
while True:
if link[1] is None:
link[0] = next(iterator)
link[1] = [None, None]
value, link = link
yield value
except StopIteration:
return
return tuple(_tee(shared_link) for _ in range(n))

Some files were not shown because too many files have changed in this diff Show More