I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,89 @@
r'''
FX is a toolkit for developers to use to transform ``nn.Module``
instances. FX consists of three main components: a **symbolic tracer,**
an **intermediate representation**, and **Python code generation**. A
demonstration of these components in action:
::
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
The **symbolic tracer** performs "symbolic execution" of the Python
code. It feeds fake values, called Proxies, through the code. Operations
on theses Proxies are recorded. More information about symbolic tracing
can be found in the :func:`symbolic_trace` and :class:`Tracer`
documentation.
The **intermediate representation** is the container for the operations
that were recorded during symbolic tracing. It consists of a list of
Nodes that represent function inputs, callsites (to functions, methods,
or :class:`torch.nn.Module` instances), and return values. More information
about the IR can be found in the documentation for :class:`Graph`. The
IR is the format on which transformations are applied.
**Python code generation** is what makes FX a Python-to-Python (or
Module-to-Module) transformation toolkit. For each Graph IR, we can
create valid Python code matching the Graph's semantics. This
functionality is wrapped up in :class:`GraphModule`, which is a
:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
``forward`` method generated from the Graph.
Taken together, this pipeline of components (symbolic tracing ->
intermediate representation -> transforms -> Python code generation)
constitutes the Python-to-Python transformation pipeline of FX. In
addition, these components can be used separately. For example,
symbolic tracing can be used in isolation to capture a form of
the code for analysis (and not transformation) purposes. Code
generation can be used for programmatically generating models, for
example from a config file. There are many uses for FX!
Several example transformations can be found at the
`examples <https://github.com/pytorch/examples/tree/master/fx>`__
repository.
'''
from .graph_module import GraphModule
from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
from .graph import Graph, CodeGen
from .node import Node, map_arg, has_side_effect
from .proxy import Proxy
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
from .subgraph_rewriter import replace_pattern

View File

@ -0,0 +1,15 @@
from torch.fx._symbolic_trace import (
symbolic_trace as symbolic_trace,
Tracer as Tracer,
wrap as wrap,
)
from torch.fx.graph import Graph as Graph
from torch.fx.graph_module import GraphModule as GraphModule
from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer
from torch.fx.node import (
has_side_effect as has_side_effect,
map_arg as map_arg,
Node as Node,
)
from torch.fx.proxy import Proxy as Proxy
from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern

View File

@ -0,0 +1,36 @@
from typing import Any, Dict, Callable, TypeVar
import textwrap
_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
_T = TypeVar("_T")
def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
if is_backward_compatible:
def mark_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
docstring += """
.. note::
Backwards-compatibility for this API is guaranteed.
"""
fn.__doc__ = docstring
_BACK_COMPAT_OBJECTS.setdefault(fn)
_MARKED_WITH_COMPATIBILITY.setdefault(fn)
return fn
return mark_back_compat
else:
def mark_not_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
docstring += """
.. warning::
This API is experimental and is *NOT* backward-compatible.
"""
fn.__doc__ = docstring
_MARKED_WITH_COMPATIBILITY.setdefault(fn)
return fn
return mark_not_back_compat

View File

@ -0,0 +1,185 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from torch.fx import GraphModule
from torch.fx.graph_module import (
_format_import_block,
reduce_graph_module,
reduce_package_graph_module,
)
from torch.package import PackageExporter, sys_importer
from ._compatibility import compatibility
_use_lazy_graph_module_flag = False
_force_skip_lazy_graph_module_flag = False
@compatibility(is_backward_compatible=False)
@contextmanager
def _force_skip_lazy_graph_module():
"""
Skip using lazy graph module disregarding the setting of _use_lazy_graph_module.
Use to skip _LazyGraphModule when testing inductor torchscript related backend.
torch.jit.script a _LazyGraphModule results in following error:
https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69
"""
try:
global _force_skip_lazy_graph_module_flag
prior = _force_skip_lazy_graph_module_flag
_force_skip_lazy_graph_module_flag = True
yield
finally:
_force_skip_lazy_graph_module_flag = prior
@compatibility(is_backward_compatible=False)
@contextmanager
def _use_lazy_graph_module(should_use: bool):
try:
global _use_lazy_graph_module_flag
prior = _use_lazy_graph_module_flag
_use_lazy_graph_module_flag = (
should_use and not _force_skip_lazy_graph_module_flag
)
yield
finally:
_use_lazy_graph_module_flag = prior
@compatibility(is_backward_compatible=False)
def _get_graph_module_cls():
return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule
def _make_graph_module(*args, graph_module_cls=None, **kwargs):
if graph_module_cls is None:
graph_module_cls = _get_graph_module_cls()
return graph_module_cls(*args, **kwargs)
@compatibility(is_backward_compatible=False)
class _LazyGraphModule(GraphModule):
"""
The main difference between _LazyGraphModule and GraphModule is how recompile happens.
GraphModule will do a 'recompile' call to generate python code and the forward method when it's
constructed. Later on if the graph get updated, recompile method can be called again to refresh
the saved python code and forward method.
However in some cases especially in inductor, the recompilation can be a waste since we never
check the python code for the graph module or call its forward method. A few more concreate
examples regarding pattern matching fx passes in inductor:
1. some passes will update the graph to be compiled and then call recompile on the GraphModule.
2. some passes will trace small pattern function to search it in the graph being compiled and
replace the match with the traced graph of a replacement function. The pattern graph and
replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile
for them in GraphModule.__init__ is also a waste of time.
However simply skip calling GraphModule.recompile in these scenarios is also dangeruous.
People may want to check the python code or call the GraphModule's forward method for debugging purposes.
The way _LazyGraphModule solves it is, we override the recompile method to just mark the
need for recompilation but does not do the actual recompilation. Later on if people really
access the compiled python code or call the GraphModule's forward method, we do the real
recompilation.
"""
@classmethod
def from_graphmodule(cls, gm: GraphModule):
if isinstance(gm, _LazyGraphModule):
return gm
else:
return _LazyGraphModule(gm, gm.graph)
@staticmethod
def force_recompile(gm):
"""
Sometimes we need force a recompile as a workaround
- we want to do the real recompilation before symbolic_trace to avoid error:
https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
"""
if isinstance(gm, _LazyGraphModule):
gm.real_recompile()
def real_recompile(self):
if self._needs_recompile():
self._real_recompile()
@classmethod
def _needs_recompile(cls):
return cls.forward is cls._lazy_forward
def _lazy_forward(self, *args, **kwargs):
# Call self.real_recompile() rather than self._real_recompile() here.
# The _lazy_forward method may be saved and call repeatedly.
# Calling self.real_recompile can make sure we skip recompilation if
# we have already done so.
self.real_recompile()
assert not self._needs_recompile()
# call `__call__` rather than 'forward' since recompilation may
# install a wrapper for `__call__` to provide a customized error
# message.
return self(*args, **kwargs)
forward = _lazy_forward
# TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__,
# or __reduce__ by calling _real_recompile. But I don't find a good way
# to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule
# will be used in torch::deploy. So it's skipped for now.
def __reduce_package__(self, exporter: PackageExporter):
"""
Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
than 'self.recompile' since for a _LazyGraphModule, self.recompile just
mark the need of recompilation and does not return the PythonCode object.
"""
python_code = self._real_recompile()
dict_without_graph = self.__dict__.copy()
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
del dict_without_graph["_graph"]
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
import_block = _format_import_block(python_code.globals, exporter.importer)
module_code = import_block + self.code
exporter.save_source_string(generated_module_name, module_code)
return (
reduce_package_graph_module,
(dict_without_graph, generated_module_name),
)
def __reduce__(self):
"""
Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
than 'self.recompile' since for a _LazyGraphModule, self.recompile just
mark the need of recompilation and does not return the PythonCode object.
"""
python_code = self._real_recompile()
dict_without_graph = self.__dict__.copy()
import_block = _format_import_block(python_code.globals, sys_importer)
del dict_without_graph["_graph"]
return (reduce_graph_module, (dict_without_graph, import_block))
def _real_recompile(self):
return super().recompile()
@classmethod
def recompile(cls):
cls.forward = cls._lazy_forward
@property
def code(self) -> str:
self.real_recompile()
return super().code
def __str__(self) -> str:
"""
str(GraphModule) will access the _code attribute. Make sure recompile
happens so _code attribute is available.
"""
self.real_recompile()
return super().__str__()

View File

@ -0,0 +1,103 @@
# mypy: allow-untyped-defs
from collections import namedtuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
import torch.return_types
from torch.utils._pytree import PyTree, TreeSpec
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
def register_pytree_flatten_spec(
cls: Type[Any],
flatten_fn_spec: FlattenFuncSpec,
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
) -> None:
SUPPORTED_NODES[cls] = flatten_fn_spec
SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
def tree_flatten_spec(
pytree: PyTree,
spec: TreeSpec,
exact_structural_match=False,
) -> List[Any]:
if spec.is_leaf():
return [pytree]
if spec.type not in SUPPORTED_NODES:
raise RuntimeError(
f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with "
"torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make "
"sure that any custom pytrees have been registered before loading it.",
)
flatten_fn_spec = SUPPORTED_NODES[spec.type]
child_pytrees = flatten_fn_spec(pytree, spec)
if exact_structural_match:
flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type]
if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec(
pytree,
spec,
):
raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}")
result = []
for child, child_spec in zip(child_pytrees, spec.children_specs):
flat = tree_flatten_spec(child, child_spec, exact_structural_match)
result += flat
return result
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
return [d[k] for k in spec.context]
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
return [d[i] for i in range(spec.num_children)]
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
return [d[i] for i in range(spec.num_children)]
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
return [d[i] for i in range(spec.num_children)]
def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children
def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children
def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children
def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
return len(d) == spec.num_children
register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
register_pytree_flatten_spec(
tuple,
_tuple_flatten_spec,
_tuple_flatten_spec_exact_match,
)
for return_type in torch.return_types.all_return_types:
register_pytree_flatten_spec(
return_type,
_tuple_flatten_spec,
_tuple_flatten_spec_exact_match,
)
register_pytree_flatten_spec(
namedtuple, # type: ignore[arg-type]
_namedtuple_flatten_spec,
_namedtuple_flatten_spec_exact_match,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,63 @@
# mypy: allow-untyped-defs
import sys
from typing import Dict, Optional
import torch
from torch._logging import LazyString
def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
"""
Returns a LazyString that formats the graph code.
"""
def format_name():
if maybe_id is not None:
return f"{name} {maybe_id}"
else:
return name
if "print_output" not in kwargs:
kwargs["print_output"] = False
if "colored" in kwargs and not sys.stdout.isatty():
kwargs["colored"] = False
return LazyString(
lambda: _format_graph_code(
f"===== {format_name()} =====\n",
gm.forward.__code__.co_filename,
gm.print_readable(**kwargs),
)
)
def _format_graph_code(name, filename, graph_str):
"""
Returns a string that formats the graph code.
"""
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
"""
Returns the nn_module_stack of the first call_function node.
"""
for node in graph.nodes:
if node.op == "call_function" and "nn_module_stack" in node.meta:
return node.meta["nn_module_stack"]
return None
def get_node_context(node, num_nodes=2) -> str:
"""
Returns a string of the last num_nodes nodes in the graph.
"""
node_contexts = []
cur = node
for i in range(num_nodes):
node_contexts.append(cur.format_node())
if cur.op == "root":
break
cur = cur.prev
return "\n".join(node_contexts[::-1])

View File

@ -0,0 +1,32 @@
# mypy: allow-untyped-defs
from torch.fx.proxy import Proxy
from ._compatibility import compatibility
@compatibility(is_backward_compatible=False)
def annotate(val, type):
"""
Annotates a Proxy object with a given type.
This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object
Args:
val (object): An object to be annotated if its type is torch.fx.Proxy.
type (object): A type to be assigned to a given proxy object as val.
Returns:
The given val.
Raises:
RuntimeError: If a val already has a type in its node.
"""
if isinstance(val, Proxy):
if val.node.type:
raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
f" Existing type is {val.node.type} "
f"and new type is {type}. "
f"This could happen if you tried to annotate a function parameter "
f"value (in which case you should use the type slot "
f"on the function signature) or you called "
f"annotate on the same value twice")
else:
val.node.type = type
return val
else:
return val

View File

@ -0,0 +1,6 @@
# Whether to disable showing progress on compilation passes
# Need to add a new config otherwise wil get a circular import if dynamo config is imported here
disable_progress = True
# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
verbose_progress = False

View File

@ -0,0 +1,27 @@
import torch.fx
class BackwardState:
"""
BackwardState is used to pass Python hooks from the forwards pass
into the backwards pass in Dynamo+Compiled Autograd.
It is created by TorchDynamo and has special handling there.
Dynamo will pass an empty BackwardState to the forwards, then populate
members on it (via setattr) only after the forwards graph is finished.
Later on, in CompileAutograd we will inline and add the needed guards
on the BackwardState.
BackwardState is identified and has special handling in AOTAutograd.
During AOTAutograd:
1) BackwardState is an input to the forwards graph
2) It must only be used in the backwards
3) It will be empty in the forwards
4) In the forwards we add a wrapper to save it
5) In the backwards it becomes an input
6) There can only be one per graph
BackwardState requires CompiledAutograd.
"""
proxy: torch.fx.Proxy

View File

