I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,317 @@
# mypy: allow-untyped-defs
import copy
import dataclasses
import functools
import io
import json
import logging
import os
import re
import sys
import types
import warnings
import weakref
import zipfile
from collections import OrderedDict
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._utils_internal import log_export_usage
from torch.export._tree_utils import reorder_kwargs
from torch.export.graph_signature import (
ArgumentSpec,
ConstantArgument,
ExportGraphSignature,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
SymIntArgument,
TensorArgument,
)
from torch.fx import traceback as fx_traceback
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from .wrappers import _wrap_submodules
log = logging.getLogger(__name__)
@dataclasses.dataclass
class ExportDynamoConfig:
"""
Manage Export-specific configurations of Dynamo.
"""
allow_rnn: bool = True
# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
# is called multiple times.
@lru_cache
def capture_pre_autograd_graph_warning():
from torch._inductor import config
log.warning("+============================+")
log.warning("| !!! WARNING !!! |")
log.warning("+============================+")
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
log.warning("Please switch to use torch.export.export_for_training instead.")
if config.is_fbcode():
log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
@compatibility(is_backward_compatible=False)
def capture_pre_autograd_graph(
f: torch.nn.Module,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> torch.nn.Module:
"""
A helper function that is intended to trace a module before any pre-autograd
decomposition is run. The produced module will be "non-functional" and
composed of aten operators. Later this API will be deleted in favor of more general
torch.export API.
Args:
f: nn.Module to be traced
args: example positional inputs.
kwargs: optional example keyword inputs.
dynamic_shapes: Should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
Returns:
An nn.Module containing the traced method.
"""
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
from torch._export.non_strict_utils import make_constraints
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._unlift import _create_stateful_graph_module
from torch.export.dynamic_shapes import _combine_args
capture_pre_autograd_graph_warning()
if sys.platform == "win32":
raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
if kwargs is None:
kwargs = {}
if capture_pre_autograd_graph_using_training_ir():
@lru_cache
def print_export_warning():
log.warning("Using torch.export.export_for_training(...,strict=True)")
print_export_warning()
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
else:
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
decomp_table = {
op: op.decompose
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
assume_static_by_default=True,
tracing_mode="symbolic",
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
_log_export_usage=False,
)(
*args,
**kwargs,
)[0]
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
m.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}
if isinstance(f, torch.nn.Module):
from torch.export._trace import _restore_state_dict
_restore_state_dict(f, m)
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
combined_args = _combine_args(f, args, kwargs)
range_constraints = make_constraints(
fake_mode,
m,
combined_args,
dynamic_shapes,
0,
)
module = _create_stateful_graph_module(
m,
range_constraints=range_constraints,
)
error_message = \
"""
Calling train() or eval() is not supported for exported models.
Alternatively, you may override these methods to do custom user behavior as follows:
def _my_train(self, mode: bool = True):
...
def _my_eval(self):
...
model.train = types.MethodType(_my_train, model)
model.eval = types.MethodType(_my_eval, model)
"""
def _train(self, mode: bool = True):
raise NotImplementedError(error_message)
def _eval(self, mode: bool = True):
raise NotImplementedError(error_message)
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
# Remove Proxy because they cannot be deepcopied or pickled.
if hasattr(module, "_buffers"):
torch._export.utils.remove_proxy_from_state_dict(
module._buffers, in_place=True
)
return module
def aot_compile(
f: Callable,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
same_signature: bool = True,
) -> str:
"""
Note: this function is not stable yet
Traces either an nn.Module's forward function or just a callable with PyTorch
operations inside, generates executable cpp code from the program, and returns
the path to the generated shared library
Args:
f: the `nn.Module` or callable to trace.
args: example positional inputs.
kwargs: optional example keyword inputs.
dynamic_shapes: Should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
options: A dictionary of options to control inductor
disable_constraint_solver: Whether the dim constraint solver must be disabled.
Returns:
Path to the generated shared library
"""
from torch.export._trace import _export_to_torch_ir
from torch._inductor.decomposition import select_decomp_table
from torch._inductor import config
if config.is_predispatch:
gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
else:
# We want to export to Torch IR here to utilize the pre_grad passes in
# inductor, which run on Torch IR.
gm = _export_to_torch_ir(
f,
args,
kwargs,
dynamic_shapes,
disable_constraint_solver=disable_constraint_solver,
same_signature=same_signature,
# Disabling this flag, because instead we can rely on the mapping
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
restore_fqn=False,
)
with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type]
return so_path
def aot_load(so_path: str, device: str) -> Callable:
"""
Loads a shared library generated by aot_compile and returns a callable
Args:
so_path: Path to the shared library
Returns:
A callable
"""
if device == "cpu":
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
elif device == "cuda" or device.startswith("cuda:"):
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
else:
raise RuntimeError("Unsupported device " + device)
def optimized(*args, **kwargs):
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)
return optimized

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,174 @@
# mypy: allow-untyped-defs
import inspect
import re
import string
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple
from types import ModuleType
import torch
_TAGS: Dict[str, Dict[str, Any]] = {
"torch": {
"cond": {},
"dynamic-shape": {},
"escape-hatch": {},
"map": {},
"dynamic-value": {},
"operator": {},
"mutation": {},
},
"python": {
"assert": {},
"builtin": {},
"closure": {},
"context-manager": {},
"control-flow": {},
"data-structure": {},
"standard-library": {},
"object-model": {},
},
}
class SupportLevel(Enum):
"""
Indicates at what stage the feature
used in the example is handled in export.
"""
SUPPORTED = 1
NOT_SUPPORTED_YET = 0
ArgsType = Tuple[Any, ...]
def check_inputs_type(args, kwargs):
if not isinstance(args, tuple):
raise ValueError(
f"Expecting args type to be a tuple, got: {type(args)}"
)
if not isinstance(kwargs, dict):
raise ValueError(
f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
)
for key in kwargs:
if not isinstance(key, str):
raise ValueError(
f"Expecting kwargs keys to be a string, got: {type(key)}"
)
def _validate_tag(tag: str):
parts = tag.split(".")
t = _TAGS
for part in parts:
assert set(part) <= set(
string.ascii_lowercase + "-"
), f"Tag contains invalid characters: {part}"
if part in t:
t = t[part]
else:
raise ValueError(f"Tag {tag} is not found in registered tags.")
@dataclass(frozen=True)
class ExportCase:
example_args: ArgsType
description: str # A description of the use case.
model: torch.nn.Module
name: str
example_kwargs: Dict[str, Any] = field(default_factory=dict)
extra_args: Optional[ArgsType] = None # For testing graph generalization.
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
tags: Set[str] = field(default_factory=set)
support_level: SupportLevel = SupportLevel.SUPPORTED
dynamic_shapes: Optional[Dict[str, Any]] = None
def __post_init__(self):
check_inputs_type(self.example_args, self.example_kwargs)
if self.extra_args is not None:
check_inputs_type(self.extra_args, {})
for tag in self.tags:
_validate_tag(tag)
if not isinstance(self.description, str) or len(self.description) == 0:
raise ValueError(f'Invalid description: "{self.description}"')
_EXAMPLE_CASES: Dict[str, ExportCase] = {}
_MODULES: Set[ModuleType] = set()
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
def register_db_case(case: ExportCase) -> None:
"""
Registers a user provided ExportCase into example bank.
"""
if case.name in _EXAMPLE_CASES:
if case.name not in _EXAMPLE_CONFLICT_CASES:
_EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
_EXAMPLE_CONFLICT_CASES[case.name].append(case)
return
_EXAMPLE_CASES[case.name] = case
def to_snake_case(name):
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
def _make_export_case(m, name, configs):
if not isinstance(m, torch.nn.Module):
raise TypeError("Export case class should be a torch.nn.Module.")
if "description" not in configs:
# Fallback to docstring if description is missing.
assert (
m.__doc__ is not None
), f"Could not find description or docstring for export case: {m}"
configs = {**configs, "description": m.__doc__}
return ExportCase(**{**configs, "model": m, "name": name})
def export_case(**kwargs):
"""
Decorator for registering a user provided case into example bank.
"""
def wrapper(m):
configs = kwargs
module = inspect.getmodule(m)
if module in _MODULES:
raise RuntimeError("export_case should only be used once per example file.")
assert module is not None
_MODULES.add(module)
module_name = module.__name__.split(".")[-1]
case = _make_export_case(m, module_name, configs)
register_db_case(case)
return case
return wrapper
def export_rewrite_case(**kwargs):
def wrapper(m):
configs = kwargs
parent = configs.pop("parent")
assert isinstance(parent, ExportCase)
key = parent.name
if key not in _EXAMPLE_REWRITE_CASES:
_EXAMPLE_REWRITE_CASES[key] = []
configs["example_args"] = parent.example_args
case = _make_export_case(m, to_snake_case(m.__name__), configs)
_EXAMPLE_REWRITE_CASES[key].append(case)
return case
return wrapper