@ -0,0 +1,88 @@
import os
import sys
from typing import Optional
# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
translation_validation = (
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
)
# Timeout (in milliseconds) for z3 finding a solution.
# [@compile_ignored: debug]
translation_validation_timeout = int(
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
)
# Disables bisection for translation validation.
#
# Translation validation bisection is enabled by default, if translation validation
# is also enabled. This should help finding guard simplification issues. However,
# since validation uses Z3 for bisecting, it might take a lot of time.
#
# Set this configuration option so as to avoid bisecting.
# [@compile_ignored: debug]
translation_validation_no_bisect = (
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
)
# Checks whether replaying ShapeEnv events on a freshly constructed one yields
# the a ShapeEnv with the same state. This should be used only in testing.
check_shape_env_recorded_events = False
# TODO: Perhaps consider allowing unions for the configs below (so you can hit
# multiple reps at the same time)
# Give extended debug information if the string representation of a guard
# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue
# this guard, we will generate full Python and C++ backtrace
# [@compile_ignored: debug]
extended_debug_guard_added = os.environ.get(
"TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None
)
# Give extended debug information when a particular symbol is allocated. For
# example, set this to "u2" and whenever we create this symbol, we will
# generate full Python and C++ backtrace
# [@compile_ignored: debug]
extended_debug_create_symbol = os.environ.get(
"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None
)
# Give extended debug information (C++ backtrace) for all extended debug
# settings as well as errors. The C++ backtrace is slow and very spammy so we
# don't include it by default even when you're requesting extended debug.
# [@compile_ignored: debug]
extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != ""
# Give extended debug information (line of code) when a torch function
# is called during export. This is useful for showing progress and detecting
# where export might be stuck. Currently only works for strict=False.
# [@compile_ignored: debug]
extended_debug_current_loc = (
os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1"
)
# [@compile_ignored: debug] Show a warning for every specialization
print_specializations = False
# wraps (un)equalities with 'Not' class after recording the correct expression
# in the FX graph. This should incorrectly construct the divisible and replacement
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
validate_shape_env_version_key = False
# If we produce more than this many guards on a symbol, force the symbol to
# get specialized and bail out if this many guards mention this particular
# symbol. This may be slightly more aggressive than the true number of guards
# issued (as we test if we've hit the limit on-the-fly, whereas we may
# do further simplifications at final guard issuance time that make guards
# irrelevant.)
symbol_guard_limit_before_specialize: Optional[int] = None
# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.
use_duck_shape = True
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,292 @@
# mypy: allow-untyped-defs
import re
from typing import Callable, Dict, Optional, Set, Union
import torch.fx
from torch.fx.node import map_arg
from torch.fx.passes.split_module import split_module
__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
class FoldedGraphModule(torch.fx.GraphModule):
"""
FoldedGraphModule is a GraphModule which also contains another
`const_subgraph_module` representing a subgraph which has all const attr
inputs and which can be run once before running the main standard
`graph`. The `const_output_names` are the ordered list names of attrs which
represent what each respective output from the const_subgraph should be set
on which attrs.
"""
def __init__(
self,
root: torch.nn.Module,
graph: torch.fx.Graph,
const_subgraph: Optional[torch.fx.Graph] = None,
fx_const_folded_attrs_name: Optional[str] = None,
device_for_folded_attrs: str = "cuda",
):
super().__init__(root, graph)
self.const_subgraph_module = (
None
if const_subgraph is None
else torch.fx.GraphModule(root, const_subgraph)
)
self.has_folding_been_run = False
self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
self.device_for_folded_attrs = device_for_folded_attrs
def __call__(self, *args, **kwargs):
if not self.has_folding_been_run:
self.run_folding()
return super().__call__(*args)
def run_folding(self):
# If there's no const subgraph module or attr output names to use, return
# early as there is no const folding to perform.
if (
self.const_subgraph_module is None
or self.fx_const_folded_attrs_name is None
):
return
assert not self.has_folding_been_run
self.has_folding_been_run = True
# Actually run const folding subgraph. Note that single attr const fold
# subgraphs output a single Tensor while multiple outputs are returned as
# Tuple[Tensor,].
folded_attrs = self.const_subgraph_module()
def _create_param(i):
return torch.nn.Parameter(
i.detach().clone()
if not isinstance(i, int)
else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
)
params = (
torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
if isinstance(folded_attrs, tuple)
else _create_param(folded_attrs)
)
setattr(self, self.fx_const_folded_attrs_name, params)
def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
"""
Given `gm` and some graph module which is called with target name `inline_mod_name`,
this helper will inline all of the nodes from that called graph module into `gm`.
"""
# Fetch the inner graph module that we want to inline inside `gm`.
inline_mod = dict(gm.named_modules())[inline_mod_name]
assert isinstance(inline_mod, torch.fx.GraphModule)
call_mod_node_to_replace = None
for node in gm.graph.nodes:
if node.op == "call_module" and node.target == inline_mod_name:
call_mod_node_to_replace = node
break
assert call_mod_node_to_replace is not None
# Now actually do the swap. Note that we have to keep track of new nodes that are
# copied into `gm` -- we do this via replacement_mapping.
call_mod_args = call_mod_node_to_replace.args
replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
ph_count = 0
def replacement_fn(node):
new_node = replacement_mapping[node]
new_node.meta = node.meta.copy()
return new_node
for inline_node in inline_mod.graph.nodes:
if inline_node.op == "placeholder":
replacement_mapping[inline_node] = call_mod_args[ph_count]
ph_count += 1
continue
if inline_node.op == "output":
outputs = inline_node.args[0]
output_replacements = map_arg(outputs, replacement_fn)
call_mod_node_to_replace.replace_all_uses_with(output_replacements)
continue
with gm.graph.inserting_before(call_mod_node_to_replace):
new_node = gm.graph.node_copy(inline_node, replacement_fn)
replacement_mapping[inline_node] = new_node
gm.graph.eliminate_dead_code()
def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
"""
Make sure the name is unique (in a module) and can represents an attr.
"""
# Delete all characters that are illegal in a Python identifier.
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit():
name = f"_{name}"
# Now make sure it is in fact unique to the module by incrementing suffix value.
while hasattr(mod_traced, name):
match = re.match(r"(.*)_(\d+)$", name)
if match is None:
name = name + "_1"
else:
base, num = match.group(1, 2)
name = f"{base}_{int(num) + 1}"
return name
def split_const_subgraphs(
module: Union[torch.nn.Module, torch.fx.GraphModule],
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
device_for_folded_attrs: str = "cpu",
) -> FoldedGraphModule:
"""
Looks through `module` for any nodes that have all constant attribute inputs
and separates them out into their own constant subgraph, and returns a
FoldedGraphModule which runs that constant subgraph on the first run to set
attributes on the module prior to running the non-constant portion of the
graph.
"""
if not isinstance(module, torch.fx.GraphModule):
mod_traced = torch.fx.symbolic_trace(module)
else:
mod_traced = module
# Build up a list of const_nodes, defined as nodes that are themselves
# get_attrs, or have all get_attr or other constant node inputs.
const_nodes: Set[torch.fx.Node] = set()
found_const_folding = False
for node in mod_traced.graph.nodes:
# Skip over placeholders/outputs because they can't be const folded and
# we don't want to add tags to them.
if node.op in {"placeholder", "output"}:
continue
# If the node itself is constant, or all of its inputs are constant,
# then tag it as constant.
if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
const_nodes
):
continue
# If provided skip folding function says to skip, then skip.
if skip_folding_node_fn and skip_folding_node_fn(node):
continue
# Skip folding side-effectful functions
if node.is_impure():
continue
# Must be a constant foldable node at this point.
const_nodes.add(node)
if node.op != "get_attr":
found_const_folding = True
# If we did not find any const folding then return early without a const fold subgraph.
if not found_const_folding:
return FoldedGraphModule(mod_traced, mod_traced.graph)
# Partition the module into two: submod_0 for constant folding subgraph, and
# submod_1 for the rest.
def mod_partition(node: torch.fx.Node):
return 0 if node in const_nodes else 1
split = split_module(mod_traced, module, mod_partition)
const_mod_name, non_const_mod_name = "submod_0", "submod_1"
# Safely get submod_1 in case there are no non-const nodes
const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None)
# The module that a call_module node refers to gets copied to submodules during split.
# The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
# attach inlined modules to `split` as it's the owning module now.
for node in non_const_gm.graph.nodes if non_const_gm else []:
if node.op == "call_module":
setattr(split, node.target, getattr(non_const_gm, node.target))
for node in const_gm.graph.nodes:
if node.op == "call_module":
setattr(split, node.target, getattr(const_gm, node.target))
# split_module currently does not use get_attrs for attrs. Instead it passes
# them in as args from the parent module, which used get_attrs. Here we set
# them as get_attrs inside const_gm, allowing for running folding without
# somehow a priori knowing the attrs that should be passed as args. We can
# unconditionally do this for all placeholders because we know all
# placeholders to const_gm must be constants accessible via get_attr.
call_const_gm_args = None
for node in split.graph.nodes:
if node.op == "call_module":
if node.target == const_mod_name:
call_const_gm_args = node.args
break
assert call_const_gm_args is not None
# Here we do the actual replacement of placeholders to get_attrs. Note that here we
# set the const_gm.graph into a new root_const_gm with split as the root module,
# because we are fetching attributes directly from the root module, instead of
# fetching them from const_gm. Example: The const_gm must have some format like:
# graph():
# %inp : [num_users=1] = placeholder[target=const_inp]
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
# return add
# We replace that with the following, which does not have any placeholders:
# graph():
# %inp_1 : [num_users=1] = get_attr[target=const_inp]
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
# return add
root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
for node in root_const_gm.graph.nodes:
if node.op == "output":
multiple_outputs = isinstance(node.args[0], tuple)
continue
if node.op != "placeholder":
continue
in_node = next(n for n in call_const_gm_args if n.name == node.target)
assert in_node.op == "get_attr"
with root_const_gm.graph.inserting_before(node):
new_node = root_const_gm.graph.get_attr(in_node.target)
new_node.meta = node.meta.copy()
node.replace_all_uses_with(new_node)
root_const_gm.graph.erase_node(node)
assert "multiple_outputs" in locals()
# Now find the call to const_gm inside split, and replace it with a getattr to the
# folded tensor(s) that result from constant folding. Note that we don't need to
# worry about whether this is one or more tensors because the original graph
# correctly uses getitem to extract individual tensors if there are multiple folded.
fx_const_folded_attrs_name = get_unique_attr_name_in_module(
mod_traced, "_FX_CONST_FOLDED_ATTRS"
)
setattr(
split,
fx_const_folded_attrs_name,
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined]
)
for node in split.graph.nodes:
if node.op == "call_module" and node.target == const_mod_name:
with node.graph.inserting_before(node):
folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
folded_attrs.meta = node.meta.copy()
node.replace_all_uses_with(folded_attrs)
break
# Finally, inline the non-constant submod (if it exists) into the split submod.
# This is so that the original caller who may have passed in a graph module will
# get back out a graph module whose graph is traced to the same granularity.
if hasattr(split, non_const_mod_name):
_inline_module(split, non_const_mod_name)
split.graph.eliminate_dead_code()
return FoldedGraphModule(
split,
split.graph,
root_const_gm.graph,
fx_const_folded_attrs_name,
device_for_folded_attrs,
)

View File

@ -0,0 +1,32 @@
# mypy: allow-untyped-defs
import torch.fx as fx
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
"""
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
`gm` gets run.
Args:
gm: graph module to insert breakpoint. It is then recompiled for it to
take effect.
Returns:
the `gm` with breakpoint inserted.
"""
def insert_pdb(body):
return ["import pdb; pdb.set_trace()\n", *body]
with gm.graph.on_generate_code(
make_transformer=lambda cur_transform: (
# new code transformer to register
lambda body: (
insert_pdb(
cur_transform(body) if cur_transform
else body
)
)
)
):
gm.recompile()
return gm

View File