View File

@ -0,0 +1,61 @@
# mypy: allow-untyped-defs
import dataclasses
import glob
import inspect
from os.path import basename, dirname, isfile, join
import torch
from torch._export.db.case import (
_EXAMPLE_CASES,
_EXAMPLE_CONFLICT_CASES,
_EXAMPLE_REWRITE_CASES,
SupportLevel,
export_case,
ExportCase,
)
def _collect_examples():
case_names = glob.glob(join(dirname(__file__), "*.py"))
case_names = [
basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py")
]
case_fields = {f.name for f in dataclasses.fields(ExportCase)}
for case_name in case_names:
case = __import__(case_name, globals(), locals(), [], 1)
variables = [name for name in dir(case) if name in case_fields]
export_case(**{v: getattr(case, v) for v in variables})(case.model)
_collect_examples()
def all_examples():
return _EXAMPLE_CASES
if len(_EXAMPLE_CONFLICT_CASES) > 0:
def get_name(case):
model = case.model
if isinstance(model, torch.nn.Module):
model = type(model)
return model.__name__
msg = "Error on conflict export case name.\n"
for case_name, cases in _EXAMPLE_CONFLICT_CASES.items():
msg += f"Case name {case_name} is associated with multiple cases:\n "
msg += f"[{','.join(map(get_name, cases))}]\n"
raise RuntimeError(msg)
def filter_examples_by_support_level(support_level: SupportLevel):
return {
key: val
for key, val in all_examples().items()
if val.support_level == support_level
}
def get_rewrite_cases(case):
return _EXAMPLE_REWRITE_CASES.get(case.name, [])

View File

@ -0,0 +1,20 @@
# mypy: allow-untyped-defs
import torch
import torch._dynamo as torchdynamo
class AssumeConstantResult(torch.nn.Module):
"""
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
"""
@torchdynamo.assume_constant_result
def get_item(self, y):
return y.int().item()
def forward(self, x, y):
return x[: self.get_item(y)]
example_args = (torch.randn(3, 2), torch.tensor(4))
tags = {"torch.escape-hatch"}
model = AssumeConstantResult()

View File

@ -0,0 +1,23 @@
# mypy: allow-untyped-defs
import torch
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return grad_output + 1
class AutogradFunction(torch.nn.Module):
"""
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
use `allow_in_graph` to mitigate this problem.
"""
def forward(self, x):
return MyAutogradFunction.apply(x)
example_args = (torch.randn(3, 2),)
model = AutogradFunction()

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
import torch
class ClassMethod(torch.nn.Module):
"""
Class methods are inlined during tracing.
"""
@classmethod
def method(cls, x):
return x + 1
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 2)
def forward(self, x):
x = self.linear(x)
return self.method(x) * self.__class__.method(x) * type(self).method(x)
example_args = (torch.randn(3, 4),)
model = ClassMethod()

View File