@ -0,0 +1,916 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from functools import reduce
import torch
import operator
from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
from typing import Callable, Dict
from torch.fx.node import Target, Node
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d
from torch.fx.experimental.refinement_types import Equality
import itertools
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
import sympy
_INFERENCE_RULES: Dict[Target, Callable] = {}
_REFINEMENT_RULES: Dict[Target, Callable] = {}
_RULES: Dict[Target, Callable] = {}
def expand_to_tensor_dim(t, n):
"""
Expand a type to the desired tensor dimension if possible
Raise an error otherwise.
- t is the given type
- n is a number of dimensions to expand to
"""
if t == Dyn:
dims = [Dyn] * n
return TensorType(tuple(dims))
elif isinstance(t, TensorType):
if len(t.__args__) != n:
raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
return t
else:
raise TypeError(f'Cannot match the type {t}')
def broadcast_types(t1, t2):
"""
Applies broadcasting to both given types such that they
become consistent with eachother and returns two new
resulting types
"""
# if either type is Dyn, do nothing since the types are already consistent
if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
return t1, t2
if isinstance(t1, TensorType) and isinstance(t2, TensorType):
s1 = len(t1.__args__)
s2 = len(t2.__args__)
new_t1 = list(t1.__args__)
new_t2 = list(t2.__args__)
# We make the types the same length which is the first requirement
# for consistency
if s1 > s2:
for i in range(s1 - s2):
new_t2.insert(0, 1)
elif s2 > s1:
for i in range(s2 - s1):
new_t1.insert(0, 1)
# we replace occurrences of "1" with each tensor with
# the corresponding type from the other tensor
for i, (x, y) in enumerate(zip(new_t1, new_t2)):
if x == 1:
new_t1[i] = y
elif y == 1:
new_t2[i] = x
# at this point our tensors should be consistent
# and we can apply the element-wise operation and find the right dimension
# for the output of the operation
(t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
return (t1, t2)
else:
raise TypeError(f'Cannot broadcast types {t1} and {t2}')
def register_inference_rule(call_target):
def register(fn):
if call_target in _INFERENCE_RULES:
raise RuntimeError(f'Inference rule already registered for {call_target}!')
_INFERENCE_RULES[call_target] = fn
return fn
return register
def register_refinement_rule(call_target):
def register(fn):
if call_target in _REFINEMENT_RULES:
raise RuntimeError(f'Refinement rule already registered for {call_target}!')
_REFINEMENT_RULES[call_target] = fn
return fn
return register
def register_algebraic_expressions_inference_rule(call_target):
def register(fn):
if call_target in _RULES:
raise RuntimeError(f'Rule already registered for {call_target}!')
_RULES[call_target] = fn
return fn
return register
@register_inference_rule(torch.add)
@register_inference_rule(operator.add)
def add_inference_rule(n: Node):
"""
Apply the addition inference rule. This includes:
- scalar addition
- broadcasting semantics
Note that we always return the least precise type between
the operands (after applying broadcasting) to be the final type of the operation
Note that we do not modify the operand types themselves after applying broadcasting
to them. We only use them to calculate the final type
"""
assert isinstance(n.args[0], Node)
assert isinstance(n.args[1], Node)
t1 = n.args[0].type
t2 = n.args[1].type
# handle scalar addition
if t1 == int and isinstance(t2, TensorType):
n.type = t2
return n.type
# handle scalar addition
elif t2 == int and isinstance(t1, TensorType):
n.type = t1
return n.type
# we bring the new types to the point where
# we can check for consistency
# any inconsistency would not have been caused
# by broadcasting at this point
(new_t1, new_t2) = broadcast_types(t1, t2)
if new_t1 != t1 or new_t2 != t2:
n.meta['broadcast'] = True
n.meta[str(n.args[0])] = new_t1
n.meta[str(n.args[1])] = new_t2
else:
n.meta['broadcast'] = False
new_t1 = t1 if not n.meta['broadcast'] else new_t1
new_t2 = t2 if not n.meta['broadcast'] else new_t2
# we check for consistency between the new types
if is_consistent(new_t1, new_t2):
# we return the less precise type because
# broadcasting may have happened
# for operands with shape [1,2,Dyn] and [1,2,1]
# we have to assign the node [1,2,Dyn]
if is_more_precise(new_t1, new_t2):
n.type = new_t2
else:
n.type = new_t1
return n.type
else:
raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.'
f' Types should match ')
@register_inference_rule(getattr)
def get_attr_inference_rule(n: Node, traced):
"""
The current getattr rule only handles the shape attribute
Can be extended to other attributes
The most representitive type we have is "Dyn" but the system
can be extended with more types, such as a type to represent shapes
"""
attr_node = n.args[0]
attr_name = n.args[1]
if attr_name == "shape":
n.type = Dyn
else:
raise TypeError("Not yet implemented")
# TODO. We leave it like this till we add a type to represent tensor sizes
return n.type
@register_inference_rule(torch.transpose)
def transpose_inference_rule(n: Node):
"""
We check that dimensions for the transpose operations
are within range of the tensor type of the node
"""
if n.target == torch.transpose:
assert isinstance(n.args[0], Node)
t = n.args[0].type
assert isinstance(n.args[1], int)
assert isinstance(n.args[2], int)
dim1, dim2 = n.args[1], n.args[2]
if t == Dyn:
n.type = Dyn
return n.type
elif isinstance(t, TensorType):
if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
new_type = list(t.__args__)
new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
final = TensorType(new_type)
n.type = get_greatest_upper_bound(n.type, final)
return n.type
else:
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
else:
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
@register_inference_rule(torch.reshape)
def reshape_inference_rule(n: Node):
"""
Without dynamism, the rule checks that the
product of the elements of the argument tensor
type is equal to the product of the elements
of the required shape. We gradualize this rule
by adding a case to handle fully dynamic input
as well as input where some of the tensor dimensions
are unknown. In this case we check for divisibility
"""
assert isinstance(n.args[0], Node)
t1 = n.args[0].type
assert isinstance(n.args[1], list)
t2 = n.args[1]
t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])
# if we do not know the original tensor dimension,
# we return the required dimension
if t1 == Dyn:
n.type = t2_type
return t2_type
# if any of the dimensions are unknown,
# we check for divisibility
elif isinstance(t1, TensorType):
assert isinstance(t1, TensorType)
a = [e if e != Dyn else 1 for e in t1.__args__]
p1 = reduce(operator.mul, a)
p2 = reduce(operator.mul, t2)
if p1 % p2 == 0 or p2 % p1 == 0:
n.type = t2_type
return t2_type
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
@register_inference_rule(BatchNorm2d)
def bn2d_inference_rule(n: Node, module_instance):
"""
Given a BatchNorm2D instance and a node check the following conditions:
- the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4)
- the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4')
- t is consistent with t'
- x_2 is consistent with the module's num_features
- x_2' is consistent with the module's num_features
output type: the more precise type of t and t'
"""
assert isinstance(n.args[0], Node)
n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
arg_type = n.args[0].type
n.type = expand_to_tensor_dim(n.type, 4)
# we check the conditions on the incoming argument
# and any existing annotation
# we also check for consistency between both annotations
if is_consistent(arg_type.__args__[1], module_instance.num_features) and \
is_consistent(n.type.__args__[1], module_instance.num_features) and \
is_consistent(arg_type, n.type):
# we choose the more precise type
# to be the node type
# so if an incoming argument has more type information
# we set this node's type to be the argument type
n.type = get_greatest_upper_bound(arg_type, n.type)
return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')
def calculate_out_dimension(d_in, module_instance, index):
"""
For calculating h_in and w_out according to the conv2D documentation
"""
padding = (module_instance.padding, module_instance.padding) \
if isinstance(module_instance.padding, int) else module_instance.padding
kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \
if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size
stride = (module_instance.stride, module_instance.stride) \
if isinstance(module_instance.stride, int) else module_instance.stride
dilation = (module_instance.dilation, module_instance.dilation) \
if isinstance(module_instance.dilation, int) else module_instance.dilation
DIMENSION_TYPES = (int, sympy.Symbol)
if d_in == Dyn:
return Dyn
elif isinstance(d_in, DIMENSION_TYPES):
n = d_in + 2 * padding[index] - \
dilation[index] * \
(kernel_size[index] - 1) - 1
return (n // stride[0]) + 1
else:
raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}')
def get_greatest_upper_bound(type1, type2):
"""
Get the most precise type that's consistent with the given types
"""
if type1 == Dyn:
return type2
elif type2 == Dyn:
return type1
elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
if not is_consistent(type1, type2):
raise TypeError(f'Inconsistent types {type1}, {type2}')
gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)]
return TensorType(tuple(gub))
@register_inference_rule(Conv2d)
def conv2d_inference_rule(n: Node, module_instance):
"""
Given a Conv2D instance and a node check the following conditions:
- the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W)
- the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4')
- x_2 is consistent with the module's in_channels
- let o = (x_1, out_channels, H_out, W_out)
then the output is the greatest upper bound of o and the existing node type t'.
"""
assert isinstance(n.args[0], Node)
n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
arg_type = n.args[0].type
curr_node_type = expand_to_tensor_dim(n.type, 4)
if is_consistent(arg_type.__args__[1], module_instance.in_channels):
w_in = arg_type.__args__[3]
h_in = arg_type.__args__[2]
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
gub = get_greatest_upper_bound(new_type, curr_node_type)
n.type = gub
return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
@register_inference_rule(torch.nn.ReLU)
def relu_inference_rule(n: Node, module_instance):
"""
Input and output shapes should be equal.
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
n.type = get_greatest_upper_bound(n.args[0].type, n.type)
return n.type
def maxpool2d_check(typ, module_instance):
"""
Applies the maxpool2d shape information to the input
this affects the last two dimensions
"""
new_type_list = list(typ.__args__)
if len(new_type_list) == 4 or len(new_type_list) == 3:
w_in = new_type_list[-1]
h_in = new_type_list[-2]
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type_list[-1] = w_out
new_type_list[-2] = h_out
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Wrong size {typ} for {module_instance}')
@register_inference_rule(torch.nn.MaxPool2d)
def maxpool2d_inference_rule(n: Node, module_instance):
"""
Given a MaxPool2D instance and a node check the following conditions:
- Input size matches size 3 or 4
- Current node type is consistent with the output type we will calculate
- Input size matches output size and the last two dimensions of the output
are w_out and h_out. The remaining dimensions are the same as the input
- Our final result is the greatest upper bound of the output we calculate
and the current node type.
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output = maxpool2d_check(n.args[0].type, module_instance)
n.type = get_greatest_upper_bound(output, n.type)
return n.type
def linear_check(tensor_type, module_instance):
"""
Checks that an input tensor type satisfies the conditions for linear operation
and returns the output type based on in and out features given by module_instance
"""
if len(tensor_type.__args__) >= 2:
if is_consistent(module_instance.in_features, tensor_type.__args__[-1]):
new_type_args = list(tensor_type.__args__)
new_type_args[-1] = module_instance.out_features
return TensorType(tuple(new_type_args))
else:
raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}')
else:
raise TypeError(f'Type {tensor_type} must have rank 2 or more.')
@register_inference_rule(torch.nn.Linear)
def linear_inference_rule(n: Node, module_instance):
"""
Applies the shape information to the input then gets the greatest upper bound
of the resulting type and the existing type
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output_type = linear_check(n.args[0].type, module_instance)
n.type = get_greatest_upper_bound(output_type, n.type)
return n.type
def adaptiveavgpool2d_check(tensor_type, module_instance):
output_size = module_instance.output_size
if isinstance(output_size, int):
output_size = [output_size, output_size]
elif isinstance(output_size, tuple):
output_size = list(output_size)
if output_size[0] is None:
output_size[0] = output_size[1]
if output_size[1] is None:
output_size[1] = output_size[0]
new_type_list = list(tensor_type.__args__)
if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3:
new_type_list[-1] = output_size[1]
new_type_list[-2] = output_size[0]
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}')
@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
"""
The input and output sizes should be the same except for the last
two dimensions taken from the input, which represent width and height
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance)
n.type = get_greatest_upper_bound(n.type, output_type)
return n.type
def flatten_check(tensor_type, start_dim, end_dim):
l = len(tensor_type.__args__)
start_dim = l if start_dim == -1 else abs(start_dim)
end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
my_args = list(tensor_type.__args__)
lhs = my_args[0:start_dim]
rhs = my_args[end_dim:]
mid = my_args[start_dim:end_dim]
if Dyn in mid:
mid = [Dyn]
else:
mid = [reduce(operator.mul, my_args[start_dim:end_dim])]
new_type_list = lhs + mid + rhs
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}')
@register_inference_rule(torch.flatten)
def flatten_inference_rule(n: Node):
"""
Applies the flatten shape information to the input then gets the
greatest upper bound of the resulting type and the existing type
"""
assert isinstance(n.args[0], Node)
# set the default start and end dims
start_dim = 1
end_dim = -1
if len(n.args) > 1:
assert isinstance(n.args[1], int)
start_dim = n.args[1]
if len(n.args) > 2:
assert isinstance(n.args[2], int)
end_dim = n.args[2]
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output_type = flatten_check(n.args[0].type, start_dim, end_dim)
n.type = get_greatest_upper_bound(output_type , n.type)
return n.type
class GraphTypeChecker:
def __init__(self, env, traced):
self.env = env
self.traced = traced
def type_check(self):
"""
A gradual type checker for graphs
Effect: every node's field type will be
populated with a type after type-checking is done
"""
graph = self.traced.graph
# type check every node with gradual type rules
# if any node does not type check return false
for n in graph.nodes:
self.type_check_node(n)
return True
def type_check_node(self, n: Node):
"""
Type check a given fx node.
Current operations:
- Reshape
- Transpose
- Add
- Relu
- conv2d
- batchnorm2d
- flatten
- maxpool2d
- adaptiveavgpool2d
- linear
"""
if n.type is None:
n.type = Dyn
if n.op == 'placeholder':
return n.type
elif n.op == 'get_attr':
t = get_parameter(self.traced, n.target) # type: ignore[arg-type]
if isinstance(t.data, torch.Tensor):
n.type = TensorType(t.data.shape)
return n.type
elif n.op == 'call_function':
if n.target == getattr:
assert getattr in _INFERENCE_RULES
return _INFERENCE_RULES[n.target](n, self.traced)
elif n.target in _INFERENCE_RULES:
return _INFERENCE_RULES[n.target](n)
else:
raise RuntimeError(f'No inference rule registered for target {n.target}!')
elif n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _INFERENCE_RULES:
return _INFERENCE_RULES[type(module_instance)](n, module_instance)
else:
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
elif n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else:
raise NotImplementedError(f"Method {n.op} not yet implemented")
@register_refinement_rule(Conv2d)
def conv_refinement_rule(n: Node):
"""
The equality constraints are between the first dimension of
the input and output
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
return res
@register_refinement_rule(torch.nn.Linear)
def linear_refinement_rule(n: Node):
"""
The equality constraints are between the first dimension of
the input and output
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
return res
@register_refinement_rule(BatchNorm2d)
@register_refinement_rule(torch.nn.ReLU)
def all_eq(n: Node):
"""
For operations where the input shape is equal to the output shape
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
args1 = arg_type.__args__
args2 = n.type.__args__
res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
return res
@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
@register_refinement_rule(torch.nn.MaxPool2d)
def first_two_eq(n: Node):
"""
For operations where the first two dimensions of the input and output shape
are equal
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
args1 = arg_type.__args__
args2 = n.type.__args__
res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
return res
@register_refinement_rule(torch.add)
@register_refinement_rule(operator.add)
def element_wise_eq(n: Node):
"""
For element-wise operations and handles broadcasting.
Note that after applying broadcasting to the arguments
we are able to determine if certain dimensions have not been broadcast
if they are symbolicallu equal.
in this case, we can establish equality between those dimensions and the
corresponding output dimensions.
Note that it takes two iterations for this result. One iteration to establish
equality between certain dimensions of the operands (requiring the whole solver
including unification) and another iteration to establish equality between the operands
and the resulting type, requiring another round of constraint generation and unificaiton.
"""
res = []
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
arg_type1 = n.args[0].type
arg_type2 = n.args[1].type
if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
args1, args2 = broadcast_types(arg_type1, arg_type2)
# by this point, we know that args1 and args2 are the same size.
a1 = args1.__args__
a2 = args2.__args__
a3 = n.type.__args__
# we would be here in the second iteration where we establish equality
# between operand type dimensions and the resulting type dimensions
r = []
for x, y, z in zip(a1, a2, a3):
if x == y:
r.append(Equality(x, z))
res = r
return res
@register_refinement_rule(torch.flatten)
def flatten_refinement_rule(n: Node):
"""
Generates equality constraints between the dimensions of the input and output
that will not be involved in the flatten operation
"""
assert isinstance(n.args[0], Node)
eq_const = []
start_dim = 1
end_dim = -1
if len(n.args) > 1:
assert isinstance(n.args[1], int)
start_dim = n.args[1]
if len(n.args) > 2:
assert isinstance(n.args[2], int)
end_dim = n.args[2]
if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType):
l = len(n.type.__args__)
arg_type = n.args[0].type
start_dim = l if start_dim == -1 else start_dim
end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]):
eq_const.append(Equality(t1, t2))
for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]):
eq_const.append(Equality(t1, t2))
return eq_const
@register_algebraic_expressions_inference_rule(Conv2d)
def conv_rule(n: Node, module_instance):
"""
Represents the outout in terms of an algrbraic expression w.r.t
the input when possible
"""
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
w_in = arg_type.__args__[3]
h_in = arg_type.__args__[2]
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out))
n.type = new_type
return new_type
class Refine:
"""
Symbolic shape inference.
Generates constraints over type variables.
Currently all constraints are equality constraints.
"""
def __init__(self, traced):
self.constraints = []
self.traced = traced
self.symbol_iter = itertools.count(start=0, step=1)
def refine(self):
"""
Generates constraints for
every node in the graph based on
the operation.
"""
graph = self.traced.graph
for n in graph.nodes:
self.refine_node(n)
return True
def symbolic_relations(self):
"""
Infers algebraic relations
"""
graph = self.traced.graph
for n in graph.nodes:
self.infer_symbolic_relations(n)
return True
def replace_dyn_with_fresh_var(self, typ):
"""
Replace all unknown types with fresh type variables.
"""
if typ == Dyn:
new_symbol = Var(next(self.symbol_iter))
return new_symbol
elif isinstance(typ, TensorType):
new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__]
return TensorType(tuple(new_args))
elif isinstance(typ, list):
return [self.replace_dyn_with_fresh_var(t) for t in typ]
elif isinstance(typ, tuple):
return (self.replace_dyn_with_fresh_var(t) for t in typ)
else:
return typ
def convert_to_sympy_symbols(self, typ):
"""
Replace all unknown types with fresh type variables.
"""
if isinstance(typ, Var):
return sympy.symbols(str(typ))
elif isinstance(typ, TensorType):
new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
return TensorType(tuple(new_args))
elif isinstance(typ, list):
return [self.convert_to_sympy_symbols(t) for t in typ]
elif isinstance(typ, tuple):
return (self.convert_to_sympy_symbols(t) for t in typ)
else:
return typ
def refine_node(self, n: Node):
"""
Returns a list of equality constraints for
call_module and call_function nodes.
Models the relation between input and output dimensions
using constraints in case they are both tensors.
All operations used in resnet50 are defined.
"""
if n.type is None:
n.type = Dyn
n.type = self.replace_dyn_with_fresh_var(n.type)
if n.op == 'call_function':
if n.target in _REFINEMENT_RULES:
self.constraints += _REFINEMENT_RULES[n.target](n)
else:
pass
if n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _REFINEMENT_RULES:
self.constraints += _REFINEMENT_RULES[type(module_instance)](n)
else:
pass
if n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else:
pass
def infer_symbolic_relations(self, n: Node):
n.type = self.convert_to_sympy_symbols(n.type)
if n.op == 'call_function':
if n.target in _RULES:
return _RULES[n.target](n)
else:
pass
if n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _RULES:
return _RULES[type(module_instance)](n, module_instance)
else:
pass
if n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else:
pass
def get_parameter(traced, target: str):
"""
Returns the parameter given by ``target`` if it exists,
otherwise throws an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the Parameter
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
torch.nn.Parameter: The Parameter referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Parameter``
"""
module_path, _, param_name = target.rpartition(".")
mod: torch.nn.Module = traced.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
param: torch.nn.Parameter = getattr(mod, param_name)
return param

View File

@ -0,0 +1,172 @@
# mypy: allow-untyped-defs
import torch
from torch.fx.node import Node
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.passes.tools_common import legalize_graph
import itertools
import operator
from typing import Dict, List, Tuple
def split_result_tensors(
result: torch.Tensor, inputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
"""
A free function for use in the merge_matmul graph transformation below that
splits the output from a merged matmul into the individual results for each
input tensor.
Arguments:
result: The merged matmul result tensor.
inputs: The list of inputs that were merged into one for the matmul.
Returns:
List of matmul results for each input tensor.
"""
# When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
# need an int even when tracing
if isinstance(result, torch.fx.Proxy):
splits = [0] * len(inputs)
else:
splits = [x.shape[0] for x in inputs]
return torch.split(result, splits)
def may_depend_on(a: Node, b: Node, search_depth: int = 6):
"""
Determine if one node depends on another in a torch.fx.Graph.
Arguments:
a: The node that may have a dependency on b.
b: The node that a may have a dependency on.
search_depth: In the case of an indirect dependency, this function
searches upto this many nodes away in search of a
data dependency. If none is found, the function
makes the conservative assumption that there is a
dependency.
Returns:
True if a may depend on b, False if it definitely does not.
"""
# Equivalence is defined as dependence.
if a == b:
return True
# If a has no inputs, it cannot depend on b.
if len(a.all_input_nodes) == 0:
return False
# If the search depth has been exhausted and no conclusion has been
# reached, assume that there is a data dependency.
if search_depth == 0:
return True
# Recursively check all inputs of a.
for inp in a.all_input_nodes:
if may_depend_on(inp, b, search_depth - 1):
return True
return False
def are_nodes_independent(nodes: List[Node]):
"""
Check if all of the given nodes are pairwise-data independent.
Arguments:
nodes: The nodes to check for data dependencies.
Returns:
True if any pair in nodes has a data dependency.
"""
# For each pair in nodes:
for i, j in itertools.combinations(nodes, 2):
if may_depend_on(i, j) or may_depend_on(j, i):
return False
return True
def merge_matmul(in_mod: torch.nn.Module):
"""
A graph transformation that merges matrix multiplication operations that share the same right-hand
side operand into one large matrix multiplication.
____ _________ _________
---- | | | | M| A * C |
M| A | T| B | * K| C | = |---------|
---- , | | | | T| B * C |
K ---- --------- ---------
K R R
"""
gm = symbolic_trace(in_mod)
rhs_users: Dict[Node, List[Node]] = {}
lhs_users: Dict[Node, List[Node]] = {}
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
# the matmul of which they are the LHS/RHS.
for node in gm.graph.nodes:
if node.op != "call_function" or node.target is not torch.matmul:
continue
lhs, rhs = node.args
# TODO: Properly handle aliasing caused by get_attr. For now,
# use the attribute name as the operand if the node is a
# get_attr.
lhs = lhs.target if lhs.op == "get_attr" else lhs
rhs = rhs.target if rhs.op == "get_attr" else rhs
lhs_users.setdefault(lhs, []).append(node)
rhs_users.setdefault(rhs, []).append(node)
for rhs, mms in rhs_users.items():
# There must be at least matmuls for a merge to make sense.
if len(mms) < 2:
continue
# All matmuls must not depend on each other directly or indirectly
# in order for the merge to be possible.
if not are_nodes_independent(mms):
continue
lhs_vals = [mm.args[0] for mm in mms]
# Merge the matmul.
# Collect a list of LHS operands and the single RHS operand.
lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
# Concatenate all the LHS operands.
merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
# Multiply the concatenated LHS operands with the one RHS. This will produce
# the same results as all the individual matmuls involving rhs in the original graph,
# but they will all be concatenated together.
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
# Split the result of the merged matmul using the shapes of the LHS operands
# to ascertain how large each chunk should be.
merge_mm_split = gm.graph.call_function(
split_result_tensors, (merge_mm, lhs), {}
)
merge_mm_res = [
gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
for out in range(len(lhs))
]
# Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
for old, new in zip(mms, merge_mm_res):
old.replace_all_uses_with(new)
gm.graph.erase_node(old)
# All of the new nodes created above were inserted at the end, so we need to sort
# the nodes topologically to make sure all definitions precede uses.
legalize_graph(gm)
gm.recompile()
gm.graph.lint()
return gm

View File

@ -0,0 +1,269 @@
# mypy: allow-untyped-defs
import torch
import torch.fx
import warnings
import functools
import builtins
from typing import Any, Callable, Dict, Optional, Union
def embedding_override(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
def nn_layernorm_override(self, input):
return input
def torch_relu_override(x):
return x
def torch_nn_relu_override(self, x):
return x
def functional_relu_override(x, inplace=False):
assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
return x
def torch_where_override(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
def torch_abs_override(input, *, out=None):
assert out is None, 'Dont support in-place abs for MetaTensor analysis'
return input
manual_meta_overrides : Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: nn_layernorm_override,
torch.relu: torch_relu_override,
torch.nn.functional.relu: functional_relu_override,
torch.nn.ReLU: torch_nn_relu_override,
torch.where: torch_where_override,
torch.abs: torch_abs_override,
}
def gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, torch.fx.Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy('call_function', target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
class MetaProxy(torch.fx.Proxy):
def install_tensor_meta(self, tensor_meta):
self._tensor_meta = tensor_meta
def size(self, dim=None):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.size(*[dim] if dim else [])
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
def dim(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.dim()
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
@property
def shape(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.shape
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
@property
def dtype(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.dtype
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, 'device')
def __getattr__(self, k):
if k == '_tensor_meta':
return self.__getattribute__(k)
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return MetaAttribute(self, k)
class MetaAttribute(MetaProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
class MetaDeviceAttribute(MetaAttribute):
pass
def proxys_to_metas(v):
if isinstance(v, MetaDeviceAttribute):
return 'meta'
if isinstance(v, torch.fx.Proxy):
assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
return v._tensor_meta
return v
class MetaTracer(torch.fx.Tracer):
allow_insert_stateless_mods : bool = True
_TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if kind == 'placeholder' and target in self.meta_args:
rv.install_tensor_meta(self.meta_args[target])
return rv
if target in self.orig_fns:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if 'device' in kwargs:
kwargs['device'] = 'meta'
try:
args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
if kind == 'call_function':
meta_target = manual_meta_overrides.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == 'call_method':
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index]
elif kind == 'call_module':
assert hasattr(self, 'orig_forward')
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in manual_meta_overrides:
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == 'get_attr':
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split('.')
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
assert isinstance(attr_itr, torch.Tensor)
meta_out = attr_itr.to(device='meta')
finally:
self._disable_module_getattr = False
else:
return rv
# TODO
assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
rv.install_tensor_meta(meta_out)
except Exception as e:
warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
return rv
def getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, '_disable_module_getattr', False):
return attr_val
else:
return super().getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
idx = 0
mod_name = mod.__class__.__name__.lower()
path = f"{mod_name}_{idx}"
while hasattr(self.root, path):
path = f"{mod_name}_{idx}"
idx += 1
self.root.add_module(path, mod)
return path
def path_of_module(self, mod: torch.nn.Module) -> str:
try:
return super().path_of_module(mod)
except NameError as e:
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod)
self.prev_module = path
return path
raise
def proxy(self, node):
return MetaProxy(node, self)
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
assert isinstance(meta_args, dict)
self.meta_args = meta_args
self.patched_torch_methods = {
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
try:
graph = super().trace(root, concrete_args)
graph._tracer_extras = {'meta_args': meta_args}
return graph
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
gm = torch.fx.GraphModule(tracer.root, graph, name)
return gm

View File

@ -0,0 +1,558 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
op_mod, op_gt, op_lt, op_neq, op_eq
from torch.fx.tensor_type import TensorType, Dyn
class Constraint:
pass
class Conj(Constraint):
def __init__(self, conjuncts):
"""
:param conjuncts: Conjunction of constraints
"""
self.conjucts = conjuncts
def __eq__(self, other):
if isinstance(other, Conj):
return self.conjucts == other.conjucts and self.conjucts == other.conjucts
else:
return False
def __repr__(self):
return f'And({self.conjucts})'
class Disj(Constraint):
def __init__(self, disjuncts):
"""
:param disjuncts: Disjunction of constraints
"""
self.disjuncts = disjuncts
def __eq__(self, other):
if isinstance(other, Disj):
return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
else:
return False
def __repr__(self):
return f'Or({self.disjuncts})'
class Prod(Constraint):
def __init__(self, products):
"""
:param products: lists of dimensions to multiply
"""
self.products = products
def __eq__(self, other):
if isinstance(other, Prod):
return self.products == other.products and self.products == other.products
else:
return False
def __repr__(self):
return f'Product({self.products})'
class T(Constraint):
"""
True
"""
def __init__(self) -> None:
pass
def __eq__(self, other):
return isinstance(other, T)
def __repr__(self):
return 'True'
class F(Constraint):
"""
False
"""
def __init__(self) -> None:
pass
def __eq__(self, other):
return isinstance(other, F)
def __repr__(self):
return 'False'
class BinaryConstraint(Constraint):
"""
Represents all binary operations
"""
def __init__(self, lhs, rhs, op):
"""
:param lhs: lhs of the constraint
:param rhs: rhs of the constraint
:param op: string representing the operation
"""
self.lhs = lhs
self.rhs = rhs
self.op = op
def __eq__(self, other):
if isinstance(other, BinaryConstraint):
return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
else:
return False
def __repr__(self):
return f'({self.lhs} {self.op} {self.rhs})'
class BinConstraintT(BinaryConstraint):
"""
Binary constraints about tensors
"""
def __init__(self, lhs, rhs, op):
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
(isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
super().__init__(lhs, rhs, op)
def __eq__(self, other):
return super().__eq__(other)
class BinConstraintD(BinaryConstraint):
"""
Binary constraints about dimensions
"""
def __init__(self, lhs, rhs, op):
assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
super().__init__(lhs, rhs, op)
def __eq__(self, other):
return super().__eq__(other)
class TGreatestUpperBound(Constraint):
"""
Greatest Upper bound for tensors with dynamic type
"""
def __init__(self, res, rhs1, rhs2):
"""
:param res: tensor variable that stores the result of the outout
:param rhs1: tensor or tensor variable
:param rhs2: tensor or tensor variabke
"""
self.res = res
self.rhs1 = rhs1
self.rhs2 = rhs2
def __repr__(self):
return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}'
def __eq__(self, other):
if isinstance(other, TGreatestUpperBound):
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
else:
return False
class DGreatestUpperBound(Constraint):
"""
Greatest Upper bound for dimensions
"""
def __init__(self, res, rhs1, rhs2):
"""
:param res: Dimension variable to store the result
:param rhs1: dimension variable 1
:param rhs2: dimension variable 2
"""
assert is_dim(res)
assert is_dim(rhs1)
assert is_dim(rhs2)
self.res = res
self.rhs1 = rhs1
self.rhs2 = rhs2
def __repr__(self):
return f'{self.res} = {self.rhs1}\u2294{self.rhs2}'
def __eq__(self, other):
if isinstance(other, DGreatestUpperBound):
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
else:
return False
class CanReshape(Constraint):
"""
can_reshape constraint
"""
def __init__(self, src, target):
"""
:param src: tensor variable
:param target: tensor
"""
self.src = src
self.target = target
def __repr__(self):
return f'can-reshape({self.src}, {self.target})'
def __eq__(self, other):
if isinstance(other, CanReshape):
return self.src == other.src and self.target == other.target
else:
return False
class IndexSelect(Constraint):
def __init__(self, tensor_size, input_var, dim_replace, index, output):
"""
Args:
input_var: input to index_select
tensor_size: tensor size we are considering
dim_replace: the dimension of the output at "index"
index: location of the dimensions to replace in the input
output: variable to store the result
"""
assert isinstance(input_var, TVar)
assert isinstance(output, TVar)
assert isinstance(dim_replace, DVar) or dim_replace == Dyn
assert isinstance(index, int)
self.input_var = input_var
self.tensor_size = tensor_size
self.dim_replace = dim_replace
self.index = index
self.output = output
def __repr__(self):
return f' {self.output} = ' \
f'IndexSelect({self.input_var}, ' \
f'tensor_size: {self.tensor_size}, ' \
f'{self.dim_replace}, ' \
f'{self.index})'
def __eq__(self, other):
if isinstance(other, IndexSelect):
return self.tensor_size == other.tensor_size and \
self.dim_replace == other.dim_replace and \
self.index == other.index and \
self.output == other.output and \
self.input_var == other.input_var
else:
return False
class Transpose(Constraint):
def __init__(self, tensor_size, input_var, index1, index2, output):
"""
Args:
tensor_size: current tensor size
input_var: variable to hold input
index1: dimension 1
index2: dimension 2
output: output that stores result
"""
assert isinstance(input_var, TVar)
assert isinstance(output, TVar)
assert isinstance(index1, int)
assert isinstance(index2, int)
self.input_var = input_var
self.tensor_size = tensor_size
self.index1 = index1
self.index2 = index2
self.output = output
def __repr__(self):
return f' {self.output} = ' \
f'Transpose({self.input_var}, ' \
f'tensor_size: {self.tensor_size}, ' \
f'{self.index1}, ' \
f'{self.index2})'
def __eq__(self, other):
if isinstance(other, Transpose):
return self.tensor_size == other.tensor_size and \
self.index1 == other.index1 and \
self.index2 == other.index2 and \
self.output == other.output and \
self.input_var == other.input_var
else:
return False
class GetItem(Constraint):
def __init__(self, tensor_size, index, res, input_var):
"""
Constraint for getting item given a tensor size
:param tensor_size: actual number
:param index: actual number representing the index
:param res: dimension variable to carry the item we get
:param input_var: a tensor variable from which we will get item
"""
assert isinstance(res, DVar)
self.res = res
self.tensor_size = tensor_size
self.index = index
self.input_var = input_var
def __repr__(self):
return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
def __eq__(self, other):
if isinstance(other, GetItem):
return self.res == other.res and \
self.tensor_size == other.tensor_size and \
self.index == other.index and \
self.input_var == other.input_var
else:
return False
class GetItemTensor(Constraint):
def __init__(self, tensor_size, index_tuple, res, input_var):
"""
Constraint for getting item given a tensor size
However, when the argument is a tuple, we will
expect a tensor
:param tensor_size: actual number representing the rank
:param index_tuple: tuple for indexing
:param res: tensor variable to carry the item we get
:param input_var: a tensor variable from which we will get item
"""
assert isinstance(res, TVar)
self.res = res
self.tensor_size = tensor_size
self.index_tuple = index_tuple
self.input_var = input_var
def __repr__(self):
return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
def __eq__(self, other):
if isinstance(other, GetItemTensor):
return self.res == other.res and \
self.tensor_size == other.tensor_size and \
self.index_tuple == other.index_tuple and \
self.input_var == other.input_var
else:
return False
class CalcConv(Constraint):
def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
"""
:param conv_result: the convolution result
:param input_var: input to convolution
:param c_out: output chanel type
:param kernel: kernel tuple
"""
self.conv_result = conv_result
self.input_var = input_var
self.c_out = c_out
self.kernel = kernel
self.padding = padding
self.stride = stride
self.dilation = dilation
self.matching_constraint = matching_constraint_vars
def __repr__(self):
return f'{self.conv_result} =' \
f' calc-conv({self.input_var},' \
f' {self.c_out}, {self.kernel}, ' \
f'{self.padding}, {self.stride},' \
f' {self.dilation})'
def __eq__(self, other):
if isinstance(other, CalcConv):
return self.conv_result == other.conv_result and self.input_var == other.input_var and \
self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
and self.stride == other.stride and self.dilation == other.dilation \
and self.matching_constraint == other.matching_constraint
else:
return False
class CalcMaxPool(Constraint):
def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
"""
:param maxpool_result: the result of maxpool
:param input_var: input to convolution
:param kernel: kernel tuple
"""
self.maxpool_result = maxpool_result
self.input_var = input_var
self.kernel = kernel
self.padding = padding
self.stride = stride
self.dilation = dilation
self.matching_constraint = matching_constraint_vars
def __repr__(self):
return f'{self.maxpool_result} =' \
f' calc-maxpool({self.input_var},' \
f' {self.kernel}, ' \
f'{self.padding}, {self.stride},' \
f' {self.dilation})'
def __eq__(self, other):
if isinstance(other, CalcMaxPool):
return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
and self.kernel == other.kernel and self.padding == other.padding \
and self.stride == other.stride and self.dilation == other.dilation \
and self.matching_constraint == other.matching_constraint
else:
return False
class ApplyBroadcasting(Constraint):
def __init__(self, res1, res2, input1, input2):
"""
:param res1: resulting tensor 1
:param res2: resulting tensor 2
:param input1: tensor variable 1
:param input2: tensor variable 2
"""
self.res1 = res1
self.res2 = res2
self.input1 = input1
self.input2 = input2
def __eq__(self, other):
if isinstance(other, ApplyBroadcasting):
return self.res1 == other.res1 \
and self.res2 == other.res2 \
and self.input1 == other.input1 \
and self.input2 == other.input2
else:
return False
def __repr__(self):
return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
class CalcProduct(Constraint):
"""
Given correct dimensions, calculate the product for flatten accounting for Dyn
"""
def __init__(self, start, end, flattened, dims_to_flatten):
"""
:param start: start index
:param end: end index
:param flattened: variable to store the product
:param dims_to_flatten: the type which we will flatten
"""
assert isinstance(dims_to_flatten, list)
assert isinstance(flattened, TVar)
assert isinstance(start, int)
assert isinstance(end, int)
self.start = start
self.end = end
self.dims_to_flatten = dims_to_flatten
self.flattened = flattened
def __eq__(self, other):
if isinstance(other, CalcProduct):
return self.start == other.start and self.end == other.end and \
self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
else:
return False
def __repr__(self):
return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
class TVar:
"""
Tensor variable with no tensor constructor
"""
def __init__(self, tvar):
"""
:param tvar: tensor variable
"""
self.tvar = tvar
def __repr__(self):
return f'TV({self.tvar})'
def __eq__(self, other):
if isinstance(other, TVar):
return self.tvar == other.tvar
else:
return False
class DVar:
"""
Dimension variable
"""
def __init__(self, c):
"""
:param c: character or number
"""
self.c = c
def __repr__(self):
return f'DV({self.c})'
def __eq__(self, other):
if isinstance(other, DVar):
return self.c == other.c
else:
return False
class BVar:
"""
Boolean variable
"""
def __init__(self, c):
"""
:param c: character or number
"""
self.c = c
def __repr__(self):
return f'BV({self.c})'
def __eq__(self, other):
if isinstance(other, BVar):
return self.c == other.c
else:
return False
def is_algebraic_expression(constraint):
if isinstance(constraint, BinConstraintD):
return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
else:
return isinstance(constraint, Prod)
def is_bool_expr(constraint):
if isinstance(constraint, BinConstraintD):
return constraint.op in [op_gt, op_lt, op_neq, op_eq]
else:
return isinstance(constraint, (BVar, Conj, Disj))
def is_dim(d):
return isinstance(d, (DVar, int)) or d == Dyn

View File

@ -0,0 +1,14 @@
op_add = '+'
op_sub = '-'
op_mul = '*'
op_div = '/'
op_eq = '='
op_neq = '!='
op_imp = '=>'
op_matching = '\u22b3' # (contains)
op_consistency = '~'
op_precision = '\u2291' # (square image of or equal to)
op_leq = '\u2264' # less-than or equal to
op_lt = '<'
op_gt = '>'
op_mod = '%'

View File

@ -0,0 +1,349 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
from torch.fx.tensor_type import TensorType, Dyn
try:
import z3 # type: ignore[import]
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
HAS_Z3 = True
def transform_to_z3(constraint, counter, dimension_dict):
if isinstance(constraint, Conj):
conjuncts = []
for c in constraint.conjucts:
new_c, counter = transform_to_z3(c, counter, dimension_dict)
conjuncts.append(new_c)
return z3.And(conjuncts), counter
elif isinstance(constraint, Disj):
disjuncts = []
for c in constraint.disjuncts:
new_c, counter = transform_to_z3(c, counter, dimension_dict)
disjuncts.append(new_c)
return z3.Or(disjuncts), counter
elif isinstance(constraint, T):
return True, counter
elif isinstance(constraint, F):
return False, counter
elif isinstance(constraint, BinConstraintT):
if constraint.op == op_eq:
lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
return (lhs == rhs), counter
else:
raise NotImplementedError('Method not yet implemented')
elif isinstance(constraint, BinConstraintD):
if constraint.op == op_eq:
if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
transformed_lhs = z3.Bool(constraint.lhs.c)
return transformed_lhs == transformed_rhs, counter
elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
# with dimension transformations we consider the encoding
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
return lhs == rhs, counter
else:
# then we have an algebraic expression which means that we disregard the
# first element of the encoding
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs == rhs, counter
# The assumption here is that the LHS and RHS must be dimensions
elif constraint.op == op_neq:
assert is_dim(constraint.lhs)
assert is_dim(constraint.rhs)
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
if constraint.rhs == Dyn or constraint.lhs == Dyn:
if constraint.rhs == Dyn:
return lhs.arg(0) == 1, counter
elif constraint.lhs == Dyn:
return rhs.arg(0) == 1, counter
# if one of the instances is a number
elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
if isinstance(constraint.lhs, int):
return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
elif isinstance(constraint.rhs, int):
return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
else:
return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
elif constraint.op == op_leq:
# if the dimensions are not dyn, this will come into effect
# there would have been another constraint specifying if a given dimension
# is dyn or not
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs <= rhs, counter
elif constraint.op == op_gt:
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs > rhs, counter
elif constraint.op == op_lt:
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs < rhs, counter
else:
raise NotImplementedError('operation not yet implemented')
else:
raise NotImplementedError('Operation not yet implemented')
def transform_var(tensor, counter, dimension_dict):
"""
Transforms tensor variables to a format understood by z3
Args:
tensor: Tensor variable or a tensor type potentially with variable dimensions
Returns: Transformed variable to a z3 format
"""
if isinstance(tensor, TensorType):
res = []
for t in tensor.__args__:
transformed, counter = transform_dimension(t, counter, dimension_dict)
res.append(transformed)
assert len(res) <= 4
if len(tensor.__args__) == 1:
return tensor_type.tensor1(res[0]), counter
elif len(tensor.__args__) == 2:
return tensor_type.tensor2(res[0], res[1]), counter
elif len(tensor.__args__) == 3:
return tensor_type.tensor3(res[0], res[1], res[2]), counter
elif len(tensor.__args__) == 4:
return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
elif tensor == Dyn:
return z3_dyn, counter
elif isinstance(tensor, TVar):
return z3.Const(tensor.tvar, tensor_type), counter
def transform_dimension(dimension, counter, dimension_dict):
"""
Takes a dimension variable or a number and transforms it to a tuple
according to our scheme
Args:
dimension: The dimension to be transformed
counter: variable tracking
Returns: tuple and the current counter
"""
if dimension == Dyn:
counter += 1
return D(0, z3.Int(counter)), counter
elif isinstance(dimension, int):
return D(1, dimension), counter
elif isinstance(dimension, DVar):
if dimension.c in dimension_dict:
return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
else:
counter += 1
dimension_dict[dimension.c] = counter
return D(z3.Int(counter), z3.Int(dimension.c)), counter
def transform_algebraic_expression(expr, counter, dimension_dict):
"""
Transforms an algebraic expression to z3 format
Args:
expr: An expression is either a dimension variable or an algebraic-expression
Returns: the transformed expression
"""
assert is_algebraic_expression(expr) or is_dim(expr)
if is_dim(expr):
transformed, counter = transform_dimension(expr, counter, dimension_dict)
return transformed.arg(1), counter
elif isinstance(expr, Prod):
dims = []
for dim in expr.products:
assert is_dim(dim)
d, counter = transform_dimension(dim, counter, dimension_dict)
dims.append(d.arg(1))
return z3.Product(dims), counter
elif is_algebraic_expression(expr):
lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
if expr.op == op_sub:
c = lhs - rhs
elif expr.op == op_add:
c = lhs + rhs
elif expr.op == op_div:
c = lhs / rhs
elif expr.op == op_mul:
c = lhs * rhs
elif expr.op == op_mod:
c = lhs % rhs
else:
raise NotImplementedError('operation not yet implemented')
return c, counter
else:
raise RuntimeError
def transform_all_constraints(traced, counter=0):
"""
Given a trace, generates constraints and transforms them to z3 format
"""
dimension_dict = {} # type: ignore[var-annotated]
generator = ConstraintGenerator(traced)
new_constraints, counter = generator.generate_constraints(counter)
# print(new_constraints.conjucts[0])
# print(*new_constraints.conjucts, sep='\n')
# transform precision, matching, consistency till obtaining a fixed point
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
# print(new_constraints)
# print(new_constraints.conjucts)
# new_constraints.conjucts = new_constraints.conjucts[:-1]
# print(*new_constraints.conjucts, sep='\n')
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
# print(transformed)
return transformed
def iterate_till_fixed_point(constraints, counter):
"""
Transform constraints till reaching a fixed point
"""
old_c = None
while old_c != constraints:
old_c = constraints
constraints, counter = transform_constraint(constraints, counter)
return constraints, counter
def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
"""
Takes a node and a graph and generates two sets of constraints.
One set constraints the node's constraints and another set
constraints the negation of the node's constraints
Args:
tracer_root: the root for getting the module instances
graph: the graph so far in the tracing process
node: node that represents a conditional
counter: variable tracking
Returns: Two sets of constraints. One with a conjunction with the
the conditional constraint and the other with a conjunction with
its negation.
"""
dimension_dict = {} # type: ignore[var-annotated]
generator = ConstraintGenerator(tracer_root, graph)
new_constraints, counter = generator.generate_constraints(counter)
condition_constraint = new_constraints.conjucts[-1]
# we know the constraint is a conjunction where the last constraint is about the conditional
# so remove the last constraint
new_constraints.conjucts = new_constraints.conjucts[:-1]
# transform precision, matching, consistency till obtaining a fixed point
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
# since the function returns a list of one element, we get the first element
# we are only interested in the RHS in this case because the LHS just stores
# the result
# we make sure the constraint is of the form:
# c = b where b is a boolean expression
# and we consider b (constraint.rhs) for transformation
assert isinstance(condition_constraint.lhs, BVar)
assert is_bool_expr(condition_constraint.rhs)
condition_constraint_rhs = condition_constraint.rhs
# transform the condition constraint
condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
return z3.And([transformed, transformed_condition_constraint]), \
z3.And([transformed, negation_transformed_condition_constraint])
def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
"""
Given an IR and a node representing a conditional, evaluate the conditional
and its negation
Args:
tracer_root: Tracer root for module instances
node: The node to be evaluated
Returns: the results of evaluating the condition and the negation with
the rest of the constraints
"""
transformed_positive, transformed_negative = \
transform_all_constraints_trace_time(tracer_root, graph, node, counter)
s = z3.Solver()
s.add(transformed_positive)
if user_constraints is not None:
s.add(user_constraints)
condition = s.check()
s = z3.Solver()
s.add(transformed_negative)
if user_constraints is not None:
s.add(user_constraints)
negation = s.check()
return condition, negation
except ImportError:
HAS_Z3 = False

View File

@ -0,0 +1,53 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
BVar
from torch.fx.experimental.migrate_gradual_types.operation import op_leq
def gen_tvar(curr):
"""
Generate a tensor variable
:param curr: The current counter
:return: a tensor variable and the updated counter
"""
curr += 1
return TVar(curr), curr
def gen_dvar(curr):
"""
Generate a dimension variable
:param curr: the current counter
:return: a dimension variable and an updated counter
"""
curr += 1
return DVar(curr), curr
def gen_bvar(curr):
"""
Generate a boolean variable
:param curr: the current counter
:return: a boolean variable and an updated counter
"""
curr += 1
return BVar(curr), curr
def gen_tensor_dims(n, curr):
"""
Generate a list of tensor dimensions
:param n: the number of dimensions
:param curr: the current counter
:return: a list of dimension variables and an updated counter
"""
dims = []
for _ in range(n):
dvar, curr = gen_dvar(curr)
dims.append(dvar)
return dims, curr
def gen_nat_constraints(list_of_dims):
"""
Generate natural number constraints for dimensions
"""
return [BinConstraintD(0, d, op_leq) for d in list_of_dims]

View File

@ -0,0 +1,29 @@
try:
import z3 # type: ignore[import]
HAS_Z3 = True
# dynamic type
dyn = z3.DeclareSort('Dyn')
dyn_type = z3.Const('dyn', dyn)
# dimension
dim = z3.Datatype('dim')
dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
dim = dim.create()
# tensors
tensor_type = z3.Datatype('TensorType')
tensor_type.declare('Dyn', ('dyn', dyn))
tensor_type.declare('tensor1', ('0', dim))
tensor_type.declare('tensor2', ('0', dim), ('1', dim))
tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
tensor_type = tensor_type.create()
# create dimension
D = dim.dim
z3_dyn = tensor_type.Dyn(dyn_type)
except ImportError:
HAS_Z3 = False

View File

@ -0,0 +1,163 @@
# mypy: allow-untyped-defs
import operator
from typing import Any, Callable, Dict, Tuple, Optional
import torch
import torch.fx
import torch.fx as fx
from torch.fx import Transformer, Proxy
from torch.fx.node import Argument, Target, Node, map_aggregate
from torch.fx.operator_schemas import (
normalize_module,
normalize_function,
create_type_hint,
)
from .schema_type_annotation import AnnotateTypesWithSchema
class NormalizeArgs(Transformer):
"""
Normalize arguments to Python targets. This means that
`args/kwargs` will be matched up to the module/functional's
signature and rewritten to exclusively kwargs in positional order
if `normalize_to_only_use_kwargs` is true. Also populates default
values. Does not support positional-only parameters or varargs
parameters (*args, **kwargs).
If the nodes have 'type' metadata, it will use it to disambiguate
overloads. Otherwise, it will throw an error.
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = NormalizeArgs(traced).transform()
"""
def __init__(
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
):
super().__init__(module)
self.node_map: Dict[Proxy, Node] = {}
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
def run_node(self, n: Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)
def get_type(arg):
if isinstance(arg, fx.Node):
return n.meta["type"] if "type" in n.meta else None
return type(arg)
arg_types = map_aggregate(n.args, get_type)
assert isinstance(arg_types, tuple)
arg_types = tuple([create_type_hint(i) for i in arg_types])
kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
if n.op == "call_function":
out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
else:
out = super().run_node(n)
if n.op != "output":
self.node_map[out] = n
out.node.meta = n.meta
out.node.type = n.type
return out
def call_function(
self,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
arg_types: Optional[Tuple[Any, ...]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
):
assert callable(target)
new_args_and_kwargs = normalize_function(
target,
args, # type: ignore[arg-type]
kwargs,
arg_types, # type: ignore[arg-type]
kwarg_types,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
new_args, new_kwargs = new_args_and_kwargs
return self.tracer.create_proxy(
"call_function", target, new_args, new_kwargs
)
else:
return super().call_function(target, args, kwargs)
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
assert isinstance(target, str)
new_args_and_kwargs = normalize_module(
self.module,
target,
args, # type: ignore[arg-type]
kwargs,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
new_args, new_kwargs = new_args_and_kwargs
return super().call_module(target, new_args, new_kwargs)
else:
return super().call_module(target, args, kwargs)
class NormalizeOperators(AnnotateTypesWithSchema):
"""
Normalize callsites that are different ways of "spelling" the same
invocation into a single, canonical call. Currently supports:
1. Normalize operators (e.g. operator.add) to the `torch` ops they
ultimately invoke (e.g. torch.add) when it is possible to statically
reason that
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = NormalizeOperators(traced).transform()
"""
binary_magic_method_remap: Dict[
Callable[[Any, Any], Any], Callable[[Any, Any], Any]
] = {
torch.add: operator.add,
torch.mul: operator.mul,
torch.sub: operator.sub,
torch.div: operator.truediv,
torch.floor_divide: operator.floordiv,
torch.remainder: operator.mod,
torch.eq: operator.eq,
torch.ne: operator.ne,
torch.lt: operator.lt,
torch.le: operator.le,
torch.gt: operator.gt,
torch.ge: operator.ge,
}
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
# Normalize operators according to the magic methods implemented on tensors here:
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
assert callable(target)
if target in self.binary_magic_method_remap:
if len(args) != 2:
return super().call_function(target, args, kwargs)
lhs, rhs = args
return super().call_function(
target=self.binary_magic_method_remap[target],
args=(lhs, rhs),
kwargs={},
)
return super().call_function(target, args, kwargs)

View File

@ -0,0 +1,409 @@
# mypy: allow-untyped-defs
import torch.fx as fx
from torch.fx.node import Argument, Target
from torch.nn.utils.fusion import fuse_conv_bn_eval
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx.passes.shape_prop import ShapeProp
import copy
from collections import defaultdict
import torch.utils.mkldnn as th_mkldnn
import operator
import time
import logging
from enum import Enum
def _parent_name(target : str) -> Tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
# Works for length 2 patterns with 2 modules
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
if len(node.args) == 0:
return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node):
return False
if current_node.op != 'call_module':
return False
if not isinstance(current_node.target, str):
return False
if current_node.target not in modules:
return False
if type(modules[current_node.target]) is not expected_type:
return False
return True
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
"""
Fuses convolution/BN layers for inference purposes. Will deepcopy your
model by default, but can modify the model inplace as well.
"""
patterns = [(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d)]
if not inplace:
model = copy.deepcopy(model)
if not no_trace or not isinstance(model, torch.fx.GraphModule):
fx_model = fx.symbolic_trace(model)
else:
fx_model = model
modules = dict(fx_model.named_modules())
new_graph = copy.deepcopy(fx_model.graph)
for pattern in patterns:
for node in new_graph.nodes:
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
if not bn.track_running_stats:
continue
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
new_graph.erase_node(node)
return fx.GraphModule(fx_model, new_graph)
def remove_dropout(model: nn.Module) -> nn.Module:
"""
Removes all dropout layers from the module.
"""
fx_model = fx.symbolic_trace(model)
class DropoutRemover(torch.fx.Transformer):
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if isinstance(self.submodules[target], nn.Dropout):
assert len(args) == 1
return args[0]
else:
return super().call_module(target, args, kwargs)
return DropoutRemover(fx_model).transform()
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
"""
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
"""
new_graph = fx.Graph()
env: Dict[fx.Node, fx.Node] = {}
for input in inputs:
new_node = new_graph.placeholder(input.name)
env[input] = new_node
for node in nodes:
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
new_graph.output([env[output] for output in outputs])
new_graph.lint()
return fx.GraphModule(orig_module, new_graph)
mkldnn_supported = [
nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
torch.relu, torch.transpose, torch.sigmoid,
F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
]
# These are operators that may not be convertible into MKLDNN ops (e.g. the
# args are scalar values). Thus, we only include them in the subgraph if their
# arguments are already in MKLDNN.
# TODO: Determine whether this can be removed after type inference.
mkldnn_supported_unknown = [operator.add, operator.mul]
mkldnn_map = {
nn.Conv2d: th_mkldnn.MkldnnConv2d,
nn.Linear: th_mkldnn.MkldnnLinear,
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
}
def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
"""
For each node, if it's a module that can be preconverted into MKLDNN,
then we do so and create a mapping to allow us to convert from the MKLDNN
version of the module to the original.
"""
old_modules: Dict[nn.Module, nn.Module] = {}
for node in nodes:
if node.op == 'call_module':
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in mkldnn_map:
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert isinstance(new_module, nn.Module)
old_modules[new_module] = copy.deepcopy(cur_module)
replace_node_module(node, modules, new_module)
return old_modules
def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
"""
Maps each module that's been changed with `modules_to_mkldnn` back to its
original.
"""
for node in nodes:
if node.op == 'call_module':
assert (isinstance(node.target, str))
cur_module = modules[node.target]
if cur_module in old_modules:
replace_node_module(node, modules, old_modules[cur_module])
class MklSubgraph:
def __init__(self, fx_graph: fx.Graph):
self.fx_graph = fx_graph
self.nodes: List[fx.Node] = []
self.start_nodes: List[fx.Node] = []
self.end_nodes: List[fx.Node] = []
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
"""
This generates a heuristic that can be passed into `optimize_for_inference` that
determines whether a subgraph should be run in MKL by running it with the example_inputs.
Example usage:
heuristic = gen_mkl_autotuner(example_inputs, iters=10)
fast_model = optimization.optimize_for_inference(model, heuristic)
"""
fx_model = None
old_modules = None
def use_mkl_heuristic(graph: MklSubgraph) -> bool:
nonlocal fx_model, old_modules
input_nodes = graph.start_nodes
if fx_model is None:
fx_model = graph.fx_graph.owning_module
old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
ShapeProp(fx_model).propagate(example_inputs)
sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
def benchmark(f):
for _ in range(warmup):
f()
begin = time.time()
for _ in range(iters):
out = f()
return time.time() - begin
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
return mkl_time < no_mkl_time
return use_mkl_heuristic
def use_mkl_length(graph: MklSubgraph) -> bool:
"""
This is a heuristic that can be passed into `optimize_for_inference` that
determines whether a subgraph should be run in MKL by checking if there
are more than 2 nodes in it
"""
return len(graph.nodes) > 2
class UnionFind:
def __init__(self, n):
self.parent: List[Optional[int]] = [None] * n
self.size: List[int] = [0] * n
def make_set(self, v: int):
self.parent[v] = v
self.size[v] = 1
def find(self, v: int) -> int:
par = self.parent[v]
if v == par:
return v
assert par is not None
self.parent[v] = self.find(par)
return cast(int, self.parent[v])
def join(self, a: int, b: int):
a, b = self.find(a), self.find(b)
if a == b:
return a
if self.size[a] < self.size[b]:
a, b = b, a
self.parent[b] = a
self.size[a] += self.size[b]
def optimize_for_inference(
model: torch.nn.Module,
pass_config: Optional[Dict[str, Any]] = None,
tracer: Type[fx.Tracer] = fx.Tracer
) -> torch.nn.Module:
"""
Performs a set of optimization passes to optimize a model for the
purposes of inference. Specifically, the passes that are run are:
1. Conv/BN fusion
2. Dropout removal
3. MKL layout optimizations
The third optimization takes a function `use_mkl_heuristic` that's used
to determine whether a subgraph should be explicitly run in MKL layout.
Note: As FX does not currently handle aliasing, this pass currently
assumes nothing aliases. If that isn't true, use at your own risk.
"""
default_pass_config = {
"conv_bn_fuse": True,
"remove_dropout": True,
"mkldnn_layout_optimize": {'heuristic': use_mkl_length},
}
if pass_config is None:
pass_config = {}
default_pass_config.update(pass_config)
if default_pass_config["conv_bn_fuse"]:
model = fuse(model)
if default_pass_config["remove_dropout"]:
model = remove_dropout(model)
if default_pass_config["mkldnn_layout_optimize"] is False:
return model
if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
raise RuntimeError("mkldnn_layout_optimize config is not a dict")
if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
cur_tracer = tracer()
fx_graph = cur_tracer.trace(copy.deepcopy(model))
fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
modules: Dict[str, nn.Module] = dict(model.named_modules())
class MklSupport(Enum):
NO = 1
YES = 2
UNKNOWN = 3
# Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
# If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
# However, if it's in `mkldnn_supported_unknown`, then we only treat it as
# a MKLDNN node if its inputs are MKLDNN nodes.
for node in list(fx_graph.nodes):
supports_mkldnn = MklSupport.NO
if node.op == 'call_module':
cur_module = modules[node.target]
if type(cur_module) in mkldnn_supported:
supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None:
assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
elif node.op == 'call_function':
if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES
elif node.target in mkldnn_supported_unknown:
supports_mkldnn = MklSupport.UNKNOWN
if supports_mkldnn != MklSupport.NO:
if supports_mkldnn == MklSupport.UNKNOWN:
if not any(arg.target == 'to_dense' for arg in node.args):
continue
with fx_graph.inserting_before(node):
mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
with fx_graph.inserting_after(node):
dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
node.replace_all_uses_with(dense_x)
dense_x.args = (node,)
# Does pre-conversion of all modules into MKLDNN (when possible)
old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
fx_graph.old_modules = old_modules # type: ignore[attr-defined]
# optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
for node in fx_graph.nodes:
if node.op == 'call_method' and node.target == 'to_dense':
prv_node = node.args[0]
users = list(node.users)
for user in users:
if user.op == 'call_method' and user.target == 'to_mkldnn':
user.replace_all_uses_with(prv_node)
fx_graph.erase_node(user)
if len(node.users) == 0:
fx_graph.erase_node(node)
num_nodes = len(fx_graph.nodes)
uf = UnionFind(num_nodes)
def get_color(n):
if hasattr(n, 'color'): # Current node is part of a MKL subgraph
return uf.find(n.color)
if hasattr(n, 'start_color'): # Current node is input to MKL subgraph
return uf.find(n.start_color)
return None
# This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
# of input nodes (which are only `to_mkldnn` calls), output nodes
# (`to_dense` calls), and intermediate nodes, which are run entirely on
# MKLDNN layout tensors.
#
# Specifically, this code does a flood fill on a directed acyclic graph
# (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
# If every node only had one input, this would be sufficient. However, in
# the case that a node has multiple inputs coming from different start
# nodes (i.e. colors), we need to join these 2 colors into 1. That's done
# using a Disjoint Set Union.
for cur_idx, node in enumerate(fx_graph.nodes):
if node.op == 'call_method' and node.target == 'to_mkldnn':
node.start_color = cur_idx
uf.make_set(cur_idx)
elif node.op == 'call_method' and node.target == 'to_dense':
assert get_color(node.args[0]) is not None
node.end_color = get_color(node.args[0])
else:
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
if len(cur_colors) == 0:
continue
assert not any(i is None for i in cur_colors)
cur_colors = sorted(cur_colors)
node.color = cur_colors[0]
for other_color in cur_colors[1:]:
uf.join(cur_colors[0], other_color)
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
for node in fx_graph.nodes:
if hasattr(node, 'color'):
mkldnn_graphs[uf.find(node.color)].nodes.append(node)
if hasattr(node, 'start_color'):
mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
if hasattr(node, 'end_color'):
mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
# Now that we have all the subgraphs, we need to decide which MKLDNN
# subgraphs we actually want to keep in MKLDNN.
for graph in mkldnn_graphs.values():
if not use_mkl_heuristic(graph):
for node in graph.start_nodes + graph.end_nodes:
prv = node.args[0]
node.replace_all_uses_with(prv)
fx_graph.erase_node(node)
reset_modules(graph.nodes, modules, old_modules)
mkldnn_conversions = 0
for node in fx_graph.nodes:
if node.target == 'to_mkldnn' or node.target == 'to_dense':
mkldnn_conversions += 1
logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
fx_graph.lint()
result = fx.GraphModule(model, fx_graph)
return result

View File

@ -0,0 +1,318 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import NamedTuple, Dict, List, Set
from torch.fx.node import Node, map_arg
class Partition:
"""Partition class contains all the information about an individual partition.
It also provides necessary methods for manipulation the partition.
"""
def __init__(self, partition_id: int) -> None:
self.nodes: Set[Node] = set()
self.partition_id = partition_id
self.parents: Set[Partition] = set()
self.children: Set[Partition] = set()
self.bfs_level: int = -1
self.used_mem_bytes: int = 0
self.logical_device_ids: List[int] = []
def __str__(self):
return str(self.partition_id)
def recalculate_mem_size(self):
self.used_mem_bytes = 0
for node in self.nodes:
self.used_mem_bytes += get_extra_size_of(node, self.nodes)
def add_node(self, node):
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Add current node's input nodes if they are placeholder or constants
for n in input_nodes:
if n.op in {"placeholder", "get_attr"}:
self.nodes.add(n)
self.nodes.add(node)
self.recalculate_mem_size()
def remove_node(self, node):
# Remove a node only if the node is in the partition
if node in self.nodes:
self.nodes.remove(node)
# Collect the node's input nodes
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Check if an input node is a placeholder or get_attr,
# and this input node is not used by some other nodes in this partition,
# the remove this input node
for input_node in input_nodes:
if all(
n not in self.nodes for n in input_node.users
) and input_node.op in {"placeholder", "get_attr"}:
self.nodes.remove(input_node)
self.recalculate_mem_size()
class Device(NamedTuple):
name: str
available_mem_bytes: int
logical_id: int
class NodeLatency(NamedTuple):
# Latency due to the memory bandwidth
mem_latency_sec: float
# Latency due to the computation
computer_latency_sec: float
class PartitionLatency(NamedTuple):
# Sum of all nodes' memory latency on the critical path
mem_latency_sec: float
# Sum of all nodes' compute latency on the critical path
computer_latency_sec: float
# Latency of the critical path
overall_latency_sec: float
class PartitionMode(Enum):
size_based = 0
sparse_nn = 1
cost_aware = 2
kl_based = 3
aot_based = 4
class PartitionerConfig(NamedTuple):
devices: List[Device]
mode: PartitionMode = PartitionMode.size_based
transfer_rate_bytes_per_sec: float = 0.0
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
node_to_partition_mapping: Dict[Node, int] = {}
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
# Saturate host by replicating partitions to the remaining idle devices.
saturate_host: bool = False
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
"""Given a node and a set of nodes,
this function return the extra size that needed
if this node is included in this set.
"""
# Find all its input nodes
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Calculate total size of related nodes
total_size_of_input_nodes = 0
for n in input_nodes:
# Make sure this node hasn't been in this set yet
if n not in nodes:
size_bytes = getattr(n, "size_bytes", None)
if size_bytes:
total_size_of_input_nodes += size_bytes.output_size
else:
raise RuntimeError("node has no size_bytes attr")
# Don't forget the op node itself
size_bytes = getattr(node, "size_bytes", None)
if size_bytes:
total_size_of_input_nodes += size_bytes.total_size
else:
raise RuntimeError("node has no size_bytes attr")
return total_size_of_input_nodes
def get_latency_of_one_partition(
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
) -> PartitionLatency:
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
def get_top_nodes(partition: Partition) -> List[Node]:
"""Given a partition, return a list of nodes on the top bfs level"""
top_nodes: List[Node] = []
for node in partition.nodes:
# Skip placeholder and get_attr nodes
if node.op in {"placeholder", "get_attr"}:
continue
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# If a node has no input nodes in this partition,
# or its input nodes in this partition are placeholders and get_attrs
# this node is on the top bfs level in this partition
if not any(
n in partition.nodes and n.op not in {"placeholder", "get_attr"}
for n in input_nodes
):
top_nodes.append(node)
return top_nodes
def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
"""Given a top node of a partition, this function returns
the latency of the critical path in the partition
"""
node_latency = node_to_latency_mapping[node]
# Calculate the current overall latency of the partition
overall_latency_sec = partition_latency.overall_latency_sec + max(
node_latency.computer_latency_sec, node_latency.mem_latency_sec
)
# Update the mem latency of this path
mem_latency_sec = (
partition_latency.mem_latency_sec + node_latency.mem_latency_sec
)
# Update the compute latency of this path
computer_latency_sec = (
partition_latency.computer_latency_sec + node_latency.computer_latency_sec
)
# Get all users of this node that are in this partition
users = set(node.users).intersection(partition.nodes)
if users:
max_latency = PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
)
for n in users:
# Get new partition latency recursively
new_partition_latency = dfs_helper(
n,
PartitionLatency(
mem_latency_sec, computer_latency_sec, overall_latency_sec
),
)
if (
new_partition_latency.overall_latency_sec
> max_latency.overall_latency_sec
):
max_latency = new_partition_latency
return max_latency
# If there is no user, the node is at bottom of the partition
return PartitionLatency(
mem_latency_sec, computer_latency_sec, overall_latency_sec
)
# Main part starts
# Get all top level nodes of this partition
top_nodes = get_top_nodes(partition)
critical_path_latency = PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
)
# Go through all top nodes and find the largest latency (critical pass latency)
for node in top_nodes:
partition_latency = dfs_helper(
node,
PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
),
)
if (
partition_latency.overall_latency_sec
> critical_path_latency.overall_latency_sec
):
critical_path_latency = partition_latency
return critical_path_latency
def get_partition_to_latency_mapping(
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
) -> Dict[Partition, PartitionLatency]:
"""Given all the partitions and node_to_latency_mapping dictionary,
return a mapping dictionary of each partition to its overall latency
"""
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
# Go through each partition and get its latency
for partition in partitions:
partition_latency = get_latency_of_one_partition(
partition, node_to_latency_mapping
)
partition_to_latency_mapping[partition] = partition_latency
return partition_to_latency_mapping
def get_comm_latency_between(
parent_partition: Partition,
child_partition: Partition,
transfer_rate_bytes_per_sec: float,
):
"""Given two partitions (parent and child),
calculate the communication latency between the two.
"""
# If two partitions are on the same device, the comm latency is 0.
if (
parent_partition.logical_device_ids != []
and child_partition.logical_device_ids != []
and parent_partition.logical_device_ids == child_partition.logical_device_ids
):
return 0.0
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
for node in child_partition.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
for n in input_nodes:
if n in parent_partition.nodes and n not in visited_nodes:
size_bytes = getattr(n, "size_bytes", None)
if size_bytes is not None:
comm_size += size_bytes.output_size
visited_nodes.add(n)
return comm_size / transfer_rate_bytes_per_sec
def get_latency_of_partitioned_graph(
partitions: List[Partition],
partition_to_latency_mapping: Dict[Partition, PartitionLatency],
transfer_rate_bytes_per_sec: float,
):
"""Given all partitions in a graph, find the critical path among all partitions
and return its latency as the latency of the whole graph
"""
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
"""This function helps to recursively get the latency of a path of partitions"""
# Update latency by adding current partition's latency
latency_so_far_sec += partition_to_latency_mapping[
partition
].overall_latency_sec
children = partition.children
if partition.children:
max_latency_sec = 0.0
for child in partition.children:
# Calculate latency between
comm_latency_sec = get_comm_latency_between(
partition, child, transfer_rate_bytes_per_sec
)
new_latency_sec = dfs_helper(
child, latency_so_far_sec + comm_latency_sec
)
if new_latency_sec > max_latency_sec:
max_latency_sec = new_latency_sec
return max_latency_sec
return latency_so_far_sec
def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
"""This function is to return all the partitions without parents
as the starting points of all the paths
"""
top_partitions = []
for partition in partitions:
# If a partition has no parents, then it is a top partition
if len(partition.parents) == 0:
top_partitions.append(partition)
return top_partitions
top_partitions = get_top_partitions(partitions)
critical_path_latency_sec = 0.0
for partition in top_partitions:
latency_sec = dfs_helper(partition, 0.0)
if latency_sec > critical_path_latency_sec:
critical_path_latency_sec = latency_sec
return critical_path_latency_sec

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,512 @@
# mypy: allow-untyped-defs
import functools
import inspect
import itertools
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
log = logging.getLogger(__name__)
trace_shape_events_log = torch._logging.getArtifactLogger(
__name__, "trace_shape_events"
)
__all__ = [
"ShapeEnvEvent",
"record_shapeenv_event",
"replay_shape_env_events",
"FakeTensorMeta",
"shape_env_check_state_equal",
"NotEqualError",
]
# [Note: Recording ShapeEnv Events]
# =================================
#
# What is a ShapeEnv event?
# -------------------------
# We consider a ShapeEnv event every function call (ShapeEnv method or
# independent function) that modifies the state of the ShapeEnv instance.
# Such calls are recorded alongside their positional and keyword arguments,
# so that it may be replayed over a different ShapeEnv instance.
#
# See [Note: ShapeEnv State Equality] for what is considered the state
# of a ShapeEnv instance.
#
# What is it for?
# ---------------
# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
# arbitrary state in time.
#
# Being able to arbitrarily replay events like so is useful, mainly for
# translation validation bisection. i.e. if a ValidationException has been
# raised, find the earliest point in time where the translation validation
# fails.
#
# Besides that, it also allows us to inspect the given instance and,
# for example, check the guards that would actually be issued at that point.
#
# What kind of arguments can be stored in an event?
# -------------------------------------------------
# There's no specific rule for what cannot be used as an argument.
# That said, pay special attention to the following cases:
#
# 1. Tensor inputs: there are some tests that check whether the inputs
# were garbage collected after execution. These will fail if there's
# an event that is holding a reference to those inputs.
#
# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
# will be automatically replaced by the new given ShapeEnv instance.
#
# 3. SymTypes arguments: they also hold references to ShapeEnv. So,
# whenever we see them, we create a new instance, replacing the
# ShapeEnv reference.
#
# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic
# shapes. That argument must be replaced when replaying the event at
# ShapeEnvEvent.run, since it has to reference a node from the given
# instance, and not from the recorded instance.
# Event class for reconstructing ShapeEnv at arbitrary time.
#
# Represents a method call that mutates ShapeEnv in a way that affects the
# issued guards, when ShapeEnv.produce_guards is called.
@dataclass
class ShapeEnvEvent:
# ShapeEnv method.
f: Callable
# Arguments and keyword arguments called with.
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None
# List of tracked_fakes at the time the method was called.
tracked_fakes: Optional[List[Any]] = None
# Name of the captured event.
# Used for special handling of particular methods.
name: Optional[str] = None
# Replay itself, but using shape_env as self.
def run(self, shape_env=None) -> Any:
from torch.fx.experimental.symbolic_shapes import (
is_symbolic,
ShapeEnv,
SymTypes,
)
# Special handling for the constructor event.
if self.f is ShapeEnv:
assert shape_env is None and self.args is None and self.kwargs is not None
return ShapeEnv(**self.kwargs)
assert shape_env is not None
args = list(self.args or [])
kwargs = dict(self.kwargs or {})
# Replace any argument of type ShapeEnv by the given one.
args, kwargs = pytree.tree_map_only(
ShapeEnv, lambda _: shape_env, (args, kwargs)
)
# Replace any argument of type SymTypes by a new instance,
# replacing its ShapeEnv reference.
args, kwargs = pytree.tree_map_only(
lambda x: isinstance(x, SymTypes) and is_symbolic(x),
lambda a: type(a)(a.node.with_shape_env(shape_env)),
(args, kwargs),
)
# Converts FX nodes using the mapping argument.
def maybe_convert_node(x: Any) -> Any:
if not isinstance(x, torch.fx.Node):
# Don't do anything to x if it's not an FX node.
return x
# If, at some point, we created an FX node, it means that translation validation is on.
# It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
# we are tracking node names at shape_env.name_to_node.
assert hasattr(shape_env, "name_to_node")
name_to_node = shape_env.name_to_node # type: ignore[attr-defined]
assert x.name in name_to_node
return name_to_node[x.name]
# Replaces the value of an specific argument by the result of fn.
def replacearg(index: int, key: str, fn: Callable):
if index < len(args):
args[index] = fn(args[index])
if key in kwargs:
kwargs[key] = fn(kwargs[key])
if self.is_create_fx_call_function():
# ShapeEnv.create_fx_call_function:
# "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
# They must be replaced, since a "call_function" FX node with this tuple as argument
# will be added to the FX graph of the new shape_env.
replacearg(
index=2,
key="args",
fn=lambda args: tuple(maybe_convert_node(a) for a in args),
)
if self.is_evaluate_expr() or self.is_defer_runtime_assert():
# ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
# "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
# They must be replaced, since it will be part of a "call_function" FX node for
# torch._assert, which will be added to the FX graph of the new shape_env.
replacearg(index=3, key="fx_node", fn=maybe_convert_node)
# Actually call the method with the converted arguments.
return self.f(*args, **kwargs)
def __str__(self) -> str:
name = self.name if self.name is not None else self.f.__name__
return f"event: {name} ({self.args}, {self.kwargs})"
def is_create_fx_call_function(self) -> bool:
return self.name == "_create_fx_call_function"
def is_evaluate_expr(self) -> bool:
return self.name == "evaluate_expr"
def is_defer_runtime_assert(self) -> bool:
return self.name == "defer_runtime_assert"
NEST = 0
# Extracts a ShapeEnv instance inside args and kwargs.
# Specifically, it looks for:
# 1. ShapeEnv arguments
# 2. SymInt, SymFloat, or SymBool arguments
# If we find more than one object of any of the above types, we
# also check that the ShapeEnv instance is the same for all of them.
def _extract_shape_env_and_assert_equal(args, kwargs):
from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
if old is not None:
assert old is new, "call with different ShapeEnv"
return new
shape_env = None
for val in itertools.chain(args, kwargs.values()):
if isinstance(val, ShapeEnv):
shape_env = assert_equal(shape_env, val)
if isinstance(val, SymTypes) and is_symbolic(val):
shape_env = assert_equal(shape_env, val.node.shape_env)
return shape_env
# Decorator for recording the given function as a replayable event.
#
# This decorator should be used at every function that mutates the state of
# ShapeEnv in some way that affects the resulting issued guards (i.e. when
# ShapeEnv.produce_guards is called).
#
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
#
# When to save the list of TrackedFake?
# =====================================
# We should save the list of TrackedFake whenever the translation validation
# bisection may actually stop and call the produce_guards method at the moment
# right after the recorded function was played. In other words, since the
# bisection bisects through torch._assert calls, we should save in all methods
# that adds a torch._assert call to the symbolic shapes FX graph.
#
# At the moment, there are 2 methods that save the list:
# - ShapeEnv.evaluate_expr
# - ShapeEnv.defer_runtime_assert
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
def decorator(fn: Callable) -> Callable:
assert callable(fn)
args = inspect.getfullargspec(fn).args
assert args and args[0] == "self", (
"record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
"code so that it calls into a method on ShapeEnv"
)
name = fn.__name__
@functools.wraps(fn)
def wrapper(*args, **kwargs):
from torch.fx.experimental.symbolic_shapes import ShapeEnv
assert isinstance(args[0], ShapeEnv)
global NEST
trace_shape_events_log.debug(
"%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
)
NEST += 1
def retlog(r):
trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
return r
try:
if args[0].is_recording: # type: ignore[has-type]
# If ShapeEnv is already recording an event, call the wrapped
# function directly.
#
# NB: here, we skip the check of whether all ShapeEnv instances
# are equal, in favor of a faster dispatch.
return retlog(fn(*args, **kwargs))
# Retrieve an instance of ShapeEnv.
# Assumption: the collection of args and kwargs may not reference
# different ShapeEnv instances.
self = _extract_shape_env_and_assert_equal(args, kwargs)
# If we are calling this function without any ShapeEnv instance
# alive in its arguments, we don't record and call the original.
if self is None:
return retlog(fn(*args, **kwargs))
# Otherwise, start recording and call the function.
with self._recording():
# Take a snapshot of the current tracked_fakes.
tracked_fakes = (
self._snapshot_tracked_fakes() if save_tracked_fakes else None
)
# Record the event for 'fn'.
event = ShapeEnvEvent(
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
)
# Play the event on this ShapeEnv.
# NB: It's important to put the event first, because running
# the event can trigger internal events that must be ordered
# after this event. However, if an exception happens, we do
# NOT want to have the event in the list, so pop it off from
# the record if an error happened
self.events.append(event)
try:
return retlog(event.run(self))
except Exception:
self.events.pop()
raise
except Exception:
log.error( # noqa: G201
"failed while running %s(*%s, **%s)",
name,
args[1:],
kwargs,
exc_info=log.isEnabledFor(logging.INFO),
)
raise
finally:
NEST -= 1
return wrapper
return decorator
# Replays the ShapeEnvEvents list.
# It assumes the first event is the constructor call.
#
# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
def replay_shape_env_events(events):
from torch.fx.experimental.symbolic_shapes import ShapeEnv
constructor_event = events[0]
assert constructor_event.f == ShapeEnv
# Constructs the new ShapeEnv.
shape_env = constructor_event.run()
for event in events[1:]:
try:
# Actually replays each event.
# We need to call create_mapping_fn every time, since the node list might
# change after each event is replayed.
event.run(shape_env)
except Exception as e:
log.error("failed when running event: %s", event)
raise
return shape_env
# FakeTensor metadata.
# This is to be used in place of FakeTensor placeholders when calling
# ShapeEnv.produce_guards.
@dataclass
class FakeTensorMeta:
tensor_size: Tuple[Union[int, torch.SymInt], ...]
tensor_stride: Tuple[Union[int, torch.SymInt], ...]
tensor_storage_offset: Union[int, torch.SymInt]
is_nested: bool
def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
return self.tensor_size
def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
return self.tensor_stride
def storage_offset(self) -> Union[int, torch.SymInt]:
return self.tensor_storage_offset
def dim(self) -> int:
return len(self.tensor_size)
@staticmethod
def from_fake(fake) -> "FakeTensorMeta":
return FakeTensorMeta(
fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
)
# [Note: ShapeEnv State Equality]
# ===============================
#
# What is considered ShapeEnv state?
# ----------------------------------
# We consider to be the state of a ShapeEnv instance everything that
# is not in the inline tuple inside remove_nonstate_variables function.
# That is: the fields within ShapeEnv that modify the flow of execution
# of the program.
#
# So, for example: the replacements field might influence on how an
# expression is simplified. That, in turn, may result in a guard being
# statically known (i.e. not added).
#
# On the other hand, var_to_stack serves only changes what is printed
# in the screen, i.e. used only for debugging purposes. Therefore, we
# should not consider it when comparing states.
#
# What to do on NotEqualError?
# ----------------------------
# Here are a few possible causes for getting a NotEqualError raised:
#
# 1. New field that does not belong in the ShapeEnv state.
# For example: log field of type ShapeEnvLoggerAdapter. Different
# ShapeEnv instances will always have different ShapeEnvLoggerAdapter
# instances, i.e. equality comparison would fail.
# Solution: add it to the inlined tuple inside remove_nonstate_variables
# function inside check_equal method.
#
# 2. New field that is not directly comparable across instances.
# For example: guards field of type List[ShapeGuard]. More specifically,
# the ShapeGuard type holds an expression and a stack information
# for debugging purposes. When replaying the even on a new ShapeEnv
# instance, the stack would be different, which would trigger this error.
# Solution: add a special case to the map_value function inside
# check_equal function.
#
# 3. Mutation of ShapeEnv on some not recorded function.
# If a mutation of the state of ShapeEnv happens inside a function
# that is not recorded (or that no caller in the stack is recorded),
# then, the replayed ShapeEnv won't catch that.
# Solution: decorate the function with record_shape_env_event.
# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
# returned by ShapeEnv.produce_guards.
def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
# Collect and remove variables that don't necessarily represent the state
# of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
# instance itself.
env1_vars = vars(env1).copy()
env2_vars = vars(env2).copy()
for v in non_state_variable_names:
if v in env1_vars:
env1_vars.pop(v)
if v in env2_vars:
env2_vars.pop(v)
# Function for transforming the mismatched values into string.
# Needed, since dict and set entries order might not be the same every time.
def value_to_str(value: Any) -> str:
if isinstance(value, dict):
return (
"{"
+ ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
+ "}"
)
if isinstance(value, set):
return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
return str(value)
# Compares env1_vars with env2_vars.
# Here, we allow the value of each field to be mapped, so that we appropriately
# compare the two values.
def compare_vars(
map_value: Callable[[str, Any], Any]
) -> List[Tuple[str, str, str]]:
env1_set, env2_set = set(env1_vars), set(env2_vars)
# First, compare the set of keys in each vars dictionary.
if env1_set != env2_set:
raise NotEqualError(
"field set mismatch:",
[
(
"found unique fields:",
str(sorted(env1_set - env2_set)),
str(sorted(env2_set - env1_set)),
),
],
)
# Then, sort the keys, and compare the mapped values of each key.
sorted_keys = list(env1_set)
sorted_keys.sort()
mapped_dict = [
(k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
for k in sorted_keys
]
# Return a list of tuples representing the fields that did not match
# alongside their respective mapped values.
return [
(f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
for k, val1, val2 in mapped_dict
if val1 != val2
]
# Accumulate the mismatching fields.
errors = compare_vars(map_value)
if len(errors) > 0:
raise NotEqualError("field values don't match:", errors)
class NotEqualError(Exception):
def __init__(
self,
msg: str,
mismatched: List[Tuple[str, str, str]],
) -> None:
details = "\n".join(
[
"\n".join(
[
f"==> {inner_msg}",
f" > Left: {str1}",
f" > Right: {str2}",
]
)
for inner_msg, str1, str2 in mismatched
]
)
super().__init__(
f"""\
ShapeEnv not equal: {msg}
{details}
"""
)

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
class Equality:
def __init__(self, lhs, rhs):
self.lhs = lhs
self.rhs = rhs
def __str__(self):
return f'{self.lhs} = {self.rhs}'
def __repr__(self):
return f'{self.lhs} = {self.rhs}'
def __eq__(self, other):
if isinstance(other, Equality):
return self.lhs == other.lhs and self.rhs == other.rhs
else:
return False

View File

@ -0,0 +1,128 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import ast
import inspect
import textwrap
import copy
import functools
from types import FunctionType
from typing import cast, Union, Callable, Dict, Optional, Any
from torch.fx._symbolic_trace import Tracer
from torch.fx.graph import Graph
from torch._sources import normalize_source_lines
import torch
class AST_Rewriter(ast.NodeTransformer):
"""
Take a FunctionType object representing a `forward` method, then
perform an AST rewrite to swap out nodes that are not symbolically
traceable with a callsite to the FX alternative.
To support swapping out an AST node, define a new `visit` method on
that node. For more details, see:
https://docs.python.org/3/library/ast.html#ast.NodeTransformer
"""
# This function checks for new keys added in the globals dict. TorchDynamo
# can insert new keys in the global dict and upset the check. Therefore, put
# a disable here. This function is an optimization pass and not really
# suitable for dynamo tracing anyways.
@torch._dynamo.disable
def rewrite(self, fn: FunctionType):
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
normalized_str = textwrap.dedent(source)
# Rewrite the original AST
source_ast = ast.parse(normalized_str)
dest_ast = ast.fix_missing_locations(self.visit(source_ast))
# Pull out the compiled function from the newly-created Module
code = compile(dest_ast, "", "exec")
globals_dict = copy.copy(fn.__globals__)
keys_before = set(globals_dict.keys())
exec(code, globals_dict)
new_keys = list(set(globals_dict.keys()) - keys_before)
assert len(new_keys) == 1
fn_compiled = globals_dict[new_keys[0]]
# return the compiled function with the original globals
def change_func_globals(f, globals):
"""Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
# __globals__ is a private member of the function class
# so we have to copy the function, f, all of its member, except f.__globals__
g = FunctionType(
f.__code__,
globals,
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined]
return g
# Return the correct FunctionType object
return change_func_globals(fn_compiled, globals=fn.__globals__)
def visit_Assert(self, node):
"""
Swap out the Assert node (Python's `assert`) with a callsite to the
symbolically-traceable torch._assert function
"""
# Create the Call node
n = ast.parse('torch._assert()', mode='eval')
assert isinstance(n, ast.Expression)
call_node = n.body
assert isinstance(call_node, ast.Call)
msg = node.msg if node.msg else ast.Constant(value="", kind=None)
call_node.args = [node.test, msg]
# Ensure that the new node conforms to the Python AST grammar
expr_wrapper = ast.Expr(value=call_node)
# Return the new Call node to signify that we want to use it as
# a replacement for the original _assert node
return ast.copy_location(expr_wrapper, node)
def visit_AnnAssign(self, node):
"""
Swap out Python's AnnAssign with an Assign node where the annotation function is called.
Example:
Original:
y: Tensor_Type(1,2,3, Dyn) = f2(x)
Output:
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
"""
return ast.Assign(targets=[node.target], value=ast.Call(
func=ast.Name(id='annotate', ctx=ast.Load()),
args=[node.value, node.annotation], keywords=[]))
class RewritingTracer(Tracer):
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
return super().trace(_rewrite(root), concrete_args)
def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
if isinstance(fn, torch.nn.Module):
# Rewrite this module's `forward` as well as the `forward`s of
# all of this module's recursive descendents. Return the new,
# rewritten module hierarchy.
def rewrite_module(m : torch.nn.Module):
class RewrittenModule(torch.nn.Module):
def __init__(self, orig):
super().__init__()
for k, v in orig.__dict__.items():
if isinstance(v, torch.nn.Module):
self.__dict__[k] = copy.copy(rewrite_module(v))
else:
self.__dict__[k] = copy.copy(v)
RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
return RewrittenModule(m)
return rewrite_module(fn)
else:
# Rewrite this single free function
return AST_Rewriter().rewrite(cast(FunctionType, fn))

View File

@ -0,0 +1,112 @@
# mypy: allow-untyped-defs
import torch
import torch.fx
import inspect
from typing import Any, Dict, Optional, Tuple
from torch.fx.node import Argument, Target
from torch._jit_internal import boolean_dispatched
from torch.fx.operator_schemas import _torchscript_type_to_python_type
from torch.fx import Transformer
class AnnotateTypesWithSchema(Transformer):
"""
Use Python function signatures to annotate types for `Nodes` within an FX graph.
This pulls out Python function signatures for:
1. Standard `torch.nn` Module calls
2. `torch.nn.functional` calls
3. Attribute fetches via `get_attr`
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = AnnotateTypesWithSchema(traced).transform()
"""
def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
annotate_modules : bool = True, annotate_get_attrs : bool = True):
super().__init__(module)
self.annotate_functionals = annotate_functionals
self.annotate_modules = annotate_modules
self.annotate_get_attrs = annotate_get_attrs
def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
python_ret_type = None
if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
target_for_analysis = target
if target in boolean_dispatched:
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
# a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
# branches of the dispatch have exactly the same signature. If they do, use the `true`
# branch signature for analysis. Otherwise, leave this un-normalized
assert not isinstance(target, str)
dispatched = boolean_dispatched[target]
if_true, if_false = dispatched['if_true'], dispatched['if_false']
# TODO: can we emit the union of these? What are the implications on TorchScript
# compilation?
if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
return super().call_function(target, args, kwargs)
target_for_analysis = if_true
python_ret_type = self._extract_python_return_type(target_for_analysis)
return_proxy = super().call_function(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return return_proxy
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
python_ret_type = None
assert isinstance(target, str)
submod = self.fetch_attr(target)
if self.annotate_modules and hasattr(submod.__class__, '__name__'):
classname = submod.__class__.__name__
if getattr(torch.nn, classname, None) == submod.__class__:
python_ret_type = self._extract_python_return_type(submod.forward)
return_proxy = super().call_module(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return return_proxy
def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
attr_proxy = super().get_attr(target, args, kwargs)
if self.annotate_get_attrs:
module_itr = self.module
assert isinstance(target, str)
atoms = target.split('.')
for i, atom in enumerate(atoms):
if not hasattr(module_itr, atom):
raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
module_itr = getattr(module_itr, atom)
maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
if maybe_inferred_ts_type.success():
python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
return attr_proxy
def _extract_python_return_type(self, target : Target) -> Optional[Any]:
"""
Given a Python call target, try to extract the Python return annotation
if it is available, otherwise return None
Args:
target (Callable): Python callable to get return annotation for
Returns:
Optional[Any]: Return annotation from the `target`, or None if it was
not available.
"""
assert callable(target)
try:
sig = inspect.signature(target)
except (ValueError, TypeError):
return None
return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
# mypy: disable-error-code=attr-defined
from .core import unify, reify # noqa: F403
from .more import unifiable # noqa: F403
from .variable import var, isvar, vars, variables, Var # noqa: F403

View File

@ -0,0 +1,119 @@
# mypy: allow-untyped-defs
from collections.abc import Iterator # type: ignore[import]
from functools import partial
from .unification_tools import assoc # type: ignore[import]
from .utils import transitive_get as walk
from .variable import isvar
from .dispatch import dispatch
__all__ = ["reify", "unify"]
###############
# Reification #
###############
@dispatch(Iterator, dict)
def _reify(t, s):
return map(partial(reify, s=s), t)
# return (reify(arg, s) for arg in t)
_reify
@dispatch(tuple, dict) # type: ignore[no-redef]
def _reify(t, s):
return tuple(reify(iter(t), s))
_reify
@dispatch(list, dict) # type: ignore[no-redef]
def _reify(t, s):
return list(reify(iter(t), s))
_reify
@dispatch(dict, dict) # type: ignore[no-redef]
def _reify(d, s):
return {k: reify(v, s) for k, v in d.items()}
_reify
@dispatch(object, dict) # type: ignore[no-redef]
def _reify(o, s):
return o # catch all, just return the object
def reify(e, s):
""" Replace variables of expression with substitution
>>> # xdoctest: +SKIP
>>> x, y = var(), var()
>>> e = (1, x, (3, y))
>>> s = {x: 2, y: 4}
>>> reify(e, s)
(1, 2, (3, 4))
>>> e = {1: x, 3: (y, 5)}
>>> reify(e, s)
{1: 2, 3: (4, 5)}
"""
if isvar(e):
return reify(s[e], s) if e in s else e
return _reify(e, s)
###############
# Unification #
###############
seq = tuple, list, Iterator
@dispatch(seq, seq, dict)
def _unify(u, v, s):
if len(u) != len(v):
return False
for uu, vv in zip(u, v): # avoiding recursion
s = unify(uu, vv, s)
if s is False:
return False
return s
#
# @dispatch((set, frozenset), (set, frozenset), dict)
# def _unify(u, v, s):
# i = u & v
# u = u - i
# v = v - i
# return _unify(sorted(u), sorted(v), s)
#
#
# @dispatch(dict, dict, dict)
# def _unify(u, v, s):
# if len(u) != len(v):
# return False
# for key, uval in iteritems(u):
# if key not in v:
# return False
# s = unify(uval, v[key], s)
# if s is False:
# return False
# return s
#
#
# @dispatch(object, object, dict)
# def _unify(u, v, s):
# return False # catch all
@dispatch(object, object, dict)
def unify(u, v, s): # no check at the moment
""" Find substitution so that u == v while satisfying s
>>> x = var('x')
>>> unify((1, x), (1, 2), {})
{~x: 2}
"""
u = walk(u, s)
v = walk(v, s)
if u == v:
return s
if isvar(u):
return assoc(s, u, v)
if isvar(v):
return assoc(s, v, u)
return _unify(u, v, s)
unify
@dispatch(object, object) # type: ignore[no-redef]
def unify(u, v):
return unify(u, v, {})

View File

@ -0,0 +1,6 @@
from functools import partial
from .multipledispatch import dispatch # type: ignore[import]
namespace = {} # type: ignore[var-annotated]
dispatch = partial(dispatch, namespace=namespace)

View File

@ -0,0 +1,122 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .variable import isvar
from .utils import _toposort, freeze
from .unification_tools import groupby, first # type: ignore[import]
class Dispatcher:
def __init__(self, name):
self.name = name
self.funcs = {}
self.ordering = []
def add(self, signature, func):
self.funcs[freeze(signature)] = func
self.ordering = ordering(self.funcs)
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
return func(*args, **kwargs)
def resolve(self, args):
n = len(args)
for signature in self.ordering:
if len(signature) != n:
continue
s = unify(freeze(args), signature)
if s is not False:
result = self.funcs[signature]
return result, s
raise NotImplementedError("No match found. \nKnown matches: "
+ str(self.ordering) + "\nInput: " + str(args))
def register(self, *signature):
def _(func):
self.add(signature, func)
return self
return _
class VarDispatcher(Dispatcher):
""" A dispatcher that calls functions with variable names
>>> # xdoctest: +SKIP
>>> d = VarDispatcher('d')
>>> x = var('x')
>>> @d.register('inc', x)
... def f(x):
... return x + 1
>>> @d.register('double', x)
... def f(x):
... return x * 2
>>> d('inc', 10)
11
>>> d('double', 10)
20
"""
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = {k.token: v for k, v in s.items()}
return func(**d)
global_namespace = {} # type: ignore[var-annotated]
def match(*signature, **kwargs):
namespace = kwargs.get('namespace', global_namespace)
dispatcher = kwargs.get('Dispatcher', Dispatcher)
def _(func):
name = func.__name__
if name not in namespace:
namespace[name] = dispatcher(name)
d = namespace[name]
d.add(signature, func)
return d
return _
def supercedes(a, b):
""" ``a`` is a more specific match than ``b`` """
if isvar(b) and not isvar(a):
return True
s = unify(a, b)
if s is False:
return False
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
if reify(a, s) == a:
return True
if reify(b, s) == b:
return False
# Taken from multipledispatch
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
if supercedes(a, b):
if supercedes(b, a):
return tie_breaker(a) > tie_breaker(b)
else:
return True
return False
# Taken from multipledispatch
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(first, edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
return _toposort(edges)

View File

@ -0,0 +1,118 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .dispatch import dispatch
def unifiable(cls):
""" Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
>>> # xdoctest: +SKIP
>>> class A(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> unifiable(A)
<class 'unification.more.A'>
>>> x = var('x')
>>> a = A(1, 2)
>>> b = A(1, x)
>>> unify(a, b, {})
{~x: 2}
"""
_unify.add((cls, cls, dict), unify_object)
_reify.add((cls, dict), reify_object)
return cls
#########
# Reify #
#########
def reify_object(o, s):
""" Reify a Python object with a substitution
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
>>> f = Foo(1, x)
>>> print(f)
Foo(1, ~x)
>>> print(reify_object(f, {x: 2}))
Foo(1, 2)
"""
if hasattr(o, '__slots__'):
return _reify_object_slots(o, s)
else:
return _reify_object_dict(o, s)
def _reify_object_dict(o, s):
obj = object.__new__(type(o))
d = reify(o.__dict__, s)
if d == o.__dict__:
return o
obj.__dict__.update(d)
return obj
def _reify_object_slots(o, s):
attrs = [getattr(o, attr) for attr in o.__slots__]
new_attrs = reify(attrs, s)
if attrs == new_attrs:
return o
else:
newobj = object.__new__(type(o))
for slot, attr in zip(o.__slots__, new_attrs):
setattr(newobj, slot, attr)
return newobj
@dispatch(slice, dict)
def _reify(o, s):
""" Reify a Python ``slice`` object """
return slice(*reify((o.start, o.stop, o.step), s))
#########
# Unify #
#########
def unify_object(u, v, s):
""" Unify two Python objects
Unifies their type and ``__dict__`` attributes
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
>>> f = Foo(1, x)
>>> g = Foo(1, 2)
>>> unify_object(f, g, {})
{~x: 2}
"""
if type(u) != type(v):
return False
if hasattr(u, '__slots__'):
return unify([getattr(u, slot) for slot in u.__slots__],
[getattr(v, slot) for slot in v.__slots__],
s)
else:
return unify(u.__dict__, v.__dict__, s)
@dispatch(slice, slice, dict)
def _unify(u, v, s):
""" Unify a Python ``slice`` object """
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)

View File

@ -0,0 +1,3 @@
from .core import dispatch
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
MDNotImplementedError)

Some files were not shown because too many files have changed in this diff Show More