@ -0,0 +1,44 @@
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using class method in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self) -> None:
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
example_args = (torch.randn(3),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondBranchClassMethod()

View File

@ -0,0 +1,41 @@
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondBranchNestedFunction(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using nested function in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def forward(self, x):
def true_fn(x):
def inner_true_fn(y):
return x + y
return inner_true_fn(x)
def false_fn(x):
def inner_false_fn(y):
return x - y
return inner_false_fn(x)
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
example_args = (torch.randn(3),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondBranchNestedFunction()

View File

@ -0,0 +1,59 @@
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondBranchNonlocalVariables(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
The code below will not work because capturing closure variables is not supported.
```
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y + my_tensor_var + my_primitive_var
def false_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y - my_tensor_var - my_primitive_var
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
```
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def forward(self, x):
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(x, y, z):
return x + y + z
def false_fn(x, y, z):
return x - y - z
return cond(
x.shape[0] > 5,
true_fn,
false_fn,
[x, my_tensor_var, torch.tensor(my_primitive_var)],
)
example_args = (torch.randn(6),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondBranchNonlocalVariables()

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondClosedOverVariable(torch.nn.Module):
"""
torch.cond() supports branches closed over arbitrary variables.
"""
def forward(self, pred, x):
def true_fn(val):
return x * 2
def false_fn(val):
return x - 2
return cond(pred, true_fn, false_fn, [x + 1])
example_args = (torch.tensor(True), torch.randn(3, 2))
tags = {"torch.cond", "python.closure"}
model = CondClosedOverVariable()

View File

@ -0,0 +1,36 @@
# mypy: allow-untyped-defs
import torch
from torch.export import Dim
from functorch.experimental.control_flow import cond
x = torch.randn(3, 2)
y = torch.randn(2)
dim0_x = Dim("dim0_x")
class CondOperands(torch.nn.Module):
"""
The operands passed to cond() must be:
- a list of tensors
- match arguments of `true_fn` and `false_fn`
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def forward(self, x, y):
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
example_args = (x, y)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
extra_inputs = (torch.randn(2, 2), torch.randn(2))
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
model = CondOperands()

View File

@ -0,0 +1,25 @@
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondPredicate(torch.nn.Module):
"""
The conditional statement (aka predicate) passed to cond() must be one of the following:
- torch.Tensor with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def forward(self, x):
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
example_args = (torch.randn(6, 4, 3),)
tags = {
"torch.cond",
"torch.dynamic-shape",
}
model = CondPredicate()

View File

@ -0,0 +1,25 @@
# mypy: allow-untyped-defs
import torch
class ConstrainAsSizeExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
can trace further. Please look at torch._check and torch._check_is_size APIs.
torch._check_is_size is used for values that NEED to be used for constructing
tensor.
"""
def forward(self, x):
a = x.item()
torch._check_is_size(a)
torch._check(a <= 5)
return torch.zeros((a, 5))
example_args = (torch.tensor(4),)
tags = {
"torch.dynamic-value",
"torch.escape-hatch",
}
model = ConstrainAsSizeExample()

View File

@ -0,0 +1,28 @@
# mypy: allow-untyped-defs
import torch
class ConstrainAsValueExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
can trace further. Please look at torch._check and torch._check_is_size APIs.
torch._check is used for values that don't need to be used for constructing
tensor.
"""
def forward(self, x, y):
a = x.item()
torch._check(a >= 0)
torch._check(a <= 5)
if a < 6:
return y.sin()
return y.cos()
example_args = (torch.tensor(4), torch.randn(5, 5))
tags = {
"torch.dynamic-value",
"torch.escape-hatch",
}
model = ConstrainAsValueExample()

View File

@ -0,0 +1,23 @@
# mypy: allow-untyped-defs
import functools
import torch
def test_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs) + 1
return wrapper
class Decorator(torch.nn.Module):
"""
Decorators calls are inlined into the exported function during tracing.
"""
@test_decorator
def forward(self, x, y):
return x + y
example_args = (torch.randn(3, 2), torch.randn(3, 2))
model = Decorator()

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
import torch
class Dictionary(torch.nn.Module):
"""
Dictionary structures are inlined and flattened along tracing.
"""
def forward(self, x, y):
elements = {}
elements["x2"] = x * x
y = y * elements["x2"]
return {"y": y}
example_args = (torch.randn(3, 2), torch.tensor(4))
tags = {"python.data-structure"}
model = Dictionary()

View File

@ -0,0 +1,18 @@
# mypy: allow-untyped-defs
import torch
class DynamicShapeAssert(torch.nn.Module):
"""
A basic usage of python assertion.
"""
def forward(self, x):
# assertion with error message
assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
# assertion without error message
assert x.shape[0] > 1
return x
example_args = (torch.randn(3, 2),)
tags = {"python.assert"}
model = DynamicShapeAssert()

View File

@ -0,0 +1,15 @@
# mypy: allow-untyped-defs
import torch
class DynamicShapeConstructor(torch.nn.Module):
"""
Tensor constructors should be captured with dynamic shape inputs rather
than being baked in with static shape.
"""
def forward(self, x):
return torch.zeros(x.shape[0] * 2)
example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeConstructor()

View File

@ -0,0 +1,19 @@
# mypy: allow-untyped-defs
import torch
class DynamicShapeIfGuard(torch.nn.Module):
"""
`if` statement with backed dynamic shape predicate will be specialized into
one particular branch and generate a guard. However, export will fail if the
the dimension is marked as dynamic shape from higher level API.
"""
def forward(self, x):
if x.shape[0] == 3:
return x.cos()
return x.sin()
example_args = (torch.randn(3, 2, 2),)
tags = {"torch.dynamic-shape", "python.control-flow"}
model = DynamicShapeIfGuard()

View File

@ -0,0 +1,19 @@
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import map
class DynamicShapeMap(torch.nn.Module):
"""
functorch map() maps a function over the first tensor dimension.
"""
def forward(self, xs, y):
def body(x, y):
return x + y
return map(body, xs, y)
example_args = (torch.randn(3, 2), torch.randn(2))
tags = {"torch.dynamic-shape", "torch.map"}
model = DynamicShapeMap()

View File

@ -0,0 +1,21 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel
from torch.export import Dim
class DynamicShapeRound(torch.nn.Module):
"""
Calling round on dynamic shapes is not supported.
"""
def forward(self, x):
return x[: round(x.shape[0] / 2)]
x = torch.randn(3, 2)
dim0_x = Dim("dim0_x")
example_args = (x,)
tags = {"torch.dynamic-shape", "python.builtin"}
support_level = SupportLevel.NOT_SUPPORTED_YET
dynamic_shapes = {"x": {0: dim0_x}}
model = DynamicShapeRound()

View File

@ -0,0 +1,15 @@
# mypy: allow-untyped-defs
import torch
class DynamicShapeSlicing(torch.nn.Module):
"""
Slices with dynamic shape arguments should be captured into the graph
rather than being baked in.
"""
def forward(self, x):
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeSlicing()

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
import torch
class DynamicShapeView(torch.nn.Module):
"""
Dynamic shapes should be propagated to view arguments instead of being
baked into the exported graph.
"""
def forward(self, x):
new_x_shape = x.size()[:-1] + (2, 5)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1)
example_args = (torch.randn(10, 10),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeView()

View File

@ -0,0 +1,30 @@
# mypy: allow-untyped-defs
import torch
class FnWithKwargs(torch.nn.Module):
"""
Keyword arguments are not supported at the moment.
"""
def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
out = pos0
for arg in tuple0:
out = out * arg
for arg in myargs:
out = out * arg
out = out * mykw0
out = out * mykwargs["input0"] * mykwargs["input1"]
return out
example_args = (
torch.randn(4),
(torch.randn(4), torch.randn(4)),
*[torch.randn(4), torch.randn(4)]
)
example_kwargs = {
"mykw0": torch.randn(4),
"input0": torch.randn(4),
"input1": torch.randn(4),
}
tags = {"python.data-structure"}
model = FnWithKwargs()

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
import torch
class ListContains(torch.nn.Module):
"""
List containment relation can be checked on a dynamic shape or constants.
"""
def forward(self, x):
assert x.size(-1) in [6, 2]
assert x.size(0) not in [4, 5, 6]
assert "monkey" not in ["cow", "pig"]
return x + x
example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"}
model = ListContains()

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
from typing import List
import torch
class ListUnpack(torch.nn.Module):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
def forward(self, args: List[torch.Tensor]):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
x, *y = args
return x + y[0]
example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
tags = {"python.control-flow", "python.data-structure"}
model = ListUnpack()

View File

@ -0,0 +1,26 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel
class ModelAttrMutation(torch.nn.Module):
"""
Attribute mutation is not supported.
"""
def __init__(self) -> None:
super().__init__()
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]
def recreate_list(self):
return [torch.zeros(3, 2), torch.zeros(3, 2)]
def forward(self, x):
self.attr_list = self.recreate_list()
return x.sum() + self.attr_list[0].sum()
example_args = (torch.randn(3, 2),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = ModelAttrMutation()

View File

@ -0,0 +1,23 @@
# mypy: allow-untyped-defs
import torch
class NestedFunction(torch.nn.Module):
"""
Nested functions are traced through. Side effects on global captures
are not supported though.
"""
def forward(self, a, b):
x = a + b
z = a - b
def closure(y):
nonlocal x
x += 1
return x * y + z
return closure(x)
example_args = (torch.randn(3, 2), torch.randn(2))
tags = {"python.closure"}
model = NestedFunction()

View File

@ -0,0 +1,21 @@
# mypy: allow-untyped-defs
import contextlib
import torch
class NullContextManager(torch.nn.Module):
"""
Null context manager in Python will be traced out.
"""
def forward(self, x):
"""
Null context manager in Python will be traced out.
"""
ctx = contextlib.nullcontext()
with ctx:
return x.sin() + x.cos()
example_args = (torch.randn(3, 2),)
tags = {"python.context-manager"}
model = NullContextManager()

View File

@ -0,0 +1,20 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel
class OptionalInput(torch.nn.Module):
"""
Tracing through optional input is not supported yet
"""
def forward(self, x, y=torch.randn(2, 3)):
if y is not None:
return x + y
return x
example_args = (torch.randn(2, 3),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = OptionalInput()

View File

@ -0,0 +1,16 @@
# mypy: allow-untyped-defs
import torch
from torch.utils import _pytree as pytree
class PytreeFlatten(torch.nn.Module):
"""
Pytree from PyTorch can be captured by TorchDynamo.
"""
def forward(self, x):
y, spec = pytree.tree_flatten(x)
return y[0] + 1
example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
model = PytreeFlatten()

View File

@ -0,0 +1,23 @@
# mypy: allow-untyped-defs
import torch
from torch.export import Dim
x = torch.randn(3, 2)
dim1_x = Dim("dim1_x")
class ScalarOutput(torch.nn.Module):
"""
Returning scalar values from the graph is supported, in addition to Tensor
outputs. Symbolic shapes are captured and rank is specialized.
"""
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x.shape[1] + 1
example_args = (x,)
tags = {"torch.dynamic-shape"}
dynamic_shapes = {"x": {1: dim1_x}}
model = ScalarOutput()

View File

@ -0,0 +1,26 @@
# mypy: allow-untyped-defs
from enum import Enum
import torch
class Animal(Enum):
COW = "moo"
class SpecializedAttribute(torch.nn.Module):
"""
Model attributes are specialized.
"""
def __init__(self) -> None:
super().__init__()
self.a = "moo"
self.b = 4
def forward(self, x):
if self.a == Animal.COW.value:
return x * x + self.b
else:
raise ValueError("bad")
example_args = (torch.randn(3, 2),)
model = SpecializedAttribute()

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
import torch
class StaticForLoop(torch.nn.Module):
"""
A for loop with constant number of iterations should be unrolled in the exported graph.
"""
def forward(self, x):
ret = []
for i in range(10): # constant
ret.append(i + x)
return ret
example_args = (torch.randn(3, 2),)
tags = {"python.control-flow"}
model = StaticForLoop()

View File

@ -0,0 +1,18 @@
# mypy: allow-untyped-defs
import torch
class StaticIf(torch.nn.Module):
"""
`if` statement with static predicate value should be traced through with the
taken branch.
"""
def forward(self, x):
if len(x.shape) == 3:
return x + torch.ones(1, 1, 1)
return x
example_args = (torch.randn(3, 2, 2),)
tags = {"python.control-flow"}
model = StaticIf()

View File

@ -0,0 +1,15 @@
# mypy: allow-untyped-defs
import torch
class TensorSetattr(torch.nn.Module):
"""
setattr() call onto tensors is not supported.
"""
def forward(self, x, attr):
setattr(x, attr, torch.randn(3, 2))
return x + 4
example_args = (torch.randn(3, 2), "attr")
tags = {"python.builtin"}
model = TensorSetattr()

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
import torch
class A:
@classmethod
def func(cls, x):
return 1 + x
class TypeReflectionMethod(torch.nn.Module):
"""
type() calls on custom objects followed by attribute accesses are not allowed
due to its overly dynamic nature.
"""
def forward(self, x):
a = A()
return type(a).func(x)
example_args = (torch.randn(3, 4),)
tags = {"python.builtin"}
model = TypeReflectionMethod()

View File

@ -0,0 +1,18 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel
class TorchSymMin(torch.nn.Module):
"""
torch.sym_min operator is not supported in export.
"""
def forward(self, x):
return x.sum() + torch.sym_min(x.size(0), 100)
example_args = (torch.randn(3, 2),)
tags = {"torch.operator"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = TorchSymMin()

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
import torch
class UserInputMutation(torch.nn.Module):
"""
Directly mutate user input in forward
"""
def forward(self, x):
x.mul_(2)
return x.cos()
example_args = (torch.randn(3, 2),)
tags = {"torch.mutation"}
model = UserInputMutation()

View File

@ -0,0 +1,21 @@
import os
import sys
import torch._export.db.examples as examples
TEMPLATE = '''import torch
def {case_name}(x):
"""
"""
return
'''
if __name__ == "__main__":
assert len(sys.argv) == 2
root_dir = examples.__name__.replace(".", "/")
assert os.path.exists(root_dir)
with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f:
print("Writing to", f.name, "...")
f.write(TEMPLATE.format(case_name=sys.argv[1]))

View File

@ -0,0 +1,47 @@
# mypy: allow-untyped-defs
def exportdb_error_message(case_name: str):
from .examples import all_examples
from torch._utils_internal import log_export_usage
ALL_EXAMPLES = all_examples()
# Detect whether case_name is really registered in exportdb.
if case_name in ALL_EXAMPLES:
url_case_name = case_name.replace("_", "-")
return f"See {case_name} in exportdb for unsupported case. \
https://pytorch.org/docs/main/generated/exportdb/index.html#{url_case_name}"
else:
log_export_usage(
event="export.error.casenotregistered",
message=case_name,
)
return f"{case_name} is unsupported."
def get_class_if_classified_error(e):
"""
Returns a string case name if the export error e is classified.
Returns None otherwise.
"""
from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError
ALWAYS_CLASSIFIED = "always_classified"
DEFAULT_CLASS_SIGIL = "case_name"
# add error types that should be classified, along with any attribute name
# whose presence acts like a sigil to further distinguish which errors of
# that type should be classified. If the attribute name is None, then the
# error type is always classified.
_ALLOW_LIST = {
Unsupported: DEFAULT_CLASS_SIGIL,
UserError: DEFAULT_CLASS_SIGIL,
TorchRuntimeError: None,
}
if type(e) in _ALLOW_LIST:
attr_name = _ALLOW_LIST[type(e)]
if attr_name is None:
return ALWAYS_CLASSIFIED
return getattr(e, attr_name, None)
return None

View File

@ -0,0 +1,56 @@
from enum import Enum
class ExportErrorType(Enum):
# User providing invalid inputs to either tracer, or other public facing APIs
INVALID_INPUT_TYPE = 1
# User returning values from their models that we don't support.
INVALID_OUTPUT_TYPE = 2
# Generated IR does not conform to Export IR Specification.
VIOLATION_OF_SPEC = 3
# User's code contains types and functionalities we don't support.
NOT_SUPPORTED = 4
# User's code didn't provide necessary details for us to successfully trace and export.
# For example, we use a lot of decorators and ask users to annotate their model.
MISSING_PROPERTY = 5
# User is using an API without proper initialization step.
UNINITIALIZED = 6
def internal_assert(pred: bool, assert_msg: str) -> None:
"""
This is exir's custom assert method. It internally just throws InternalError.
Note that the sole purpose is to throw our own error while maintaining similar syntax
as python assert.
"""
if not pred:
raise InternalError(assert_msg)
class InternalError(Exception):
"""
Raised when an internal invariance is violated in EXIR stack.
Should hint users to report a bug to dev and expose the original
error message.
"""
def __init__(self, message: str) -> None:
super().__init__(message)
class ExportError(Exception):
"""
This type of exception is raised for errors that are directly caused by the user
code. In general, user errors happen during model authoring, tracing, using our public
facing APIs, and writing graph passes.
"""
def __init__(self, error_code: ExportErrorType, message: str) -> None:
prefix = f"[{error_code}]: "
super().__init__(prefix + message)

View File

@ -0,0 +1,523 @@
# mypy: allow-untyped-defs
import contextlib
import inspect
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
from torch._dynamo.source import (
AttrSource,
GetItemSource,
LocalSource,
TensorProperty,
TensorPropertySource,
)
from torch._dynamo.variables.builder import TrackedFake
from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim
from torch._export.passes.lift_constants_pass import ConstantAttrMap
from torch._guards import Source
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Constraint
from torch.export.dynamic_shapes import (
_check_dynamic_shapes,
_combine_args,
_DimHint,
_process_dynamic_shapes,
_transform_shapes_for_default_dynamic,
_tree_map_with_path,
)
from torch.export.graph_signature import CustomObjArgument
from torch.fx.experimental import _config as config
from torch.fx.experimental.symbolic_shapes import (
_find_user_code_frame,
_suggest_fixes_for_data_dependent_error_non_strict,
ConstraintViolationError,
DimDynamic,
EqualityConstraint,
GuardOnDataDependentSymNode,
ShapeEnv,
StatelessSymbolicContext,
ValueRanges,
)
from torch.utils._pytree import (
GetAttrKey,
KeyPath,
MappingKey,
SequenceKey,
tree_map_with_path,
)
if TYPE_CHECKING:
from sympy import Symbol
log = logging.getLogger(__name__)
def key_path_to_source(kp: KeyPath) -> Source:
"""
Given a key path, return the source for the key path.
"""
source: Source = LocalSource("args")
for k in kp:
if isinstance(k, SequenceKey):
source = GetItemSource(source, k.idx)
elif isinstance(k, MappingKey):
source = GetItemSource(source, k.key)
elif isinstance(k, GetAttrKey):
source = AttrSource(source, k.name)
else:
raise ValueError(f"Unknown KeyEntry {k}")
return source
def _is_constant_argument(t):
return t is None or isinstance(t, (int, float, bool, str))
def fakify(
mode: FakeTensorMode,
kp: KeyPath,
t: Any,
t_constraints: Dict[int, Dict[int, Constraint]],
sources: Dict[Tuple[int, int], List[Source]],
):
source = key_path_to_source(kp)
if _is_constant_argument(t) or isinstance(t, torch.ScriptObject):
return t
if not isinstance(t, torch.Tensor):
raise ValueError(f"Unsupported input type {type(t)}")
n_dims = len(t.shape)
symbolic_context = StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.DYNAMIC] * n_dims,
constraint_sizes=[None] * n_dims,
)
t_id = id(t)
assert mode.shape_env is not None
if t_id in t_constraints:
for i, constraint in t_constraints[t_id].items():
symbolic_context.constraint_sizes[i] = constraint.constraint_range
src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
sources[(t_id, i)].append(src)
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment]
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr]
return fake
def make_fake_inputs(
nn_module,
args,
kwargs,
dynamic_shapes,
_is_torch_jit_trace=False,
allow_complex_guards_as_runtime_asserts=False,
):
"""
Given an nn module, example inputs, and constraints, return a new fake mode,
fake inputs created in that mode whose dynamic shape dimensions are constrained
by the given ranges, and sources for pairs of dynamic shape dimensions that are
constrained to be equal.
"""
# TODO(avik): refactor Dynamo to avoid duplication of the following code
# between non-strict and strict.
# Specifically, here (non-strict) we do the following pre-tracing steps:
# - Fakify inputs.
# - Process input shape equalities.
# In strict, these steps are spread across multiple files:
# - output_graph.py fakifies inputs.
# - [post-tracing] guards.py processes input shape equalities.
combined_args = _combine_args(nn_module, args, kwargs)
_check_dynamic_shapes(combined_args, dynamic_shapes)
transformed_dynamic_shapes = _transform_shapes_for_default_dynamic(
combined_args, dynamic_shapes
)
constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes)
t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
for constraint in constraints:
t_constraints[constraint.t_id][constraint.dim] = constraint
context = torch._guards.TracingContext.try_get()
if context is not None:
# This occurs when we are exporting within dynamo. There already exists
# a toplevel TracingContext with a fake mode, so we do not want to
# create another fake mode.
fake_mode = context.fake_mode
elif not _is_torch_jit_trace:
code = nn_module.forward.__code__
co_fields = {
"co_name": code.co_name,
"co_filename": code.co_filename,
"co_firstlineno": code.co_firstlineno,
}
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(
tracked_fakes=[],
co_fields=co_fields,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
),
allow_non_fake_inputs=True,
export=True,
)
else:
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(
tracked_fakes=[],
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
),
allow_non_fake_inputs=True,
)
if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
raise ValueError(
"Detected fake_mode does not have a shape_env with tracked fakes. "
"If you constructed the module under a FakeTensorMode, "
"please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))"
)
with fake_mode:
# FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock
if not _is_torch_jit_trace:
original_signature = inspect.signature(nn_module.forward)
else:
original_signature = None
sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list)
fake_args, fake_kwargs = tree_map_with_path(
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
(args, kwargs),
)
names: Dict[str, Tuple[int, int]] = {}
source_pairs: List[Tuple[Source, Source]] = []
derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
phantom_symbols: Dict[str, Symbol] = {}
for constraint in constraints:
torch.export.dynamic_shapes._process_equalities(
constraint,
lambda t_id, dim: sources[(t_id, dim)],
fake_mode.shape_env,
names,
source_pairs,
derived_equalities,
phantom_symbols,
)
equalities_inputs = EqualityConstraint(
source_pairs=source_pairs,
derived_equalities=derived_equalities,
phantom_symbols=list(phantom_symbols.values()),
warn_only=False,
)
return (
fake_mode,
fake_args,
fake_kwargs,
equalities_inputs,
original_signature,
transformed_dynamic_shapes,
)
def _flatten_dynamic_shapes(
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
) -> List[Any]:
flat_shapes = []
def _tree_map_helper(path, t, shape):
nonlocal flat_shapes
flat_shapes.append(shape)
_tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes)
return flat_shapes
def produce_guards_and_solve_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
equalities_inputs: EqualityConstraint,
original_signature: inspect.Signature,
_is_torch_jit_trace=False,
):
"""
Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
and a graph module, produce guards on the fake mode's shape env (raising constraint
violations if any), solve (to suggest simplifications or fixes).
Dynamo already performs this, so this is for non-strict mode.
Additional inputs:
equalities_inputs: the equality constraints to use for guards
original_signature: the signature of the forward method
"""
shape_env = fake_mode.shape_env
assert shape_env is not None
assert shape_env.tracked_fakes is not None
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
sources = [tf.source for tf in shape_env.tracked_fakes]
input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
constraint_violation_error = None
try:
shape_env.produce_guards(
placeholders,
sources,
input_contexts=input_contexts,
equalities_inputs=equalities_inputs,
ignore_static=False,
)
except ConstraintViolationError as e:
constraint_violation_error = e
shape_env.frozen = True
dim_constraints = shape_env.dim_constraints
if dim_constraints is None:
# Expected when shape_env.produce_guards throws an early constraint violation error.
# There is nothing to solve for in this case.
# TODO(avik): Maybe record the constraint violation error instead and replay later?
assert constraint_violation_error
raise constraint_violation_error
dim_constraints.solve()
forced_specializations = dim_constraints.forced_specializations()
if not _is_torch_jit_trace:
msg = dim_constraints.prettify_results(
original_signature,
dynamic_shapes,
constraint_violation_error,
forced_specializations,
)
else:
# FIXME(ycao): This is a hack to get around missing signature from ScriptMethod
msg = "dummy constraint violation message"
if constraint_violation_error:
constraint_violation_error.args = (constraint_violation_error.args[0] + msg,)
elif forced_specializations:
constraint_violation_error = ConstraintViolationError(msg)
if constraint_violation_error:
raise constraint_violation_error
def make_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
num_lifted_inputs: int,
):
"""
Given a fake mode's shape env and user-specified dynamic shapes,
return the resulting range constraints and equality constraints.
Additional args:
num_lifted_inputs: the number of non-user-input placeholder nodes in the graph
(used only to enumerate the user-input nodes)
"""
shape_env = fake_mode.shape_env
assert shape_env is not None
inline_constraints = gm.meta.get("inline_constraints", [])
range_constraints = {
symbol: inline_constraints[symbol] for symbol in inline_constraints
}
if not dynamic_shapes:
return range_constraints
# get individual dynamic shapes spec for each input
if not isinstance(dynamic_shapes, dict):
assert isinstance(dynamic_shapes, (tuple, list))
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
# check number of shapes vs. number of inputs
num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True)
assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs
input_dims = defaultdict(list)
free_symbols = set()
for input_index, node in enumerate(gm.graph.nodes):
if input_index < num_lifted_inputs or node.op != "placeholder":
continue
if _is_constant_argument(node.meta["val"]) or isinstance(
node.meta["val"], CustomObjArgument
):
continue
shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs]
for i, d in enumerate(node.meta["val"].shape):
if isinstance(d, torch.SymInt) and not d.node.expr.is_number:
# Look up the range constraint for the symbol corresponding to this shape dimension
# and store it indexed by the symbolic expression corresponding to it.
# NOTE(avik): Use node._expr instead of node.expr for the lookup here because
# we want the symbol, not its replacement, which could be an expression. Maybe
# there's a better way to do this, e.g., by (re)computing value ranges for expressions?
dim = shape_spec[i] if shape_spec else None
if dim is None or isinstance(dim, _DimHint):
range_constraints[d.node.expr] = shape_env.var_to_range[
d.node._expr
]
else:
range_constraints[d.node.expr] = ValueRanges(
lower=dim.min, upper=dim.max
)
input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
free_symbols.update(d.node.expr.free_symbols)
for symbol in free_symbols:
if symbol not in range_constraints:
# Placeholders can have symbolic shapes that are derived expressions.
# The above code will record direct range constraints for them
# so that we can do runtime assertions. In addition, for serde checks
# we want to record range constraints for their root symbols.
range_constraints[symbol] = shape_env.var_to_range[symbol]
return range_constraints
def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
"""Search the module hierarchy, gathering up all tensor and ScriptObject constants.
Returns a dictionary mapping hash(value) to the name of the constant. We
have to abuse `hash` here unfortunately, see: [ScriptObject hash].
"""
constants = ConstantAttrMap()
buffers_parameters = set(m.buffers())
buffers_parameters.update(m.parameters())
def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
for k, v in m.__dict__.items():
if isinstance(
v,
(
torch.Tensor,
torch.ScriptObject,
FakeScriptObject,
),
):
if v in buffers_parameters:
# filter out buffers and parameters, leaving only constants
continue
fqn = ".".join(prefix_atoms + [k])
constants.add(v, fqn)
for k, v in m.named_children():
inner(v, prefix_atoms + [k], constants)
inner(m, [], constants)
return constants
@contextlib.contextmanager
def _fakify_script_objects(
mod: torch.nn.Module,
args: Tuple[Any],
kwargs: Dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
):
# This context manager is used to fakify script objects into FakeScriptObject.
# Inputs:
# mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified.
# args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified.
# fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors.
#
# Returns:
# mod: the patched module, its (and its recursive submodules) script object attrs have been fakified.
# fake_args, fake_kwargs: new fakified args and kwargs.
# Script object inputs have been fakified. Don't touch the tensors.
# fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object.
# fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching.
constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod)
assert not any(
isinstance(obj, FakeScriptObject) for obj in constant_attrs.values()
), "Mod shouldn't contain any FakeScriptObject."
assert not pytree.tree_any(
lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs)
), "args and kwargs shouldn't contain any FakeScriptObject."
patched_attr = {}
fake_constant_attrs = ConstantAttrMap()
fake_to_real = {}
def _maybe_fakify_obj(obj):
fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj)
fake_to_real[fake_obj] = obj
return fake_obj
def _leaf_mod_and_attr(
mod: torch.nn.Module, attr_fqn: str
) -> Tuple[torch.nn.Module, str]:
*prefix_attr, last_attr = attr_fqn.split(".")
cur_mod = mod
for attr in prefix_attr:
cur_mod = getattr(cur_mod, attr)
return cur_mod, last_attr
try:
for obj, fqns in constant_attrs.items():
if isinstance(obj, torch.ScriptObject):
fake_script_obj = _maybe_fakify_obj(obj)
for fqn in fqns:
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
assert obj is getattr(cur_mod, attr)
setattr(cur_mod, attr, fake_script_obj)
fake_constant_attrs.add(fake_script_obj, fqn)
patched_attr[fqn] = obj
else:
for fqn in fqns:
fake_constant_attrs.add(obj, fqn)
fake_args, fake_kwargs = pytree.tree_map_only(
torch.ScriptObject, _maybe_fakify_obj, (args, kwargs)
)
yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real)
finally:
for fqn, orig_obj in patched_attr.items():
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
setattr(cur_mod, attr, orig_obj)
class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
"""
1. Handles data-dependent errors raised by torch function calls in non-strict.
Any data-dependent error is due to some condition on unbacked symints
that cannot be resolved. A mechanical way of fixing the error is to use
a torch._check() call to assert either that condition or its negation.
The handler suggests these options as code and points to the location
of the torch function call that raised the error as part of the error
message shown to the user, who can then simply select and copy-paste
a suggested fix at that location.
NOTE: Not all data-dependent errors are raised by torch function calls.
In particular, conditions on unbacked symints can appear outside such
calls, and as such are not handled here.
2. Handles line-of-code logging for each torch function call in non-strict.
Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ...
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
frame = _find_user_code_frame()
if frame is not None:
log.debug(
"%s called at %s:%s in %s",
func.__qualname__,
frame.f_code.co_filename,
frame.f_lineno,
frame.f_code.co_name,
)
try:
return func(*args, **kwargs)
except GuardOnDataDependentSymNode as e:
_suggest_fixes_for_data_dependent_error_non_strict(e)
raise

View File

@ -0,0 +1,441 @@
# mypy: allow-untyped-defs
import operator
import traceback
import typing
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from functorch.experimental.control_flow import _unstack_pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx import traceback as fx_traceback
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
from torch.fx.graph import CodeGen
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.utils import _pytree as pytree
from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
Argument = Any
Value = Any
Fn = Callable[..., Any]
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
_TORCH_SYM_OPS: Set[Callable] = {
torch.sym_int,
torch.sym_float,
torch.sym_ite,
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.sym_sqrt,
}
class ExportPassBaseError(RuntimeError):
pass
class _ExportPassBaseDeprecatedDoNotUse(PassBase):
"""
Interpreter-based pass class to help users maintain the IR spec while writing
transformations.
"""
@staticmethod
def _create_dummy_node_metadata():
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
class ExportTracer(PythonKeyTracer):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
super().__init__()
self.callback = callback
self.root = torch.nn.Module()
self.graph = torch.fx.Graph()
self.graph.set_codegen(codegen)
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.submodules: Dict[torch.nn.Module, str] = {}
def trace(self) -> None: # type: ignore[override]
raise ExportPassBaseError("ExportTracer doesn't support trace().")
def create_arg(self, a: Argument) -> torch.fx.Node:
if isinstance(a, torch.nn.Module):
if a not in self.submodules:
name_submodule = f"submodule_{len(self.submodules)}"
self.root.add_module(name_submodule, a)
self.submodules[a] = name_submodule
elif isinstance(a, FakeTensor):
if not hasattr(a, "constant") or a.constant is None:
raise ExportPassBaseError(f"Cannot add {a} to graph.")
a = a.constant
node = super().create_arg(a)
if (
isinstance(a, torch.Tensor)
and isinstance(node, torch.fx.Node)
and node.op == "get_attr"
):
self.set_metadata(node, a)
self.callback.on_attr(ProxyValue(a, node))
return node
def set_metadata(
self, node: torch.fx.Node, value: Argument,
) -> None:
# propagate the fake tensor or sym nodes
def make_val(
x: Argument,
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
if isinstance(x, FakeTensor):
return x
elif isinstance(x, torch.Tensor):
if x.is_quantized:
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
x = torch.dequantize(x)
try:
assert self.fake_tensor_mode is not None
# TODO we should allocate static shapes
# for param/buffer values
if isinstance(x, torch.nn.Parameter):
fake_tensor = self.fake_tensor_mode.from_tensor(
x, static_shapes=True
)
else:
fake_tensor = self.fake_tensor_mode.from_tensor(x)
except UnsupportedFakeTensorException:
# TODO: This is just a workaround to get over the
# x.as_subclass error
print(
"Fakeifying a Tensor subclass is not supported \
right now. Instead a TensorMetadata is used."
)
fake_tensor = None
return fake_tensor
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
return x
else:
return None
node.meta["val"] = pytree.tree_map(make_val, value)
# Set the tensor_metadata for values that do not have a corresponding FakeTensor
def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
if x.is_quantized:
# TODO (tmanlaibaatar) properly support Quantized FakeTensor
x = torch.dequantize(x)
try:
assert self.fake_tensor_mode is not None
_ = self.fake_tensor_mode.from_tensor(x)
tensor_meta = None
except UnsupportedFakeTensorException:
# TODO: This is just a workaround to get over the
# x.as_subclass error
tensor_meta = _extract_tensor_metadata(x)
return tensor_meta
else:
return None
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
class ExportInterpreter(fx.Interpreter):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
super().__init__(gm)
self.callback = callback
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
def placeholder(
self,
target: str, # type: ignore[override]
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
arg = super().placeholder(target, args, kwargs)
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
def output(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
def call_function(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
meta = NodeMetadata(self.node.meta)
if target == operator.getitem:
value, key = args
return self.callback.call_getitem(value, key, meta)
elif getattr(target, "__module__", None) in {"_operator", "math"}:
assert callable(target)
return self.callback.call_sym(target, args, meta)
elif target in _TORCH_SYM_OPS:
assert callable(target)
return self.callback.call_sym(target, args, meta)
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
return self.callback.call_operator(
target,
args,
kwargs,
meta,
)
elif target == torch.ops.higher_order.cond:
pred, true_fn, false_fn, inputs = args
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
elif target == torch.ops.higher_order.map_impl:
f, mapped_args, operands = args # type: ignore[assignment]
return self.callback.call_map(f, mapped_args, operands, meta)
# For other unregistered HigherOrderOps, just interpret them blindly
elif isinstance(target, torch._ops.HigherOrderOperator):
return self.callback._fx(
"call_function",
target,
args,
kwargs,
meta,
)
else:
raise ExportPassBaseError(f"Unsupported target type: {target}")
def get_attr(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
) -> Argument:
return super().get_attr(target, args, kwargs)
def call_module(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> None:
raise ExportPassBaseError("call_module is not supported.")
def call_method(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
) -> None:
raise ExportPassBaseError("call_method is not supported.")
def run_node(self, n: torch.fx.Node) -> Argument:
self.node = n
self.callback.node_debug_str = n.format_node()
return super().run_node(n)
def __init__(self) -> None:
self.interpreter = PropagateUnbackedSymInts(
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)
self.tracer = self.ExportTracer(self, CodeGen())
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self._initialized = True
self.node_debug_str: typing.Optional[str] = None
def _fx(
self,
kind: str,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
args_data, kwargs_data = pytree.tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
args_proxy, kwargs_proxy = pytree.tree_map_only(
ProxyValue, lambda x: x.proxy, (args, kwargs)
)
name = None
if isinstance(target, torch._ops.OpOverload):
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
res_proxy.node.meta.update(meta.data)
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
self.tracer.set_metadata(res_proxy.node, res_data)
return ProxyValue(res_data, res_proxy)
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
# TODO(angelayi): Update this with what we decide to do for metadata in
# the exported graph module
if (args := graph_module.meta.get("args", None)) is not None:
return list(args)
def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
if "val" in node.meta:
fake = node.meta["val"]
if hasattr(fake, "constant") and fake.constant is not None:
return fake.constant
return fake
elif tensor_meta := node.meta.get("tensor_meta"):
assert self.fake_tensor_mode is not None
return FakeTensor(
self.fake_tensor_mode,
torch.empty(
tensor_meta.shape,
dtype=tensor_meta.dtype,
device="meta",
requires_grad=tensor_meta.requires_grad,
memory_format=tensor_meta.memory_format,
),
torch.device("cpu"),
)
elif len(node.users) == 0:
return None
raise ExportPassBaseError(
f"Cannot construct an input for graph module: {graph_module}.",
)
return [
extract_input(node)
for node in graph_module.graph.nodes
if node.op == "placeholder"
]
def on_attr(self, attr: ProxyValue) -> None:
pass
def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
arg_proxy.node.meta = meta.data
self.tracer.set_metadata(arg_proxy.node, arg)
return ProxyValue(arg, arg_proxy)
def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
return self._fx("call_function", op, args, kwargs, meta)
def call_sym(
self,
target: Fn,
args: Tuple[Argument, ...],
meta: NodeMetadata,
) -> ProxyValue:
return self._fx("call_function", target, args, {}, meta)
def call_cond(
self,
pred: ProxyValue,
true_fn: torch.fx.GraphModule,
false_fn: torch.fx.GraphModule,
inputs: List[Argument],
meta: NodeMetadata,
) -> ProxyValue:
true_branch = self.call_submodule(true_fn, tuple(inputs))
false_branch = self.call_submodule(false_fn, tuple(inputs))
assert true_branch is not None
assert false_branch is not None
return self._fx(
"call_function",
torch.ops.higher_order.cond,
(pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
{},
meta,
)
def call_map(
self,
f: torch.fx.GraphModule,
mapped_args: List[ProxyValue],
operands: List[ProxyValue],
meta: NodeMetadata,
) -> ProxyValue:
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
assert f_branch is not None
return self._fx(
"call_function",
torch.ops.higher_order.map_impl,
(f_branch.graph_module, mapped_args, operands),
{},
meta,
)
def call_getitem(
self, value: ProxyValue, key: int, meta: NodeMetadata
) -> ProxyValue:
return self._fx("call_function", operator.getitem, (value, key), {}, meta)
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
return self._fx("output", "output", (results,), {}, meta)
def call_submodule(
self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
) -> PassResult:
prev_tracer, self.tracer = self.tracer, self.ExportTracer(
self, graph_module.graph._codegen
)
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
interpreter = self.ExportInterpreter(self, graph_module)
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
with fx_traceback.preserve_node_meta():
interpreter.run(*inputs_data)
new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
self.tracer = prev_tracer
self.interpreter = prev_interpreter
return PassResult(
new_graph_module,
True,
)
def call(self, graph_module: fx.GraphModule) -> PassResult:
if not getattr(self, "_initialized", False):
raise ExportPassBaseError(
"ExportPass is not initialized with __init__().",
)
inputs = self.inputs(graph_module)
fake_tensor_mode = None
for i in inputs:
if isinstance(i, FakeTensor):
assert (
fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
), "Multiple fake tensor mode detected."
fake_tensor_mode = i.fake_mode
if fake_tensor_mode is None:
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
fake_tensor_mode = nullcontext() # type: ignore[assignment]
dispatcher_mode = nullcontext() # type: ignore[assignment]
else:
fake_tensor_mode.allow_non_fake_inputs = True
self.tracer.fake_tensor_mode = fake_tensor_mode
dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
self.fake_tensor_mode = self.tracer.fake_tensor_mode
with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
result = self.call_submodule(graph_module, tuple(inputs))
return result

Some files were not shown because too many files have changed in this diff Show More