I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,27 @@
import torch.fx
class BackwardState:
"""
BackwardState is used to pass Python hooks from the forwards pass
into the backwards pass in Dynamo+Compiled Autograd.
It is created by TorchDynamo and has special handling there.
Dynamo will pass an empty BackwardState to the forwards, then populate
members on it (via setattr) only after the forwards graph is finished.
Later on, in CompileAutograd we will inline and add the needed guards
on the BackwardState.
BackwardState is identified and has special handling in AOTAutograd.
During AOTAutograd:
1) BackwardState is an input to the forwards graph
2) It must only be used in the backwards
3) It will be empty in the forwards
4) In the forwards we add a wrapper to save it
5) In the backwards it becomes an input
6) There can only be one per graph
BackwardState requires CompiledAutograd.
"""
proxy: torch.fx.Proxy

View File

@ -0,0 +1,88 @@
import os
import sys
from typing import Optional
# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
translation_validation = (
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
)
# Timeout (in milliseconds) for z3 finding a solution.
# [@compile_ignored: debug]
translation_validation_timeout = int(
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
)
# Disables bisection for translation validation.
#
# Translation validation bisection is enabled by default, if translation validation
# is also enabled. This should help finding guard simplification issues. However,
# since validation uses Z3 for bisecting, it might take a lot of time.
#
# Set this configuration option so as to avoid bisecting.
# [@compile_ignored: debug]
translation_validation_no_bisect = (
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
)
# Checks whether replaying ShapeEnv events on a freshly constructed one yields
# the a ShapeEnv with the same state. This should be used only in testing.
check_shape_env_recorded_events = False
# TODO: Perhaps consider allowing unions for the configs below (so you can hit
# multiple reps at the same time)
# Give extended debug information if the string representation of a guard
# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue
# this guard, we will generate full Python and C++ backtrace
# [@compile_ignored: debug]
extended_debug_guard_added = os.environ.get(
"TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None
)
# Give extended debug information when a particular symbol is allocated. For
# example, set this to "u2" and whenever we create this symbol, we will
# generate full Python and C++ backtrace
# [@compile_ignored: debug]
extended_debug_create_symbol = os.environ.get(
"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None
)
# Give extended debug information (C++ backtrace) for all extended debug
# settings as well as errors. The C++ backtrace is slow and very spammy so we
# don't include it by default even when you're requesting extended debug.
# [@compile_ignored: debug]
extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != ""
# Give extended debug information (line of code) when a torch function
# is called during export. This is useful for showing progress and detecting
# where export might be stuck. Currently only works for strict=False.
# [@compile_ignored: debug]
extended_debug_current_loc = (
os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1"
)
# [@compile_ignored: debug] Show a warning for every specialization
print_specializations = False
# wraps (un)equalities with 'Not' class after recording the correct expression
# in the FX graph. This should incorrectly construct the divisible and replacement
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
validate_shape_env_version_key = False
# If we produce more than this many guards on a symbol, force the symbol to
# get specialized and bail out if this many guards mention this particular
# symbol. This may be slightly more aggressive than the true number of guards
# issued (as we test if we've hit the limit on-the-fly, whereas we may
# do further simplifications at final guard issuance time that make guards
# irrelevant.)
symbol_guard_limit_before_specialize: Optional[int] = None
# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.
use_duck_shape = True
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,292 @@
# mypy: allow-untyped-defs
import re
from typing import Callable, Dict, Optional, Set, Union
import torch.fx
from torch.fx.node import map_arg
from torch.fx.passes.split_module import split_module
__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
class FoldedGraphModule(torch.fx.GraphModule):
"""
FoldedGraphModule is a GraphModule which also contains another
`const_subgraph_module` representing a subgraph which has all const attr
inputs and which can be run once before running the main standard
`graph`. The `const_output_names` are the ordered list names of attrs which
represent what each respective output from the const_subgraph should be set
on which attrs.
"""
def __init__(
self,
root: torch.nn.Module,
graph: torch.fx.Graph,
const_subgraph: Optional[torch.fx.Graph] = None,
fx_const_folded_attrs_name: Optional[str] = None,
device_for_folded_attrs: str = "cuda",
):
super().__init__(root, graph)
self.const_subgraph_module = (
None
if const_subgraph is None
else torch.fx.GraphModule(root, const_subgraph)
)
self.has_folding_been_run = False
self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
self.device_for_folded_attrs = device_for_folded_attrs
def __call__(self, *args, **kwargs):
if not self.has_folding_been_run:
self.run_folding()
return super().__call__(*args)
def run_folding(self):
# If there's no const subgraph module or attr output names to use, return
# early as there is no const folding to perform.
if (
self.const_subgraph_module is None
or self.fx_const_folded_attrs_name is None
):
return
assert not self.has_folding_been_run
self.has_folding_been_run = True
# Actually run const folding subgraph. Note that single attr const fold
# subgraphs output a single Tensor while multiple outputs are returned as
# Tuple[Tensor,].
folded_attrs = self.const_subgraph_module()
def _create_param(i):
return torch.nn.Parameter(
i.detach().clone()
if not isinstance(i, int)
else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
)
params = (
torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
if isinstance(folded_attrs, tuple)
else _create_param(folded_attrs)
)
setattr(self, self.fx_const_folded_attrs_name, params)
def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
"""
Given `gm` and some graph module which is called with target name `inline_mod_name`,
this helper will inline all of the nodes from that called graph module into `gm`.
"""
# Fetch the inner graph module that we want to inline inside `gm`.
inline_mod = dict(gm.named_modules())[inline_mod_name]
assert isinstance(inline_mod, torch.fx.GraphModule)
call_mod_node_to_replace = None
for node in gm.graph.nodes:
if node.op == "call_module" and node.target == inline_mod_name:
call_mod_node_to_replace = node
break
assert call_mod_node_to_replace is not None
# Now actually do the swap. Note that we have to keep track of new nodes that are
# copied into `gm` -- we do this via replacement_mapping.
call_mod_args = call_mod_node_to_replace.args
replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
ph_count = 0
def replacement_fn(node):
new_node = replacement_mapping[node]
new_node.meta = node.meta.copy()
return new_node
for inline_node in inline_mod.graph.nodes:
if inline_node.op == "placeholder":
replacement_mapping[inline_node] = call_mod_args[ph_count]
ph_count += 1
continue
if inline_node.op == "output":
outputs = inline_node.args[0]
output_replacements = map_arg(outputs, replacement_fn)
call_mod_node_to_replace.replace_all_uses_with(output_replacements)
continue
with gm.graph.inserting_before(call_mod_node_to_replace):
new_node = gm.graph.node_copy(inline_node, replacement_fn)
replacement_mapping[inline_node] = new_node
gm.graph.eliminate_dead_code()
def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
"""
Make sure the name is unique (in a module) and can represents an attr.
"""
# Delete all characters that are illegal in a Python identifier.
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit():
name = f"_{name}"
# Now make sure it is in fact unique to the module by incrementing suffix value.
while hasattr(mod_traced, name):
match = re.match(r"(.*)_(\d+)$", name)
if match is None:
name = name + "_1"
else:
base, num = match.group(1, 2)
name = f"{base}_{int(num) + 1}"
return name
def split_const_subgraphs(
module: Union[torch.nn.Module, torch.fx.GraphModule],
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
device_for_folded_attrs: str = "cpu",
) -> FoldedGraphModule:
"""
Looks through `module` for any nodes that have all constant attribute inputs
and separates them out into their own constant subgraph, and returns a
FoldedGraphModule which runs that constant subgraph on the first run to set
attributes on the module prior to running the non-constant portion of the
graph.
"""
if not isinstance(module, torch.fx.GraphModule):
mod_traced = torch.fx.symbolic_trace(module)
else:
mod_traced = module
# Build up a list of const_nodes, defined as nodes that are themselves
# get_attrs, or have all get_attr or other constant node inputs.
const_nodes: Set[torch.fx.Node] = set()
found_const_folding = False
for node in mod_traced.graph.nodes:
# Skip over placeholders/outputs because they can't be const folded and
# we don't want to add tags to them.
if node.op in {"placeholder", "output"}:
continue
# If the node itself is constant, or all of its inputs are constant,
# then tag it as constant.
if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
const_nodes
):
continue
# If provided skip folding function says to skip, then skip.
if skip_folding_node_fn and skip_folding_node_fn(node):
continue
# Skip folding side-effectful functions
if node.is_impure():
continue
# Must be a constant foldable node at this point.
const_nodes.add(node)
if node.op != "get_attr":
found_const_folding = True
# If we did not find any const folding then return early without a const fold subgraph.
if not found_const_folding:
return FoldedGraphModule(mod_traced, mod_traced.graph)
# Partition the module into two: submod_0 for constant folding subgraph, and
# submod_1 for the rest.
def mod_partition(node: torch.fx.Node):
return 0 if node in const_nodes else 1
split = split_module(mod_traced, module, mod_partition)
const_mod_name, non_const_mod_name = "submod_0", "submod_1"
# Safely get submod_1 in case there are no non-const nodes
const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None)
# The module that a call_module node refers to gets copied to submodules during split.
# The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
# attach inlined modules to `split` as it's the owning module now.
for node in non_const_gm.graph.nodes if non_const_gm else []:
if node.op == "call_module":
setattr(split, node.target, getattr(non_const_gm, node.target))
for node in const_gm.graph.nodes:
if node.op == "call_module":
setattr(split, node.target, getattr(const_gm, node.target))
# split_module currently does not use get_attrs for attrs. Instead it passes
# them in as args from the parent module, which used get_attrs. Here we set
# them as get_attrs inside const_gm, allowing for running folding without
# somehow a priori knowing the attrs that should be passed as args. We can
# unconditionally do this for all placeholders because we know all
# placeholders to const_gm must be constants accessible via get_attr.
call_const_gm_args = None
for node in split.graph.nodes:
if node.op == "call_module":
if node.target == const_mod_name:
call_const_gm_args = node.args
break
assert call_const_gm_args is not None
# Here we do the actual replacement of placeholders to get_attrs. Note that here we
# set the const_gm.graph into a new root_const_gm with split as the root module,
# because we are fetching attributes directly from the root module, instead of
# fetching them from const_gm. Example: The const_gm must have some format like:
# graph():
# %inp : [num_users=1] = placeholder[target=const_inp]
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
# return add
# We replace that with the following, which does not have any placeholders:
# graph():
# %inp_1 : [num_users=1] = get_attr[target=const_inp]
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
# return add
root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
for node in root_const_gm.graph.nodes:
if node.op == "output":
multiple_outputs = isinstance(node.args[0], tuple)
continue
if node.op != "placeholder":
continue
in_node = next(n for n in call_const_gm_args if n.name == node.target)
assert in_node.op == "get_attr"
with root_const_gm.graph.inserting_before(node):
new_node = root_const_gm.graph.get_attr(in_node.target)
new_node.meta = node.meta.copy()
node.replace_all_uses_with(new_node)
root_const_gm.graph.erase_node(node)
assert "multiple_outputs" in locals()
# Now find the call to const_gm inside split, and replace it with a getattr to the
# folded tensor(s) that result from constant folding. Note that we don't need to
# worry about whether this is one or more tensors because the original graph
# correctly uses getitem to extract individual tensors if there are multiple folded.
fx_const_folded_attrs_name = get_unique_attr_name_in_module(
mod_traced, "_FX_CONST_FOLDED_ATTRS"
)
setattr(
split,
fx_const_folded_attrs_name,
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined]
)
for node in split.graph.nodes:
if node.op == "call_module" and node.target == const_mod_name:
with node.graph.inserting_before(node):
folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
folded_attrs.meta = node.meta.copy()
node.replace_all_uses_with(folded_attrs)
break
# Finally, inline the non-constant submod (if it exists) into the split submod.
# This is so that the original caller who may have passed in a graph module will
# get back out a graph module whose graph is traced to the same granularity.
if hasattr(split, non_const_mod_name):
_inline_module(split, non_const_mod_name)
split.graph.eliminate_dead_code()
return FoldedGraphModule(
split,
split.graph,
root_const_gm.graph,
fx_const_folded_attrs_name,
device_for_folded_attrs,
)

View File

@ -0,0 +1,32 @@
# mypy: allow-untyped-defs
import torch.fx as fx
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
"""
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
`gm` gets run.
Args:
gm: graph module to insert breakpoint. It is then recompiled for it to
take effect.
Returns:
the `gm` with breakpoint inserted.
"""
def insert_pdb(body):
return ["import pdb; pdb.set_trace()\n", *body]
with gm.graph.on_generate_code(
make_transformer=lambda cur_transform: (
# new code transformer to register
lambda body: (
insert_pdb(
cur_transform(body) if cur_transform
else body
)
)
)
):
gm.recompile()
return gm

View File

@ -0,0 +1,916 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from functools import reduce
import torch
import operator
from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
from typing import Callable, Dict
from torch.fx.node import Target, Node
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d
from torch.fx.experimental.refinement_types import Equality
import itertools
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
import sympy
_INFERENCE_RULES: Dict[Target, Callable] = {}
_REFINEMENT_RULES: Dict[Target, Callable] = {}
_RULES: Dict[Target, Callable] = {}
def expand_to_tensor_dim(t, n):
"""
Expand a type to the desired tensor dimension if possible
Raise an error otherwise.
- t is the given type
- n is a number of dimensions to expand to
"""
if t == Dyn:
dims = [Dyn] * n
return TensorType(tuple(dims))
elif isinstance(t, TensorType):
if len(t.__args__) != n:
raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
return t
else:
raise TypeError(f'Cannot match the type {t}')
def broadcast_types(t1, t2):
"""
Applies broadcasting to both given types such that they
become consistent with eachother and returns two new
resulting types
"""
# if either type is Dyn, do nothing since the types are already consistent
if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
return t1, t2
if isinstance(t1, TensorType) and isinstance(t2, TensorType):
s1 = len(t1.__args__)
s2 = len(t2.__args__)
new_t1 = list(t1.__args__)
new_t2 = list(t2.__args__)
# We make the types the same length which is the first requirement
# for consistency
if s1 > s2:
for i in range(s1 - s2):
new_t2.insert(0, 1)
elif s2 > s1:
for i in range(s2 - s1):
new_t1.insert(0, 1)
# we replace occurrences of "1" with each tensor with
# the corresponding type from the other tensor
for i, (x, y) in enumerate(zip(new_t1, new_t2)):
if x == 1:
new_t1[i] = y
elif y == 1:
new_t2[i] = x
# at this point our tensors should be consistent
# and we can apply the element-wise operation and find the right dimension
# for the output of the operation
(t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
return (t1, t2)
else:
raise TypeError(f'Cannot broadcast types {t1} and {t2}')
def register_inference_rule(call_target):
def register(fn):
if call_target in _INFERENCE_RULES:
raise RuntimeError(f'Inference rule already registered for {call_target}!')
_INFERENCE_RULES[call_target] = fn
return fn
return register
def register_refinement_rule(call_target):
def register(fn):
if call_target in _REFINEMENT_RULES:
raise RuntimeError(f'Refinement rule already registered for {call_target}!')
_REFINEMENT_RULES[call_target] = fn
return fn
return register
def register_algebraic_expressions_inference_rule(call_target):
def register(fn):
if call_target in _RULES:
raise RuntimeError(f'Rule already registered for {call_target}!')
_RULES[call_target] = fn
return fn
return register
@register_inference_rule(torch.add)
@register_inference_rule(operator.add)
def add_inference_rule(n: Node):
"""
Apply the addition inference rule. This includes:
- scalar addition
- broadcasting semantics
Note that we always return the least precise type between
the operands (after applying broadcasting) to be the final type of the operation
Note that we do not modify the operand types themselves after applying broadcasting
to them. We only use them to calculate the final type
"""
assert isinstance(n.args[0], Node)
assert isinstance(n.args[1], Node)
t1 = n.args[0].type
t2 = n.args[1].type
# handle scalar addition
if t1 == int and isinstance(t2, TensorType):
n.type = t2
return n.type
# handle scalar addition
elif t2 == int and isinstance(t1, TensorType):
n.type = t1
return n.type
# we bring the new types to the point where
# we can check for consistency
# any inconsistency would not have been caused
# by broadcasting at this point
(new_t1, new_t2) = broadcast_types(t1, t2)
if new_t1 != t1 or new_t2 != t2:
n.meta['broadcast'] = True
n.meta[str(n.args[0])] = new_t1
n.meta[str(n.args[1])] = new_t2
else:
n.meta['broadcast'] = False
new_t1 = t1 if not n.meta['broadcast'] else new_t1
new_t2 = t2 if not n.meta['broadcast'] else new_t2
# we check for consistency between the new types
if is_consistent(new_t1, new_t2):
# we return the less precise type because
# broadcasting may have happened
# for operands with shape [1,2,Dyn] and [1,2,1]
# we have to assign the node [1,2,Dyn]
if is_more_precise(new_t1, new_t2):
n.type = new_t2
else:
n.type = new_t1
return n.type
else:
raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.'
f' Types should match ')
@register_inference_rule(getattr)
def get_attr_inference_rule(n: Node, traced):
"""
The current getattr rule only handles the shape attribute
Can be extended to other attributes
The most representitive type we have is "Dyn" but the system
can be extended with more types, such as a type to represent shapes
"""
attr_node = n.args[0]
attr_name = n.args[1]
if attr_name == "shape":
n.type = Dyn
else:
raise TypeError("Not yet implemented")
# TODO. We leave it like this till we add a type to represent tensor sizes
return n.type
@register_inference_rule(torch.transpose)
def transpose_inference_rule(n: Node):
"""
We check that dimensions for the transpose operations
are within range of the tensor type of the node
"""
if n.target == torch.transpose:
assert isinstance(n.args[0], Node)
t = n.args[0].type
assert isinstance(n.args[1], int)
assert isinstance(n.args[2], int)
dim1, dim2 = n.args[1], n.args[2]
if t == Dyn:
n.type = Dyn
return n.type
elif isinstance(t, TensorType):
if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
new_type = list(t.__args__)
new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
final = TensorType(new_type)
n.type = get_greatest_upper_bound(n.type, final)
return n.type
else:
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
else:
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
@register_inference_rule(torch.reshape)
def reshape_inference_rule(n: Node):
"""
Without dynamism, the rule checks that the
product of the elements of the argument tensor
type is equal to the product of the elements
of the required shape. We gradualize this rule
by adding a case to handle fully dynamic input
as well as input where some of the tensor dimensions
are unknown. In this case we check for divisibility
"""
assert isinstance(n.args[0], Node)
t1 = n.args[0].type
assert isinstance(n.args[1], list)
t2 = n.args[1]
t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])
# if we do not know the original tensor dimension,
# we return the required dimension
if t1 == Dyn:
n.type = t2_type
return t2_type
# if any of the dimensions are unknown,
# we check for divisibility
elif isinstance(t1, TensorType):
assert isinstance(t1, TensorType)
a = [e if e != Dyn else 1 for e in t1.__args__]
p1 = reduce(operator.mul, a)
p2 = reduce(operator.mul, t2)
if p1 % p2 == 0 or p2 % p1 == 0:
n.type = t2_type
return t2_type
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
@register_inference_rule(BatchNorm2d)
def bn2d_inference_rule(n: Node, module_instance):
"""
Given a BatchNorm2D instance and a node check the following conditions:
- the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4)
- the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4')
- t is consistent with t'
- x_2 is consistent with the module's num_features
- x_2' is consistent with the module's num_features
output type: the more precise type of t and t'
"""
assert isinstance(n.args[0], Node)
n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
arg_type = n.args[0].type
n.type = expand_to_tensor_dim(n.type, 4)
# we check the conditions on the incoming argument
# and any existing annotation
# we also check for consistency between both annotations
if is_consistent(arg_type.__args__[1], module_instance.num_features) and \
is_consistent(n.type.__args__[1], module_instance.num_features) and \
is_consistent(arg_type, n.type):
# we choose the more precise type
# to be the node type
# so if an incoming argument has more type information
# we set this node's type to be the argument type
n.type = get_greatest_upper_bound(arg_type, n.type)
return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')
def calculate_out_dimension(d_in, module_instance, index):
"""
For calculating h_in and w_out according to the conv2D documentation
"""
padding = (module_instance.padding, module_instance.padding) \
if isinstance(module_instance.padding, int) else module_instance.padding
kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \
if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size
stride = (module_instance.stride, module_instance.stride) \
if isinstance(module_instance.stride, int) else module_instance.stride
dilation = (module_instance.dilation, module_instance.dilation) \
if isinstance(module_instance.dilation, int) else module_instance.dilation
DIMENSION_TYPES = (int, sympy.Symbol)
if d_in == Dyn:
return Dyn
elif isinstance(d_in, DIMENSION_TYPES):
n = d_in + 2 * padding[index] - \
dilation[index] * \
(kernel_size[index] - 1) - 1
return (n // stride[0]) + 1
else:
raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}')
def get_greatest_upper_bound(type1, type2):
"""
Get the most precise type that's consistent with the given types
"""
if type1 == Dyn:
return type2
elif type2 == Dyn:
return type1
elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
if not is_consistent(type1, type2):
raise TypeError(f'Inconsistent types {type1}, {type2}')
gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)]
return TensorType(tuple(gub))
@register_inference_rule(Conv2d)
def conv2d_inference_rule(n: Node, module_instance):
"""
Given a Conv2D instance and a node check the following conditions:
- the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W)
- the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4')
- x_2 is consistent with the module's in_channels
- let o = (x_1, out_channels, H_out, W_out)
then the output is the greatest upper bound of o and the existing node type t'.
"""
assert isinstance(n.args[0], Node)
n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
arg_type = n.args[0].type
curr_node_type = expand_to_tensor_dim(n.type, 4)
if is_consistent(arg_type.__args__[1], module_instance.in_channels):
w_in = arg_type.__args__[3]
h_in = arg_type.__args__[2]
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
gub = get_greatest_upper_bound(new_type, curr_node_type)
n.type = gub
return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
@register_inference_rule(torch.nn.ReLU)
def relu_inference_rule(n: Node, module_instance):
"""
Input and output shapes should be equal.
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
n.type = get_greatest_upper_bound(n.args[0].type, n.type)
return n.type
def maxpool2d_check(typ, module_instance):
"""
Applies the maxpool2d shape information to the input
this affects the last two dimensions
"""
new_type_list = list(typ.__args__)
if len(new_type_list) == 4 or len(new_type_list) == 3:
w_in = new_type_list[-1]
h_in = new_type_list[-2]
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type_list[-1] = w_out
new_type_list[-2] = h_out
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Wrong size {typ} for {module_instance}')
@register_inference_rule(torch.nn.MaxPool2d)
def maxpool2d_inference_rule(n: Node, module_instance):
"""
Given a MaxPool2D instance and a node check the following conditions:
- Input size matches size 3 or 4
- Current node type is consistent with the output type we will calculate
- Input size matches output size and the last two dimensions of the output
are w_out and h_out. The remaining dimensions are the same as the input
- Our final result is the greatest upper bound of the output we calculate
and the current node type.
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output = maxpool2d_check(n.args[0].type, module_instance)
n.type = get_greatest_upper_bound(output, n.type)
return n.type
def linear_check(tensor_type, module_instance):
"""
Checks that an input tensor type satisfies the conditions for linear operation
and returns the output type based on in and out features given by module_instance
"""
if len(tensor_type.__args__) >= 2:
if is_consistent(module_instance.in_features, tensor_type.__args__[-1]):
new_type_args = list(tensor_type.__args__)
new_type_args[-1] = module_instance.out_features
return TensorType(tuple(new_type_args))
else:
raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}')
else:
raise TypeError(f'Type {tensor_type} must have rank 2 or more.')
@register_inference_rule(torch.nn.Linear)
def linear_inference_rule(n: Node, module_instance):
"""
Applies the shape information to the input then gets the greatest upper bound
of the resulting type and the existing type
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output_type = linear_check(n.args[0].type, module_instance)
n.type = get_greatest_upper_bound(output_type, n.type)
return n.type
def adaptiveavgpool2d_check(tensor_type, module_instance):
output_size = module_instance.output_size
if isinstance(output_size, int):
output_size = [output_size, output_size]
elif isinstance(output_size, tuple):
output_size = list(output_size)
if output_size[0] is None:
output_size[0] = output_size[1]
if output_size[1] is None:
output_size[1] = output_size[0]
new_type_list = list(tensor_type.__args__)
if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3:
new_type_list[-1] = output_size[1]
new_type_list[-2] = output_size[0]
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}')
@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
"""
The input and output sizes should be the same except for the last
two dimensions taken from the input, which represent width and height
"""
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance)
n.type = get_greatest_upper_bound(n.type, output_type)
return n.type
def flatten_check(tensor_type, start_dim, end_dim):
l = len(tensor_type.__args__)
start_dim = l if start_dim == -1 else abs(start_dim)
end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
my_args = list(tensor_type.__args__)
lhs = my_args[0:start_dim]
rhs = my_args[end_dim:]
mid = my_args[start_dim:end_dim]
if Dyn in mid:
mid = [Dyn]
else:
mid = [reduce(operator.mul, my_args[start_dim:end_dim])]
new_type_list = lhs + mid + rhs
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}')
@register_inference_rule(torch.flatten)
def flatten_inference_rule(n: Node):
"""
Applies the flatten shape information to the input then gets the
greatest upper bound of the resulting type and the existing type
"""
assert isinstance(n.args[0], Node)
# set the default start and end dims
start_dim = 1
end_dim = -1
if len(n.args) > 1:
assert isinstance(n.args[1], int)
start_dim = n.args[1]
if len(n.args) > 2:
assert isinstance(n.args[2], int)
end_dim = n.args[2]
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
if isinstance(n.args[0].type, TensorType):
output_type = flatten_check(n.args[0].type, start_dim, end_dim)
n.type = get_greatest_upper_bound(output_type , n.type)
return n.type
class GraphTypeChecker:
def __init__(self, env, traced):
self.env = env
self.traced = traced
def type_check(self):
"""
A gradual type checker for graphs
Effect: every node's field type will be
populated with a type after type-checking is done
"""
graph = self.traced.graph
# type check every node with gradual type rules
# if any node does not type check return false
for n in graph.nodes:
self.type_check_node(n)
return True
def type_check_node(self, n: Node):
"""
Type check a given fx node.
Current operations:
- Reshape
- Transpose
- Add
- Relu
- conv2d
- batchnorm2d
- flatten
- maxpool2d
- adaptiveavgpool2d
- linear
"""
if n.type is None:
n.type = Dyn
if n.op == 'placeholder':
return n.type
elif n.op == 'get_attr':
t = get_parameter(self.traced, n.target) # type: ignore[arg-type]
if isinstance(t.data, torch.Tensor):
n.type = TensorType(t.data.shape)
return n.type
elif n.op == 'call_function':
if n.target == getattr:
assert getattr in _INFERENCE_RULES
return _INFERENCE_RULES[n.target](n, self.traced)
elif n.target in _INFERENCE_RULES:
return _INFERENCE_RULES[n.target](n)
else:
raise RuntimeError(f'No inference rule registered for target {n.target}!')
elif n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _INFERENCE_RULES:
return _INFERENCE_RULES[type(module_instance)](n, module_instance)
else:
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
elif n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else:
raise NotImplementedError(f"Method {n.op} not yet implemented")
@register_refinement_rule(Conv2d)
def conv_refinement_rule(n: Node):
"""
The equality constraints are between the first dimension of
the input and output
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
return res
@register_refinement_rule(torch.nn.Linear)
def linear_refinement_rule(n: Node):
"""
The equality constraints are between the first dimension of
the input and output
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
return res
@register_refinement_rule(BatchNorm2d)
@register_refinement_rule(torch.nn.ReLU)
def all_eq(n: Node):
"""
For operations where the input shape is equal to the output shape
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
args1 = arg_type.__args__
args2 = n.type.__args__
res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
return res
@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
@register_refinement_rule(torch.nn.MaxPool2d)
def first_two_eq(n: Node):
"""
For operations where the first two dimensions of the input and output shape
are equal
"""
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
args1 = arg_type.__args__
args2 = n.type.__args__
res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
return res
@register_refinement_rule(torch.add)
@register_refinement_rule(operator.add)
def element_wise_eq(n: Node):
"""
For element-wise operations and handles broadcasting.
Note that after applying broadcasting to the arguments
we are able to determine if certain dimensions have not been broadcast
if they are symbolicallu equal.
in this case, we can establish equality between those dimensions and the
corresponding output dimensions.
Note that it takes two iterations for this result. One iteration to establish
equality between certain dimensions of the operands (requiring the whole solver
including unification) and another iteration to establish equality between the operands
and the resulting type, requiring another round of constraint generation and unificaiton.
"""
res = []
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
arg_type1 = n.args[0].type
arg_type2 = n.args[1].type
if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
args1, args2 = broadcast_types(arg_type1, arg_type2)
# by this point, we know that args1 and args2 are the same size.
a1 = args1.__args__
a2 = args2.__args__
a3 = n.type.__args__
# we would be here in the second iteration where we establish equality
# between operand type dimensions and the resulting type dimensions
r = []
for x, y, z in zip(a1, a2, a3):
if x == y:
r.append(Equality(x, z))
res = r
return res
@register_refinement_rule(torch.flatten)
def flatten_refinement_rule(n: Node):
"""
Generates equality constraints between the dimensions of the input and output
that will not be involved in the flatten operation
"""
assert isinstance(n.args[0], Node)
eq_const = []
start_dim = 1
end_dim = -1
if len(n.args) > 1:
assert isinstance(n.args[1], int)
start_dim = n.args[1]
if len(n.args) > 2:
assert isinstance(n.args[2], int)
end_dim = n.args[2]
if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType):
l = len(n.type.__args__)
arg_type = n.args[0].type
start_dim = l if start_dim == -1 else start_dim
end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]):
eq_const.append(Equality(t1, t2))
for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]):
eq_const.append(Equality(t1, t2))
return eq_const
@register_algebraic_expressions_inference_rule(Conv2d)
def conv_rule(n: Node, module_instance):
"""
Represents the outout in terms of an algrbraic expression w.r.t
the input when possible
"""
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
w_in = arg_type.__args__[3]
h_in = arg_type.__args__[2]
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out))
n.type = new_type
return new_type
class Refine:
"""
Symbolic shape inference.
Generates constraints over type variables.
Currently all constraints are equality constraints.
"""
def __init__(self, traced):
self.constraints = []
self.traced = traced
self.symbol_iter = itertools.count(start=0, step=1)
def refine(self):
"""
Generates constraints for
every node in the graph based on
the operation.
"""
graph = self.traced.graph
for n in graph.nodes:
self.refine_node(n)
return True
def symbolic_relations(self):
"""
Infers algebraic relations
"""
graph = self.traced.graph
for n in graph.nodes:
self.infer_symbolic_relations(n)
return True
def replace_dyn_with_fresh_var(self, typ):
"""
Replace all unknown types with fresh type variables.
"""
if typ == Dyn:
new_symbol = Var(next(self.symbol_iter))
return new_symbol
elif isinstance(typ, TensorType):
new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__]
return TensorType(tuple(new_args))
elif isinstance(typ, list):
return [self.replace_dyn_with_fresh_var(t) for t in typ]
elif isinstance(typ, tuple):
return (self.replace_dyn_with_fresh_var(t) for t in typ)
else:
return typ
def convert_to_sympy_symbols(self, typ):
"""
Replace all unknown types with fresh type variables.
"""
if isinstance(typ, Var):
return sympy.symbols(str(typ))
elif isinstance(typ, TensorType):
new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
return TensorType(tuple(new_args))
elif isinstance(typ, list):
return [self.convert_to_sympy_symbols(t) for t in typ]
elif isinstance(typ, tuple):
return (self.convert_to_sympy_symbols(t) for t in typ)
else:
return typ
def refine_node(self, n: Node):
"""
Returns a list of equality constraints for
call_module and call_function nodes.
Models the relation between input and output dimensions
using constraints in case they are both tensors.
All operations used in resnet50 are defined.
"""
if n.type is None:
n.type = Dyn
n.type = self.replace_dyn_with_fresh_var(n.type)
if n.op == 'call_function':
if n.target in _REFINEMENT_RULES:
self.constraints += _REFINEMENT_RULES[n.target](n)
else:
pass
if n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _REFINEMENT_RULES:
self.constraints += _REFINEMENT_RULES[type(module_instance)](n)
else:
pass
if n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else:
pass
def infer_symbolic_relations(self, n: Node):
n.type = self.convert_to_sympy_symbols(n.type)
if n.op == 'call_function':
if n.target in _RULES:
return _RULES[n.target](n)
else:
pass
if n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _RULES:
return _RULES[type(module_instance)](n, module_instance)
else:
pass
if n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else:
pass
def get_parameter(traced, target: str):
"""
Returns the parameter given by ``target`` if it exists,
otherwise throws an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the Parameter
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
torch.nn.Parameter: The Parameter referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Parameter``
"""
module_path, _, param_name = target.rpartition(".")
mod: torch.nn.Module = traced.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
param: torch.nn.Parameter = getattr(mod, param_name)
return param

View File

@ -0,0 +1,172 @@
# mypy: allow-untyped-defs
import torch
from torch.fx.node import Node
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.passes.tools_common import legalize_graph
import itertools
import operator
from typing import Dict, List, Tuple
def split_result_tensors(
result: torch.Tensor, inputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
"""
A free function for use in the merge_matmul graph transformation below that
splits the output from a merged matmul into the individual results for each
input tensor.
Arguments:
result: The merged matmul result tensor.
inputs: The list of inputs that were merged into one for the matmul.
Returns:
List of matmul results for each input tensor.
"""
# When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
# need an int even when tracing
if isinstance(result, torch.fx.Proxy):
splits = [0] * len(inputs)
else:
splits = [x.shape[0] for x in inputs]
return torch.split(result, splits)
def may_depend_on(a: Node, b: Node, search_depth: int = 6):
"""
Determine if one node depends on another in a torch.fx.Graph.
Arguments:
a: The node that may have a dependency on b.
b: The node that a may have a dependency on.
search_depth: In the case of an indirect dependency, this function
searches upto this many nodes away in search of a
data dependency. If none is found, the function
makes the conservative assumption that there is a
dependency.
Returns:
True if a may depend on b, False if it definitely does not.
"""
# Equivalence is defined as dependence.
if a == b:
return True
# If a has no inputs, it cannot depend on b.
if len(a.all_input_nodes) == 0:
return False
# If the search depth has been exhausted and no conclusion has been
# reached, assume that there is a data dependency.
if search_depth == 0:
return True
# Recursively check all inputs of a.
for inp in a.all_input_nodes:
if may_depend_on(inp, b, search_depth - 1):
return True
return False
def are_nodes_independent(nodes: List[Node]):
"""
Check if all of the given nodes are pairwise-data independent.
Arguments:
nodes: The nodes to check for data dependencies.
Returns:
True if any pair in nodes has a data dependency.
"""
# For each pair in nodes:
for i, j in itertools.combinations(nodes, 2):
if may_depend_on(i, j) or may_depend_on(j, i):
return False
return True
def merge_matmul(in_mod: torch.nn.Module):
"""
A graph transformation that merges matrix multiplication operations that share the same right-hand
side operand into one large matrix multiplication.
____ _________ _________
---- | | | | M| A * C |
M| A | T| B | * K| C | = |---------|
---- , | | | | T| B * C |
K ---- --------- ---------
K R R
"""
gm = symbolic_trace(in_mod)
rhs_users: Dict[Node, List[Node]] = {}
lhs_users: Dict[Node, List[Node]] = {}
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
# the matmul of which they are the LHS/RHS.
for node in gm.graph.nodes:
if node.op != "call_function" or node.target is not torch.matmul:
continue
lhs, rhs = node.args
# TODO: Properly handle aliasing caused by get_attr. For now,
# use the attribute name as the operand if the node is a
# get_attr.
lhs = lhs.target if lhs.op == "get_attr" else lhs
rhs = rhs.target if rhs.op == "get_attr" else rhs
lhs_users.setdefault(lhs, []).append(node)
rhs_users.setdefault(rhs, []).append(node)
for rhs, mms in rhs_users.items():
# There must be at least matmuls for a merge to make sense.
if len(mms) < 2:
continue
# All matmuls must not depend on each other directly or indirectly
# in order for the merge to be possible.
if not are_nodes_independent(mms):
continue
lhs_vals = [mm.args[0] for mm in mms]
# Merge the matmul.
# Collect a list of LHS operands and the single RHS operand.
lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
# Concatenate all the LHS operands.
merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
# Multiply the concatenated LHS operands with the one RHS. This will produce
# the same results as all the individual matmuls involving rhs in the original graph,
# but they will all be concatenated together.
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
# Split the result of the merged matmul using the shapes of the LHS operands
# to ascertain how large each chunk should be.
merge_mm_split = gm.graph.call_function(
split_result_tensors, (merge_mm, lhs), {}
)
merge_mm_res = [
gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
for out in range(len(lhs))
]
# Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
for old, new in zip(mms, merge_mm_res):
old.replace_all_uses_with(new)
gm.graph.erase_node(old)
# All of the new nodes created above were inserted at the end, so we need to sort
# the nodes topologically to make sure all definitions precede uses.
legalize_graph(gm)
gm.recompile()
gm.graph.lint()
return gm

View File

@ -0,0 +1,269 @@
# mypy: allow-untyped-defs
import torch
import torch.fx
import warnings
import functools
import builtins
from typing import Any, Callable, Dict, Optional, Union
def embedding_override(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
def nn_layernorm_override(self, input):
return input
def torch_relu_override(x):
return x
def torch_nn_relu_override(self, x):
return x
def functional_relu_override(x, inplace=False):
assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
return x
def torch_where_override(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
def torch_abs_override(input, *, out=None):
assert out is None, 'Dont support in-place abs for MetaTensor analysis'
return input
manual_meta_overrides : Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: nn_layernorm_override,
torch.relu: torch_relu_override,
torch.nn.functional.relu: functional_relu_override,
torch.nn.ReLU: torch_nn_relu_override,
torch.where: torch_where_override,
torch.abs: torch_abs_override,
}
def gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, torch.fx.Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy('call_function', target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
class MetaProxy(torch.fx.Proxy):
def install_tensor_meta(self, tensor_meta):
self._tensor_meta = tensor_meta
def size(self, dim=None):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.size(*[dim] if dim else [])
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
def dim(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.dim()
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
@property
def shape(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.shape
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
@property
def dtype(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
return self._tensor_meta.dtype
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, 'device')
def __getattr__(self, k):
if k == '_tensor_meta':
return self.__getattribute__(k)
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return MetaAttribute(self, k)
class MetaAttribute(MetaProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
class MetaDeviceAttribute(MetaAttribute):
pass
def proxys_to_metas(v):
if isinstance(v, MetaDeviceAttribute):
return 'meta'
if isinstance(v, torch.fx.Proxy):
assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
return v._tensor_meta
return v
class MetaTracer(torch.fx.Tracer):
allow_insert_stateless_mods : bool = True
_TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if kind == 'placeholder' and target in self.meta_args:
rv.install_tensor_meta(self.meta_args[target])
return rv
if target in self.orig_fns:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if 'device' in kwargs:
kwargs['device'] = 'meta'
try:
args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
if kind == 'call_function':
meta_target = manual_meta_overrides.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == 'call_method':
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index]
elif kind == 'call_module':
assert hasattr(self, 'orig_forward')
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in manual_meta_overrides:
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == 'get_attr':
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split('.')
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
assert isinstance(attr_itr, torch.Tensor)
meta_out = attr_itr.to(device='meta')
finally:
self._disable_module_getattr = False
else:
return rv
# TODO
assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
rv.install_tensor_meta(meta_out)
except Exception as e:
warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
return rv
def getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, '_disable_module_getattr', False):
return attr_val
else:
return super().getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
idx = 0
mod_name = mod.__class__.__name__.lower()
path = f"{mod_name}_{idx}"
while hasattr(self.root, path):
path = f"{mod_name}_{idx}"
idx += 1
self.root.add_module(path, mod)
return path
def path_of_module(self, mod: torch.nn.Module) -> str:
try:
return super().path_of_module(mod)
except NameError as e:
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod)
self.prev_module = path
return path
raise
def proxy(self, node):
return MetaProxy(node, self)
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
assert isinstance(meta_args, dict)
self.meta_args = meta_args
self.patched_torch_methods = {
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
try:
graph = super().trace(root, concrete_args)
graph._tracer_extras = {'meta_args': meta_args}
return graph
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
gm = torch.fx.GraphModule(tracer.root, graph, name)
return gm

View File

@ -0,0 +1,558 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
op_mod, op_gt, op_lt, op_neq, op_eq
from torch.fx.tensor_type import TensorType, Dyn
class Constraint:
pass
class Conj(Constraint):
def __init__(self, conjuncts):
"""
:param conjuncts: Conjunction of constraints
"""
self.conjucts = conjuncts
def __eq__(self, other):
if isinstance(other, Conj):
return self.conjucts == other.conjucts and self.conjucts == other.conjucts
else:
return False
def __repr__(self):
return f'And({self.conjucts})'
class Disj(Constraint):
def __init__(self, disjuncts):
"""
:param disjuncts: Disjunction of constraints
"""
self.disjuncts = disjuncts
def __eq__(self, other):
if isinstance(other, Disj):
return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
else:
return False
def __repr__(self):
return f'Or({self.disjuncts})'
class Prod(Constraint):
def __init__(self, products):
"""
:param products: lists of dimensions to multiply
"""
self.products = products
def __eq__(self, other):
if isinstance(other, Prod):
return self.products == other.products and self.products == other.products
else:
return False
def __repr__(self):
return f'Product({self.products})'
class T(Constraint):
"""
True
"""
def __init__(self) -> None:
pass
def __eq__(self, other):
return isinstance(other, T)
def __repr__(self):
return 'True'
class F(Constraint):
"""
False
"""
def __init__(self) -> None:
pass
def __eq__(self, other):
return isinstance(other, F)
def __repr__(self):
return 'False'
class BinaryConstraint(Constraint):
"""
Represents all binary operations
"""
def __init__(self, lhs, rhs, op):
"""
:param lhs: lhs of the constraint
:param rhs: rhs of the constraint
:param op: string representing the operation
"""
self.lhs = lhs
self.rhs = rhs
self.op = op
def __eq__(self, other):
if isinstance(other, BinaryConstraint):
return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
else:
return False
def __repr__(self):
return f'({self.lhs} {self.op} {self.rhs})'
class BinConstraintT(BinaryConstraint):
"""
Binary constraints about tensors
"""
def __init__(self, lhs, rhs, op):
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
(isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
super().__init__(lhs, rhs, op)
def __eq__(self, other):
return super().__eq__(other)
class BinConstraintD(BinaryConstraint):
"""
Binary constraints about dimensions
"""
def __init__(self, lhs, rhs, op):
assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
super().__init__(lhs, rhs, op)
def __eq__(self, other):
return super().__eq__(other)
class TGreatestUpperBound(Constraint):
"""
Greatest Upper bound for tensors with dynamic type
"""
def __init__(self, res, rhs1, rhs2):
"""
:param res: tensor variable that stores the result of the outout
:param rhs1: tensor or tensor variable
:param rhs2: tensor or tensor variabke
"""
self.res = res
self.rhs1 = rhs1
self.rhs2 = rhs2
def __repr__(self):
return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}'
def __eq__(self, other):
if isinstance(other, TGreatestUpperBound):
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
else:
return False
class DGreatestUpperBound(Constraint):
"""
Greatest Upper bound for dimensions
"""
def __init__(self, res, rhs1, rhs2):
"""
:param res: Dimension variable to store the result
:param rhs1: dimension variable 1
:param rhs2: dimension variable 2
"""
assert is_dim(res)
assert is_dim(rhs1)
assert is_dim(rhs2)
self.res = res
self.rhs1 = rhs1
self.rhs2 = rhs2
def __repr__(self):
return f'{self.res} = {self.rhs1}\u2294{self.rhs2}'
def __eq__(self, other):
if isinstance(other, DGreatestUpperBound):
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
else:
return False
class CanReshape(Constraint):
"""
can_reshape constraint
"""
def __init__(self, src, target):
"""
:param src: tensor variable
:param target: tensor
"""
self.src = src
self.target = target
def __repr__(self):
return f'can-reshape({self.src}, {self.target})'
def __eq__(self, other):
if isinstance(other, CanReshape):
return self.src == other.src and self.target == other.target
else:
return False
class IndexSelect(Constraint):
def __init__(self, tensor_size, input_var, dim_replace, index, output):
"""
Args:
input_var: input to index_select
tensor_size: tensor size we are considering
dim_replace: the dimension of the output at "index"
index: location of the dimensions to replace in the input
output: variable to store the result
"""
assert isinstance(input_var, TVar)
assert isinstance(output, TVar)
assert isinstance(dim_replace, DVar) or dim_replace == Dyn
assert isinstance(index, int)
self.input_var = input_var
self.tensor_size = tensor_size
self.dim_replace = dim_replace
self.index = index
self.output = output
def __repr__(self):
return f' {self.output} = ' \
f'IndexSelect({self.input_var}, ' \
f'tensor_size: {self.tensor_size}, ' \
f'{self.dim_replace}, ' \
f'{self.index})'
def __eq__(self, other):
if isinstance(other, IndexSelect):
return self.tensor_size == other.tensor_size and \
self.dim_replace == other.dim_replace and \
self.index == other.index and \
self.output == other.output and \
self.input_var == other.input_var
else:
return False
class Transpose(Constraint):
def __init__(self, tensor_size, input_var, index1, index2, output):
"""
Args:
tensor_size: current tensor size
input_var: variable to hold input
index1: dimension 1
index2: dimension 2
output: output that stores result
"""
assert isinstance(input_var, TVar)
assert isinstance(output, TVar)
assert isinstance(index1, int)
assert isinstance(index2, int)
self.input_var = input_var
self.tensor_size = tensor_size
self.index1 = index1
self.index2 = index2
self.output = output
def __repr__(self):
return f' {self.output} = ' \
f'Transpose({self.input_var}, ' \
f'tensor_size: {self.tensor_size}, ' \
f'{self.index1}, ' \
f'{self.index2})'
def __eq__(self, other):
if isinstance(other, Transpose):
return self.tensor_size == other.tensor_size and \
self.index1 == other.index1 and \
self.index2 == other.index2 and \
self.output == other.output and \
self.input_var == other.input_var
else:
return False
class GetItem(Constraint):
def __init__(self, tensor_size, index, res, input_var):
"""
Constraint for getting item given a tensor size
:param tensor_size: actual number
:param index: actual number representing the index
:param res: dimension variable to carry the item we get
:param input_var: a tensor variable from which we will get item
"""
assert isinstance(res, DVar)
self.res = res
self.tensor_size = tensor_size
self.index = index
self.input_var = input_var
def __repr__(self):
return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
def __eq__(self, other):
if isinstance(other, GetItem):
return self.res == other.res and \
self.tensor_size == other.tensor_size and \
self.index == other.index and \
self.input_var == other.input_var
else:
return False
class GetItemTensor(Constraint):
def __init__(self, tensor_size, index_tuple, res, input_var):
"""
Constraint for getting item given a tensor size
However, when the argument is a tuple, we will
expect a tensor
:param tensor_size: actual number representing the rank
:param index_tuple: tuple for indexing
:param res: tensor variable to carry the item we get
:param input_var: a tensor variable from which we will get item
"""
assert isinstance(res, TVar)
self.res = res
self.tensor_size = tensor_size
self.index_tuple = index_tuple
self.input_var = input_var
def __repr__(self):
return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
def __eq__(self, other):
if isinstance(other, GetItemTensor):
return self.res == other.res and \
self.tensor_size == other.tensor_size and \
self.index_tuple == other.index_tuple and \
self.input_var == other.input_var
else:
return False
class CalcConv(Constraint):
def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
"""
:param conv_result: the convolution result
:param input_var: input to convolution
:param c_out: output chanel type
:param kernel: kernel tuple
"""
self.conv_result = conv_result
self.input_var = input_var
self.c_out = c_out
self.kernel = kernel
self.padding = padding
self.stride = stride
self.dilation = dilation
self.matching_constraint = matching_constraint_vars
def __repr__(self):
return f'{self.conv_result} =' \
f' calc-conv({self.input_var},' \
f' {self.c_out}, {self.kernel}, ' \
f'{self.padding}, {self.stride},' \
f' {self.dilation})'
def __eq__(self, other):
if isinstance(other, CalcConv):
return self.conv_result == other.conv_result and self.input_var == other.input_var and \
self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
and self.stride == other.stride and self.dilation == other.dilation \
and self.matching_constraint == other.matching_constraint
else:
return False
class CalcMaxPool(Constraint):
def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
"""
:param maxpool_result: the result of maxpool
:param input_var: input to convolution
:param kernel: kernel tuple
"""
self.maxpool_result = maxpool_result
self.input_var = input_var
self.kernel = kernel
self.padding = padding
self.stride = stride
self.dilation = dilation
self.matching_constraint = matching_constraint_vars
def __repr__(self):
return f'{self.maxpool_result} =' \
f' calc-maxpool({self.input_var},' \
f' {self.kernel}, ' \
f'{self.padding}, {self.stride},' \
f' {self.dilation})'
def __eq__(self, other):
if isinstance(other, CalcMaxPool):
return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
and self.kernel == other.kernel and self.padding == other.padding \
and self.stride == other.stride and self.dilation == other.dilation \
and self.matching_constraint == other.matching_constraint
else:
return False
class ApplyBroadcasting(Constraint):
def __init__(self, res1, res2, input1, input2):
"""
:param res1: resulting tensor 1
:param res2: resulting tensor 2
:param input1: tensor variable 1
:param input2: tensor variable 2
"""
self.res1 = res1
self.res2 = res2
self.input1 = input1
self.input2 = input2
def __eq__(self, other):
if isinstance(other, ApplyBroadcasting):
return self.res1 == other.res1 \
and self.res2 == other.res2 \
and self.input1 == other.input1 \
and self.input2 == other.input2
else:
return False
def __repr__(self):
return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
class CalcProduct(Constraint):
"""
Given correct dimensions, calculate the product for flatten accounting for Dyn
"""
def __init__(self, start, end, flattened, dims_to_flatten):
"""
:param start: start index
:param end: end index
:param flattened: variable to store the product
:param dims_to_flatten: the type which we will flatten
"""
assert isinstance(dims_to_flatten, list)
assert isinstance(flattened, TVar)
assert isinstance(start, int)
assert isinstance(end, int)
self.start = start
self.end = end
self.dims_to_flatten = dims_to_flatten
self.flattened = flattened
def __eq__(self, other):
if isinstance(other, CalcProduct):
return self.start == other.start and self.end == other.end and \
self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
else:
return False
def __repr__(self):
return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
class TVar:
"""
Tensor variable with no tensor constructor
"""
def __init__(self, tvar):
"""
:param tvar: tensor variable
"""
self.tvar = tvar
def __repr__(self):
return f'TV({self.tvar})'
def __eq__(self, other):
if isinstance(other, TVar):
return self.tvar == other.tvar
else:
return False
class DVar:
"""
Dimension variable
"""
def __init__(self, c):
"""
:param c: character or number
"""
self.c = c
def __repr__(self):
return f'DV({self.c})'
def __eq__(self, other):
if isinstance(other, DVar):
return self.c == other.c
else:
return False
class BVar:
"""
Boolean variable
"""
def __init__(self, c):
"""
:param c: character or number
"""
self.c = c
def __repr__(self):
return f'BV({self.c})'
def __eq__(self, other):
if isinstance(other, BVar):
return self.c == other.c
else:
return False
def is_algebraic_expression(constraint):
if isinstance(constraint, BinConstraintD):
return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
else:
return isinstance(constraint, Prod)
def is_bool_expr(constraint):
if isinstance(constraint, BinConstraintD):
return constraint.op in [op_gt, op_lt, op_neq, op_eq]
else:
return isinstance(constraint, (BVar, Conj, Disj))
def is_dim(d):
return isinstance(d, (DVar, int)) or d == Dyn

View File

@ -0,0 +1,14 @@
op_add = '+'
op_sub = '-'
op_mul = '*'
op_div = '/'
op_eq = '='
op_neq = '!='
op_imp = '=>'
op_matching = '\u22b3' # (contains)
op_consistency = '~'
op_precision = '\u2291' # (square image of or equal to)
op_leq = '\u2264' # less-than or equal to
op_lt = '<'
op_gt = '>'
op_mod = '%'

View File

@ -0,0 +1,349 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
from torch.fx.tensor_type import TensorType, Dyn
try:
import z3 # type: ignore[import]
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
HAS_Z3 = True
def transform_to_z3(constraint, counter, dimension_dict):
if isinstance(constraint, Conj):
conjuncts = []
for c in constraint.conjucts:
new_c, counter = transform_to_z3(c, counter, dimension_dict)
conjuncts.append(new_c)
return z3.And(conjuncts), counter
elif isinstance(constraint, Disj):
disjuncts = []
for c in constraint.disjuncts:
new_c, counter = transform_to_z3(c, counter, dimension_dict)
disjuncts.append(new_c)
return z3.Or(disjuncts), counter
elif isinstance(constraint, T):
return True, counter
elif isinstance(constraint, F):
return False, counter
elif isinstance(constraint, BinConstraintT):
if constraint.op == op_eq:
lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
return (lhs == rhs), counter
else:
raise NotImplementedError('Method not yet implemented')
elif isinstance(constraint, BinConstraintD):
if constraint.op == op_eq:
if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
transformed_lhs = z3.Bool(constraint.lhs.c)
return transformed_lhs == transformed_rhs, counter
elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
# with dimension transformations we consider the encoding
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
return lhs == rhs, counter
else:
# then we have an algebraic expression which means that we disregard the
# first element of the encoding
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs == rhs, counter
# The assumption here is that the LHS and RHS must be dimensions
elif constraint.op == op_neq:
assert is_dim(constraint.lhs)
assert is_dim(constraint.rhs)
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
if constraint.rhs == Dyn or constraint.lhs == Dyn:
if constraint.rhs == Dyn:
return lhs.arg(0) == 1, counter
elif constraint.lhs == Dyn:
return rhs.arg(0) == 1, counter
# if one of the instances is a number
elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
if isinstance(constraint.lhs, int):
return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
elif isinstance(constraint.rhs, int):
return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
else:
return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
elif constraint.op == op_leq:
# if the dimensions are not dyn, this will come into effect
# there would have been another constraint specifying if a given dimension
# is dyn or not
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs <= rhs, counter
elif constraint.op == op_gt:
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs > rhs, counter
elif constraint.op == op_lt:
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
return lhs < rhs, counter
else:
raise NotImplementedError('operation not yet implemented')
else:
raise NotImplementedError('Operation not yet implemented')
def transform_var(tensor, counter, dimension_dict):
"""
Transforms tensor variables to a format understood by z3
Args:
tensor: Tensor variable or a tensor type potentially with variable dimensions
Returns: Transformed variable to a z3 format
"""
if isinstance(tensor, TensorType):
res = []
for t in tensor.__args__:
transformed, counter = transform_dimension(t, counter, dimension_dict)
res.append(transformed)
assert len(res) <= 4
if len(tensor.__args__) == 1:
return tensor_type.tensor1(res[0]), counter
elif len(tensor.__args__) == 2:
return tensor_type.tensor2(res[0], res[1]), counter
elif len(tensor.__args__) == 3:
return tensor_type.tensor3(res[0], res[1], res[2]), counter
elif len(tensor.__args__) == 4:
return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
elif tensor == Dyn:
return z3_dyn, counter
elif isinstance(tensor, TVar):
return z3.Const(tensor.tvar, tensor_type), counter
def transform_dimension(dimension, counter, dimension_dict):
"""
Takes a dimension variable or a number and transforms it to a tuple
according to our scheme
Args:
dimension: The dimension to be transformed
counter: variable tracking
Returns: tuple and the current counter
"""
if dimension == Dyn:
counter += 1
return D(0, z3.Int(counter)), counter
elif isinstance(dimension, int):
return D(1, dimension), counter
elif isinstance(dimension, DVar):
if dimension.c in dimension_dict:
return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
else:
counter += 1
dimension_dict[dimension.c] = counter
return D(z3.Int(counter), z3.Int(dimension.c)), counter
def transform_algebraic_expression(expr, counter, dimension_dict):
"""
Transforms an algebraic expression to z3 format
Args:
expr: An expression is either a dimension variable or an algebraic-expression
Returns: the transformed expression
"""
assert is_algebraic_expression(expr) or is_dim(expr)
if is_dim(expr):
transformed, counter = transform_dimension(expr, counter, dimension_dict)
return transformed.arg(1), counter
elif isinstance(expr, Prod):
dims = []
for dim in expr.products:
assert is_dim(dim)
d, counter = transform_dimension(dim, counter, dimension_dict)
dims.append(d.arg(1))
return z3.Product(dims), counter
elif is_algebraic_expression(expr):
lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
if expr.op == op_sub:
c = lhs - rhs
elif expr.op == op_add:
c = lhs + rhs
elif expr.op == op_div:
c = lhs / rhs
elif expr.op == op_mul:
c = lhs * rhs
elif expr.op == op_mod:
c = lhs % rhs
else:
raise NotImplementedError('operation not yet implemented')
return c, counter
else:
raise RuntimeError
def transform_all_constraints(traced, counter=0):
"""
Given a trace, generates constraints and transforms them to z3 format
"""
dimension_dict = {} # type: ignore[var-annotated]
generator = ConstraintGenerator(traced)
new_constraints, counter = generator.generate_constraints(counter)
# print(new_constraints.conjucts[0])
# print(*new_constraints.conjucts, sep='\n')
# transform precision, matching, consistency till obtaining a fixed point
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
# print(new_constraints)
# print(new_constraints.conjucts)
# new_constraints.conjucts = new_constraints.conjucts[:-1]
# print(*new_constraints.conjucts, sep='\n')
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
# print(transformed)
return transformed
def iterate_till_fixed_point(constraints, counter):
"""
Transform constraints till reaching a fixed point
"""
old_c = None
while old_c != constraints:
old_c = constraints
constraints, counter = transform_constraint(constraints, counter)
return constraints, counter
def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
"""
Takes a node and a graph and generates two sets of constraints.
One set constraints the node's constraints and another set
constraints the negation of the node's constraints
Args:
tracer_root: the root for getting the module instances
graph: the graph so far in the tracing process
node: node that represents a conditional
counter: variable tracking
Returns: Two sets of constraints. One with a conjunction with the
the conditional constraint and the other with a conjunction with
its negation.
"""
dimension_dict = {} # type: ignore[var-annotated]
generator = ConstraintGenerator(tracer_root, graph)
new_constraints, counter = generator.generate_constraints(counter)
condition_constraint = new_constraints.conjucts[-1]
# we know the constraint is a conjunction where the last constraint is about the conditional
# so remove the last constraint
new_constraints.conjucts = new_constraints.conjucts[:-1]
# transform precision, matching, consistency till obtaining a fixed point
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
# since the function returns a list of one element, we get the first element
# we are only interested in the RHS in this case because the LHS just stores
# the result
# we make sure the constraint is of the form:
# c = b where b is a boolean expression
# and we consider b (constraint.rhs) for transformation
assert isinstance(condition_constraint.lhs, BVar)
assert is_bool_expr(condition_constraint.rhs)
condition_constraint_rhs = condition_constraint.rhs
# transform the condition constraint
condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
return z3.And([transformed, transformed_condition_constraint]), \
z3.And([transformed, negation_transformed_condition_constraint])
def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
"""
Given an IR and a node representing a conditional, evaluate the conditional
and its negation
Args:
tracer_root: Tracer root for module instances
node: The node to be evaluated
Returns: the results of evaluating the condition and the negation with
the rest of the constraints
"""
transformed_positive, transformed_negative = \
transform_all_constraints_trace_time(tracer_root, graph, node, counter)
s = z3.Solver()
s.add(transformed_positive)
if user_constraints is not None:
s.add(user_constraints)
condition = s.check()
s = z3.Solver()
s.add(transformed_negative)
if user_constraints is not None:
s.add(user_constraints)
negation = s.check()
return condition, negation
except ImportError:
HAS_Z3 = False

View File

@ -0,0 +1,53 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
BVar
from torch.fx.experimental.migrate_gradual_types.operation import op_leq
def gen_tvar(curr):
"""
Generate a tensor variable
:param curr: The current counter
:return: a tensor variable and the updated counter
"""
curr += 1
return TVar(curr), curr
def gen_dvar(curr):
"""
Generate a dimension variable
:param curr: the current counter
:return: a dimension variable and an updated counter
"""
curr += 1
return DVar(curr), curr
def gen_bvar(curr):
"""
Generate a boolean variable
:param curr: the current counter
:return: a boolean variable and an updated counter
"""
curr += 1
return BVar(curr), curr
def gen_tensor_dims(n, curr):
"""
Generate a list of tensor dimensions
:param n: the number of dimensions
:param curr: the current counter
:return: a list of dimension variables and an updated counter
"""
dims = []
for _ in range(n):
dvar, curr = gen_dvar(curr)
dims.append(dvar)
return dims, curr
def gen_nat_constraints(list_of_dims):
"""
Generate natural number constraints for dimensions
"""
return [BinConstraintD(0, d, op_leq) for d in list_of_dims]

View File

@ -0,0 +1,29 @@
try:
import z3 # type: ignore[import]
HAS_Z3 = True
# dynamic type
dyn = z3.DeclareSort('Dyn')
dyn_type = z3.Const('dyn', dyn)
# dimension
dim = z3.Datatype('dim')
dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
dim = dim.create()
# tensors
tensor_type = z3.Datatype('TensorType')
tensor_type.declare('Dyn', ('dyn', dyn))
tensor_type.declare('tensor1', ('0', dim))
tensor_type.declare('tensor2', ('0', dim), ('1', dim))
tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
tensor_type = tensor_type.create()
# create dimension
D = dim.dim
z3_dyn = tensor_type.Dyn(dyn_type)
except ImportError:
HAS_Z3 = False

View File

@ -0,0 +1,163 @@
# mypy: allow-untyped-defs
import operator
from typing import Any, Callable, Dict, Tuple, Optional
import torch
import torch.fx
import torch.fx as fx
from torch.fx import Transformer, Proxy
from torch.fx.node import Argument, Target, Node, map_aggregate
from torch.fx.operator_schemas import (
normalize_module,
normalize_function,
create_type_hint,
)
from .schema_type_annotation import AnnotateTypesWithSchema
class NormalizeArgs(Transformer):
"""
Normalize arguments to Python targets. This means that
`args/kwargs` will be matched up to the module/functional's
signature and rewritten to exclusively kwargs in positional order
if `normalize_to_only_use_kwargs` is true. Also populates default
values. Does not support positional-only parameters or varargs
parameters (*args, **kwargs).
If the nodes have 'type' metadata, it will use it to disambiguate
overloads. Otherwise, it will throw an error.
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = NormalizeArgs(traced).transform()
"""
def __init__(
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
):
super().__init__(module)
self.node_map: Dict[Proxy, Node] = {}
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
def run_node(self, n: Node) -> Any:
args, kwargs = self.fetch_args_kwargs_from_env(n)
def get_type(arg):
if isinstance(arg, fx.Node):
return n.meta["type"] if "type" in n.meta else None
return type(arg)
arg_types = map_aggregate(n.args, get_type)
assert isinstance(arg_types, tuple)
arg_types = tuple([create_type_hint(i) for i in arg_types])
kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
if n.op == "call_function":
out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
else:
out = super().run_node(n)
if n.op != "output":
self.node_map[out] = n
out.node.meta = n.meta
out.node.type = n.type
return out
def call_function(
self,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
arg_types: Optional[Tuple[Any, ...]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
):
assert callable(target)
new_args_and_kwargs = normalize_function(
target,
args, # type: ignore[arg-type]
kwargs,
arg_types, # type: ignore[arg-type]
kwarg_types,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
new_args, new_kwargs = new_args_and_kwargs
return self.tracer.create_proxy(
"call_function", target, new_args, new_kwargs
)
else:
return super().call_function(target, args, kwargs)
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
assert isinstance(target, str)
new_args_and_kwargs = normalize_module(
self.module,
target,
args, # type: ignore[arg-type]
kwargs,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
new_args, new_kwargs = new_args_and_kwargs
return super().call_module(target, new_args, new_kwargs)
else:
return super().call_module(target, args, kwargs)
class NormalizeOperators(AnnotateTypesWithSchema):
"""
Normalize callsites that are different ways of "spelling" the same
invocation into a single, canonical call. Currently supports:
1. Normalize operators (e.g. operator.add) to the `torch` ops they
ultimately invoke (e.g. torch.add) when it is possible to statically
reason that
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = NormalizeOperators(traced).transform()
"""
binary_magic_method_remap: Dict[
Callable[[Any, Any], Any], Callable[[Any, Any], Any]
] = {
torch.add: operator.add,
torch.mul: operator.mul,
torch.sub: operator.sub,
torch.div: operator.truediv,
torch.floor_divide: operator.floordiv,
torch.remainder: operator.mod,
torch.eq: operator.eq,
torch.ne: operator.ne,
torch.lt: operator.lt,
torch.le: operator.le,
torch.gt: operator.gt,
torch.ge: operator.ge,
}
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
# Normalize operators according to the magic methods implemented on tensors here:
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
assert callable(target)
if target in self.binary_magic_method_remap:
if len(args) != 2:
return super().call_function(target, args, kwargs)
lhs, rhs = args
return super().call_function(
target=self.binary_magic_method_remap[target],
args=(lhs, rhs),
kwargs={},
)
return super().call_function(target, args, kwargs)

View File

@ -0,0 +1,409 @@
# mypy: allow-untyped-defs
import torch.fx as fx
from torch.fx.node import Argument, Target
from torch.nn.utils.fusion import fuse_conv_bn_eval
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx.passes.shape_prop import ShapeProp
import copy
from collections import defaultdict
import torch.utils.mkldnn as th_mkldnn
import operator
import time
import logging
from enum import Enum
def _parent_name(target : str) -> Tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
# Works for length 2 patterns with 2 modules
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
if len(node.args) == 0:
return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node):
return False
if current_node.op != 'call_module':
return False
if not isinstance(current_node.target, str):
return False
if current_node.target not in modules:
return False
if type(modules[current_node.target]) is not expected_type:
return False
return True
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
"""
Fuses convolution/BN layers for inference purposes. Will deepcopy your
model by default, but can modify the model inplace as well.
"""
patterns = [(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d)]
if not inplace:
model = copy.deepcopy(model)
if not no_trace or not isinstance(model, torch.fx.GraphModule):
fx_model = fx.symbolic_trace(model)
else:
fx_model = model
modules = dict(fx_model.named_modules())
new_graph = copy.deepcopy(fx_model.graph)
for pattern in patterns:
for node in new_graph.nodes:
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
if not bn.track_running_stats:
continue
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
new_graph.erase_node(node)
return fx.GraphModule(fx_model, new_graph)
def remove_dropout(model: nn.Module) -> nn.Module:
"""
Removes all dropout layers from the module.
"""
fx_model = fx.symbolic_trace(model)
class DropoutRemover(torch.fx.Transformer):
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if isinstance(self.submodules[target], nn.Dropout):
assert len(args) == 1
return args[0]
else:
return super().call_module(target, args, kwargs)
return DropoutRemover(fx_model).transform()
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
"""
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
"""
new_graph = fx.Graph()
env: Dict[fx.Node, fx.Node] = {}
for input in inputs:
new_node = new_graph.placeholder(input.name)
env[input] = new_node
for node in nodes:
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
new_graph.output([env[output] for output in outputs])
new_graph.lint()
return fx.GraphModule(orig_module, new_graph)
mkldnn_supported = [
nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
torch.relu, torch.transpose, torch.sigmoid,
F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
]
# These are operators that may not be convertible into MKLDNN ops (e.g. the
# args are scalar values). Thus, we only include them in the subgraph if their
# arguments are already in MKLDNN.
# TODO: Determine whether this can be removed after type inference.
mkldnn_supported_unknown = [operator.add, operator.mul]
mkldnn_map = {
nn.Conv2d: th_mkldnn.MkldnnConv2d,
nn.Linear: th_mkldnn.MkldnnLinear,
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
}
def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
"""
For each node, if it's a module that can be preconverted into MKLDNN,
then we do so and create a mapping to allow us to convert from the MKLDNN
version of the module to the original.
"""
old_modules: Dict[nn.Module, nn.Module] = {}
for node in nodes:
if node.op == 'call_module':
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in mkldnn_map:
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert isinstance(new_module, nn.Module)
old_modules[new_module] = copy.deepcopy(cur_module)
replace_node_module(node, modules, new_module)
return old_modules
def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
"""
Maps each module that's been changed with `modules_to_mkldnn` back to its
original.
"""
for node in nodes:
if node.op == 'call_module':
assert (isinstance(node.target, str))
cur_module = modules[node.target]
if cur_module in old_modules:
replace_node_module(node, modules, old_modules[cur_module])
class MklSubgraph:
def __init__(self, fx_graph: fx.Graph):
self.fx_graph = fx_graph
self.nodes: List[fx.Node] = []
self.start_nodes: List[fx.Node] = []
self.end_nodes: List[fx.Node] = []
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
"""
This generates a heuristic that can be passed into `optimize_for_inference` that
determines whether a subgraph should be run in MKL by running it with the example_inputs.
Example usage:
heuristic = gen_mkl_autotuner(example_inputs, iters=10)
fast_model = optimization.optimize_for_inference(model, heuristic)
"""
fx_model = None
old_modules = None
def use_mkl_heuristic(graph: MklSubgraph) -> bool:
nonlocal fx_model, old_modules
input_nodes = graph.start_nodes
if fx_model is None:
fx_model = graph.fx_graph.owning_module
old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
ShapeProp(fx_model).propagate(example_inputs)
sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
def benchmark(f):
for _ in range(warmup):
f()
begin = time.time()
for _ in range(iters):
out = f()
return time.time() - begin
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
return mkl_time < no_mkl_time
return use_mkl_heuristic
def use_mkl_length(graph: MklSubgraph) -> bool:
"""
This is a heuristic that can be passed into `optimize_for_inference` that
determines whether a subgraph should be run in MKL by checking if there
are more than 2 nodes in it
"""
return len(graph.nodes) > 2
class UnionFind:
def __init__(self, n):
self.parent: List[Optional[int]] = [None] * n
self.size: List[int] = [0] * n
def make_set(self, v: int):
self.parent[v] = v
self.size[v] = 1
def find(self, v: int) -> int:
par = self.parent[v]
if v == par:
return v
assert par is not None
self.parent[v] = self.find(par)
return cast(int, self.parent[v])
def join(self, a: int, b: int):
a, b = self.find(a), self.find(b)
if a == b:
return a
if self.size[a] < self.size[b]:
a, b = b, a
self.parent[b] = a
self.size[a] += self.size[b]
def optimize_for_inference(
model: torch.nn.Module,
pass_config: Optional[Dict[str, Any]] = None,
tracer: Type[fx.Tracer] = fx.Tracer
) -> torch.nn.Module:
"""
Performs a set of optimization passes to optimize a model for the
purposes of inference. Specifically, the passes that are run are:
1. Conv/BN fusion
2. Dropout removal
3. MKL layout optimizations
The third optimization takes a function `use_mkl_heuristic` that's used
to determine whether a subgraph should be explicitly run in MKL layout.
Note: As FX does not currently handle aliasing, this pass currently
assumes nothing aliases. If that isn't true, use at your own risk.
"""
default_pass_config = {
"conv_bn_fuse": True,
"remove_dropout": True,
"mkldnn_layout_optimize": {'heuristic': use_mkl_length},
}
if pass_config is None:
pass_config = {}
default_pass_config.update(pass_config)
if default_pass_config["conv_bn_fuse"]:
model = fuse(model)
if default_pass_config["remove_dropout"]:
model = remove_dropout(model)
if default_pass_config["mkldnn_layout_optimize"] is False:
return model
if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
raise RuntimeError("mkldnn_layout_optimize config is not a dict")
if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
cur_tracer = tracer()
fx_graph = cur_tracer.trace(copy.deepcopy(model))
fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
modules: Dict[str, nn.Module] = dict(model.named_modules())
class MklSupport(Enum):
NO = 1
YES = 2
UNKNOWN = 3
# Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
# If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
# However, if it's in `mkldnn_supported_unknown`, then we only treat it as
# a MKLDNN node if its inputs are MKLDNN nodes.
for node in list(fx_graph.nodes):
supports_mkldnn = MklSupport.NO
if node.op == 'call_module':
cur_module = modules[node.target]
if type(cur_module) in mkldnn_supported:
supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None:
assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
elif node.op == 'call_function':
if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES
elif node.target in mkldnn_supported_unknown:
supports_mkldnn = MklSupport.UNKNOWN
if supports_mkldnn != MklSupport.NO:
if supports_mkldnn == MklSupport.UNKNOWN:
if not any(arg.target == 'to_dense' for arg in node.args):
continue
with fx_graph.inserting_before(node):
mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
with fx_graph.inserting_after(node):
dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
node.replace_all_uses_with(dense_x)
dense_x.args = (node,)
# Does pre-conversion of all modules into MKLDNN (when possible)
old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
fx_graph.old_modules = old_modules # type: ignore[attr-defined]
# optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
for node in fx_graph.nodes:
if node.op == 'call_method' and node.target == 'to_dense':
prv_node = node.args[0]
users = list(node.users)
for user in users:
if user.op == 'call_method' and user.target == 'to_mkldnn':
user.replace_all_uses_with(prv_node)
fx_graph.erase_node(user)
if len(node.users) == 0:
fx_graph.erase_node(node)
num_nodes = len(fx_graph.nodes)
uf = UnionFind(num_nodes)
def get_color(n):
if hasattr(n, 'color'): # Current node is part of a MKL subgraph
return uf.find(n.color)
if hasattr(n, 'start_color'): # Current node is input to MKL subgraph
return uf.find(n.start_color)
return None
# This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
# of input nodes (which are only `to_mkldnn` calls), output nodes
# (`to_dense` calls), and intermediate nodes, which are run entirely on
# MKLDNN layout tensors.
#
# Specifically, this code does a flood fill on a directed acyclic graph
# (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
# If every node only had one input, this would be sufficient. However, in
# the case that a node has multiple inputs coming from different start
# nodes (i.e. colors), we need to join these 2 colors into 1. That's done
# using a Disjoint Set Union.
for cur_idx, node in enumerate(fx_graph.nodes):
if node.op == 'call_method' and node.target == 'to_mkldnn':
node.start_color = cur_idx
uf.make_set(cur_idx)
elif node.op == 'call_method' and node.target == 'to_dense':
assert get_color(node.args[0]) is not None
node.end_color = get_color(node.args[0])
else:
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
if len(cur_colors) == 0:
continue
assert not any(i is None for i in cur_colors)
cur_colors = sorted(cur_colors)
node.color = cur_colors[0]
for other_color in cur_colors[1:]:
uf.join(cur_colors[0], other_color)
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
for node in fx_graph.nodes:
if hasattr(node, 'color'):
mkldnn_graphs[uf.find(node.color)].nodes.append(node)
if hasattr(node, 'start_color'):
mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
if hasattr(node, 'end_color'):
mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
# Now that we have all the subgraphs, we need to decide which MKLDNN
# subgraphs we actually want to keep in MKLDNN.
for graph in mkldnn_graphs.values():
if not use_mkl_heuristic(graph):
for node in graph.start_nodes + graph.end_nodes:
prv = node.args[0]
node.replace_all_uses_with(prv)
fx_graph.erase_node(node)
reset_modules(graph.nodes, modules, old_modules)
mkldnn_conversions = 0
for node in fx_graph.nodes:
if node.target == 'to_mkldnn' or node.target == 'to_dense':
mkldnn_conversions += 1
logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
fx_graph.lint()
result = fx.GraphModule(model, fx_graph)
return result

View File

@ -0,0 +1,318 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import NamedTuple, Dict, List, Set
from torch.fx.node import Node, map_arg
class Partition:
"""Partition class contains all the information about an individual partition.
It also provides necessary methods for manipulation the partition.
"""
def __init__(self, partition_id: int) -> None:
self.nodes: Set[Node] = set()
self.partition_id = partition_id
self.parents: Set[Partition] = set()
self.children: Set[Partition] = set()
self.bfs_level: int = -1
self.used_mem_bytes: int = 0
self.logical_device_ids: List[int] = []
def __str__(self):
return str(self.partition_id)
def recalculate_mem_size(self):
self.used_mem_bytes = 0
for node in self.nodes:
self.used_mem_bytes += get_extra_size_of(node, self.nodes)
def add_node(self, node):
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Add current node's input nodes if they are placeholder or constants
for n in input_nodes:
if n.op in {"placeholder", "get_attr"}:
self.nodes.add(n)
self.nodes.add(node)
self.recalculate_mem_size()
def remove_node(self, node):
# Remove a node only if the node is in the partition
if node in self.nodes:
self.nodes.remove(node)
# Collect the node's input nodes
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Check if an input node is a placeholder or get_attr,
# and this input node is not used by some other nodes in this partition,
# the remove this input node
for input_node in input_nodes:
if all(
n not in self.nodes for n in input_node.users
) and input_node.op in {"placeholder", "get_attr"}:
self.nodes.remove(input_node)
self.recalculate_mem_size()
class Device(NamedTuple):
name: str
available_mem_bytes: int
logical_id: int
class NodeLatency(NamedTuple):
# Latency due to the memory bandwidth
mem_latency_sec: float
# Latency due to the computation
computer_latency_sec: float
class PartitionLatency(NamedTuple):
# Sum of all nodes' memory latency on the critical path
mem_latency_sec: float
# Sum of all nodes' compute latency on the critical path
computer_latency_sec: float
# Latency of the critical path
overall_latency_sec: float
class PartitionMode(Enum):
size_based = 0
sparse_nn = 1
cost_aware = 2
kl_based = 3
aot_based = 4
class PartitionerConfig(NamedTuple):
devices: List[Device]
mode: PartitionMode = PartitionMode.size_based
transfer_rate_bytes_per_sec: float = 0.0
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
node_to_partition_mapping: Dict[Node, int] = {}
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
# Saturate host by replicating partitions to the remaining idle devices.
saturate_host: bool = False
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
"""Given a node and a set of nodes,
this function return the extra size that needed
if this node is included in this set.
"""
# Find all its input nodes
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Calculate total size of related nodes
total_size_of_input_nodes = 0
for n in input_nodes:
# Make sure this node hasn't been in this set yet
if n not in nodes:
size_bytes = getattr(n, "size_bytes", None)
if size_bytes:
total_size_of_input_nodes += size_bytes.output_size
else:
raise RuntimeError("node has no size_bytes attr")
# Don't forget the op node itself
size_bytes = getattr(node, "size_bytes", None)
if size_bytes:
total_size_of_input_nodes += size_bytes.total_size
else:
raise RuntimeError("node has no size_bytes attr")
return total_size_of_input_nodes
def get_latency_of_one_partition(
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
) -> PartitionLatency:
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
def get_top_nodes(partition: Partition) -> List[Node]:
"""Given a partition, return a list of nodes on the top bfs level"""
top_nodes: List[Node] = []
for node in partition.nodes:
# Skip placeholder and get_attr nodes
if node.op in {"placeholder", "get_attr"}:
continue
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# If a node has no input nodes in this partition,
# or its input nodes in this partition are placeholders and get_attrs
# this node is on the top bfs level in this partition
if not any(
n in partition.nodes and n.op not in {"placeholder", "get_attr"}
for n in input_nodes
):
top_nodes.append(node)
return top_nodes
def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
"""Given a top node of a partition, this function returns
the latency of the critical path in the partition
"""
node_latency = node_to_latency_mapping[node]
# Calculate the current overall latency of the partition
overall_latency_sec = partition_latency.overall_latency_sec + max(
node_latency.computer_latency_sec, node_latency.mem_latency_sec
)
# Update the mem latency of this path
mem_latency_sec = (
partition_latency.mem_latency_sec + node_latency.mem_latency_sec
)
# Update the compute latency of this path
computer_latency_sec = (
partition_latency.computer_latency_sec + node_latency.computer_latency_sec
)
# Get all users of this node that are in this partition
users = set(node.users).intersection(partition.nodes)
if users:
max_latency = PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
)
for n in users:
# Get new partition latency recursively
new_partition_latency = dfs_helper(
n,
PartitionLatency(
mem_latency_sec, computer_latency_sec, overall_latency_sec
),
)
if (
new_partition_latency.overall_latency_sec
> max_latency.overall_latency_sec
):
max_latency = new_partition_latency
return max_latency
# If there is no user, the node is at bottom of the partition
return PartitionLatency(
mem_latency_sec, computer_latency_sec, overall_latency_sec
)
# Main part starts
# Get all top level nodes of this partition
top_nodes = get_top_nodes(partition)
critical_path_latency = PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
)
# Go through all top nodes and find the largest latency (critical pass latency)
for node in top_nodes:
partition_latency = dfs_helper(
node,
PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
),
)
if (
partition_latency.overall_latency_sec
> critical_path_latency.overall_latency_sec
):
critical_path_latency = partition_latency
return critical_path_latency
def get_partition_to_latency_mapping(
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
) -> Dict[Partition, PartitionLatency]:
"""Given all the partitions and node_to_latency_mapping dictionary,
return a mapping dictionary of each partition to its overall latency
"""
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
# Go through each partition and get its latency
for partition in partitions:
partition_latency = get_latency_of_one_partition(
partition, node_to_latency_mapping
)
partition_to_latency_mapping[partition] = partition_latency
return partition_to_latency_mapping
def get_comm_latency_between(
parent_partition: Partition,
child_partition: Partition,
transfer_rate_bytes_per_sec: float,
):
"""Given two partitions (parent and child),
calculate the communication latency between the two.
"""
# If two partitions are on the same device, the comm latency is 0.
if (
parent_partition.logical_device_ids != []
and child_partition.logical_device_ids != []
and parent_partition.logical_device_ids == child_partition.logical_device_ids
):
return 0.0
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
for node in child_partition.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
for n in input_nodes:
if n in parent_partition.nodes and n not in visited_nodes:
size_bytes = getattr(n, "size_bytes", None)
if size_bytes is not None:
comm_size += size_bytes.output_size
visited_nodes.add(n)
return comm_size / transfer_rate_bytes_per_sec
def get_latency_of_partitioned_graph(
partitions: List[Partition],
partition_to_latency_mapping: Dict[Partition, PartitionLatency],
transfer_rate_bytes_per_sec: float,
):
"""Given all partitions in a graph, find the critical path among all partitions
and return its latency as the latency of the whole graph
"""
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
"""This function helps to recursively get the latency of a path of partitions"""
# Update latency by adding current partition's latency
latency_so_far_sec += partition_to_latency_mapping[
partition
].overall_latency_sec
children = partition.children
if partition.children:
max_latency_sec = 0.0
for child in partition.children:
# Calculate latency between
comm_latency_sec = get_comm_latency_between(
partition, child, transfer_rate_bytes_per_sec
)
new_latency_sec = dfs_helper(
child, latency_so_far_sec + comm_latency_sec
)
if new_latency_sec > max_latency_sec:
max_latency_sec = new_latency_sec
return max_latency_sec
return latency_so_far_sec
def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
"""This function is to return all the partitions without parents
as the starting points of all the paths
"""
top_partitions = []
for partition in partitions:
# If a partition has no parents, then it is a top partition
if len(partition.parents) == 0:
top_partitions.append(partition)
return top_partitions
top_partitions = get_top_partitions(partitions)
critical_path_latency_sec = 0.0
for partition in top_partitions:
latency_sec = dfs_helper(partition, 0.0)
if latency_sec > critical_path_latency_sec:
critical_path_latency_sec = latency_sec
return critical_path_latency_sec

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,512 @@
# mypy: allow-untyped-defs
import functools
import inspect
import itertools
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
log = logging.getLogger(__name__)
trace_shape_events_log = torch._logging.getArtifactLogger(
__name__, "trace_shape_events"
)
__all__ = [
"ShapeEnvEvent",
"record_shapeenv_event",
"replay_shape_env_events",
"FakeTensorMeta",
"shape_env_check_state_equal",
"NotEqualError",
]
# [Note: Recording ShapeEnv Events]
# =================================
#
# What is a ShapeEnv event?
# -------------------------
# We consider a ShapeEnv event every function call (ShapeEnv method or
# independent function) that modifies the state of the ShapeEnv instance.
# Such calls are recorded alongside their positional and keyword arguments,
# so that it may be replayed over a different ShapeEnv instance.
#
# See [Note: ShapeEnv State Equality] for what is considered the state
# of a ShapeEnv instance.
#
# What is it for?
# ---------------
# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
# arbitrary state in time.
#
# Being able to arbitrarily replay events like so is useful, mainly for
# translation validation bisection. i.e. if a ValidationException has been
# raised, find the earliest point in time where the translation validation
# fails.
#
# Besides that, it also allows us to inspect the given instance and,
# for example, check the guards that would actually be issued at that point.
#
# What kind of arguments can be stored in an event?
# -------------------------------------------------
# There's no specific rule for what cannot be used as an argument.
# That said, pay special attention to the following cases:
#
# 1. Tensor inputs: there are some tests that check whether the inputs
# were garbage collected after execution. These will fail if there's
# an event that is holding a reference to those inputs.
#
# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
# will be automatically replaced by the new given ShapeEnv instance.
#
# 3. SymTypes arguments: they also hold references to ShapeEnv. So,
# whenever we see them, we create a new instance, replacing the
# ShapeEnv reference.
#
# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic
# shapes. That argument must be replaced when replaying the event at
# ShapeEnvEvent.run, since it has to reference a node from the given
# instance, and not from the recorded instance.
# Event class for reconstructing ShapeEnv at arbitrary time.
#
# Represents a method call that mutates ShapeEnv in a way that affects the
# issued guards, when ShapeEnv.produce_guards is called.
@dataclass
class ShapeEnvEvent:
# ShapeEnv method.
f: Callable
# Arguments and keyword arguments called with.
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None
# List of tracked_fakes at the time the method was called.
tracked_fakes: Optional[List[Any]] = None
# Name of the captured event.
# Used for special handling of particular methods.
name: Optional[str] = None
# Replay itself, but using shape_env as self.
def run(self, shape_env=None) -> Any:
from torch.fx.experimental.symbolic_shapes import (
is_symbolic,
ShapeEnv,
SymTypes,
)
# Special handling for the constructor event.
if self.f is ShapeEnv:
assert shape_env is None and self.args is None and self.kwargs is not None
return ShapeEnv(**self.kwargs)
assert shape_env is not None
args = list(self.args or [])
kwargs = dict(self.kwargs or {})
# Replace any argument of type ShapeEnv by the given one.
args, kwargs = pytree.tree_map_only(
ShapeEnv, lambda _: shape_env, (args, kwargs)
)
# Replace any argument of type SymTypes by a new instance,
# replacing its ShapeEnv reference.
args, kwargs = pytree.tree_map_only(
lambda x: isinstance(x, SymTypes) and is_symbolic(x),
lambda a: type(a)(a.node.with_shape_env(shape_env)),
(args, kwargs),
)
# Converts FX nodes using the mapping argument.
def maybe_convert_node(x: Any) -> Any:
if not isinstance(x, torch.fx.Node):
# Don't do anything to x if it's not an FX node.
return x
# If, at some point, we created an FX node, it means that translation validation is on.
# It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
# we are tracking node names at shape_env.name_to_node.
assert hasattr(shape_env, "name_to_node")
name_to_node = shape_env.name_to_node # type: ignore[attr-defined]
assert x.name in name_to_node
return name_to_node[x.name]
# Replaces the value of an specific argument by the result of fn.
def replacearg(index: int, key: str, fn: Callable):
if index < len(args):
args[index] = fn(args[index])
if key in kwargs:
kwargs[key] = fn(kwargs[key])
if self.is_create_fx_call_function():
# ShapeEnv.create_fx_call_function:
# "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
# They must be replaced, since a "call_function" FX node with this tuple as argument
# will be added to the FX graph of the new shape_env.
replacearg(
index=2,
key="args",
fn=lambda args: tuple(maybe_convert_node(a) for a in args),
)
if self.is_evaluate_expr() or self.is_defer_runtime_assert():
# ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
# "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
# They must be replaced, since it will be part of a "call_function" FX node for
# torch._assert, which will be added to the FX graph of the new shape_env.
replacearg(index=3, key="fx_node", fn=maybe_convert_node)
# Actually call the method with the converted arguments.
return self.f(*args, **kwargs)
def __str__(self) -> str:
name = self.name if self.name is not None else self.f.__name__
return f"event: {name} ({self.args}, {self.kwargs})"
def is_create_fx_call_function(self) -> bool:
return self.name == "_create_fx_call_function"
def is_evaluate_expr(self) -> bool:
return self.name == "evaluate_expr"
def is_defer_runtime_assert(self) -> bool:
return self.name == "defer_runtime_assert"
NEST = 0
# Extracts a ShapeEnv instance inside args and kwargs.
# Specifically, it looks for:
# 1. ShapeEnv arguments
# 2. SymInt, SymFloat, or SymBool arguments
# If we find more than one object of any of the above types, we
# also check that the ShapeEnv instance is the same for all of them.
def _extract_shape_env_and_assert_equal(args, kwargs):
from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
if old is not None:
assert old is new, "call with different ShapeEnv"
return new
shape_env = None
for val in itertools.chain(args, kwargs.values()):
if isinstance(val, ShapeEnv):
shape_env = assert_equal(shape_env, val)
if isinstance(val, SymTypes) and is_symbolic(val):
shape_env = assert_equal(shape_env, val.node.shape_env)
return shape_env
# Decorator for recording the given function as a replayable event.
#
# This decorator should be used at every function that mutates the state of
# ShapeEnv in some way that affects the resulting issued guards (i.e. when
# ShapeEnv.produce_guards is called).
#
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
#
# When to save the list of TrackedFake?
# =====================================
# We should save the list of TrackedFake whenever the translation validation
# bisection may actually stop and call the produce_guards method at the moment
# right after the recorded function was played. In other words, since the
# bisection bisects through torch._assert calls, we should save in all methods
# that adds a torch._assert call to the symbolic shapes FX graph.
#
# At the moment, there are 2 methods that save the list:
# - ShapeEnv.evaluate_expr
# - ShapeEnv.defer_runtime_assert
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
def decorator(fn: Callable) -> Callable:
assert callable(fn)
args = inspect.getfullargspec(fn).args
assert args and args[0] == "self", (
"record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
"code so that it calls into a method on ShapeEnv"
)
name = fn.__name__
@functools.wraps(fn)
def wrapper(*args, **kwargs):
from torch.fx.experimental.symbolic_shapes import ShapeEnv
assert isinstance(args[0], ShapeEnv)
global NEST
trace_shape_events_log.debug(
"%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
)
NEST += 1
def retlog(r):
trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
return r
try:
if args[0].is_recording: # type: ignore[has-type]
# If ShapeEnv is already recording an event, call the wrapped
# function directly.
#
# NB: here, we skip the check of whether all ShapeEnv instances
# are equal, in favor of a faster dispatch.
return retlog(fn(*args, **kwargs))
# Retrieve an instance of ShapeEnv.
# Assumption: the collection of args and kwargs may not reference
# different ShapeEnv instances.
self = _extract_shape_env_and_assert_equal(args, kwargs)
# If we are calling this function without any ShapeEnv instance
# alive in its arguments, we don't record and call the original.
if self is None:
return retlog(fn(*args, **kwargs))
# Otherwise, start recording and call the function.
with self._recording():
# Take a snapshot of the current tracked_fakes.
tracked_fakes = (
self._snapshot_tracked_fakes() if save_tracked_fakes else None
)
# Record the event for 'fn'.
event = ShapeEnvEvent(
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
)
# Play the event on this ShapeEnv.
# NB: It's important to put the event first, because running
# the event can trigger internal events that must be ordered
# after this event. However, if an exception happens, we do
# NOT want to have the event in the list, so pop it off from
# the record if an error happened
self.events.append(event)
try:
return retlog(event.run(self))
except Exception:
self.events.pop()
raise
except Exception:
log.error( # noqa: G201
"failed while running %s(*%s, **%s)",
name,
args[1:],
kwargs,
exc_info=log.isEnabledFor(logging.INFO),
)
raise
finally:
NEST -= 1
return wrapper
return decorator
# Replays the ShapeEnvEvents list.
# It assumes the first event is the constructor call.
#
# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
def replay_shape_env_events(events):
from torch.fx.experimental.symbolic_shapes import ShapeEnv
constructor_event = events[0]
assert constructor_event.f == ShapeEnv
# Constructs the new ShapeEnv.
shape_env = constructor_event.run()
for event in events[1:]:
try:
# Actually replays each event.
# We need to call create_mapping_fn every time, since the node list might
# change after each event is replayed.
event.run(shape_env)
except Exception as e:
log.error("failed when running event: %s", event)
raise
return shape_env
# FakeTensor metadata.
# This is to be used in place of FakeTensor placeholders when calling
# ShapeEnv.produce_guards.
@dataclass
class FakeTensorMeta:
tensor_size: Tuple[Union[int, torch.SymInt], ...]
tensor_stride: Tuple[Union[int, torch.SymInt], ...]
tensor_storage_offset: Union[int, torch.SymInt]
is_nested: bool
def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
return self.tensor_size
def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
return self.tensor_stride
def storage_offset(self) -> Union[int, torch.SymInt]:
return self.tensor_storage_offset
def dim(self) -> int:
return len(self.tensor_size)
@staticmethod
def from_fake(fake) -> "FakeTensorMeta":
return FakeTensorMeta(
fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
)
# [Note: ShapeEnv State Equality]
# ===============================
#
# What is considered ShapeEnv state?
# ----------------------------------
# We consider to be the state of a ShapeEnv instance everything that
# is not in the inline tuple inside remove_nonstate_variables function.
# That is: the fields within ShapeEnv that modify the flow of execution
# of the program.
#
# So, for example: the replacements field might influence on how an
# expression is simplified. That, in turn, may result in a guard being
# statically known (i.e. not added).
#
# On the other hand, var_to_stack serves only changes what is printed
# in the screen, i.e. used only for debugging purposes. Therefore, we
# should not consider it when comparing states.
#
# What to do on NotEqualError?
# ----------------------------
# Here are a few possible causes for getting a NotEqualError raised:
#
# 1. New field that does not belong in the ShapeEnv state.
# For example: log field of type ShapeEnvLoggerAdapter. Different
# ShapeEnv instances will always have different ShapeEnvLoggerAdapter
# instances, i.e. equality comparison would fail.
# Solution: add it to the inlined tuple inside remove_nonstate_variables
# function inside check_equal method.
#
# 2. New field that is not directly comparable across instances.
# For example: guards field of type List[ShapeGuard]. More specifically,
# the ShapeGuard type holds an expression and a stack information
# for debugging purposes. When replaying the even on a new ShapeEnv
# instance, the stack would be different, which would trigger this error.
# Solution: add a special case to the map_value function inside
# check_equal function.
#
# 3. Mutation of ShapeEnv on some not recorded function.
# If a mutation of the state of ShapeEnv happens inside a function
# that is not recorded (or that no caller in the stack is recorded),
# then, the replayed ShapeEnv won't catch that.
# Solution: decorate the function with record_shape_env_event.
# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
# returned by ShapeEnv.produce_guards.
def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
# Collect and remove variables that don't necessarily represent the state
# of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
# instance itself.
env1_vars = vars(env1).copy()
env2_vars = vars(env2).copy()
for v in non_state_variable_names:
if v in env1_vars:
env1_vars.pop(v)
if v in env2_vars:
env2_vars.pop(v)
# Function for transforming the mismatched values into string.
# Needed, since dict and set entries order might not be the same every time.
def value_to_str(value: Any) -> str:
if isinstance(value, dict):
return (
"{"
+ ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
+ "}"
)
if isinstance(value, set):
return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
return str(value)
# Compares env1_vars with env2_vars.
# Here, we allow the value of each field to be mapped, so that we appropriately
# compare the two values.
def compare_vars(
map_value: Callable[[str, Any], Any]
) -> List[Tuple[str, str, str]]:
env1_set, env2_set = set(env1_vars), set(env2_vars)
# First, compare the set of keys in each vars dictionary.
if env1_set != env2_set:
raise NotEqualError(
"field set mismatch:",
[
(
"found unique fields:",
str(sorted(env1_set - env2_set)),
str(sorted(env2_set - env1_set)),
),
],
)
# Then, sort the keys, and compare the mapped values of each key.
sorted_keys = list(env1_set)
sorted_keys.sort()
mapped_dict = [
(k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
for k in sorted_keys
]
# Return a list of tuples representing the fields that did not match
# alongside their respective mapped values.
return [
(f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
for k, val1, val2 in mapped_dict
if val1 != val2
]
# Accumulate the mismatching fields.
errors = compare_vars(map_value)
if len(errors) > 0:
raise NotEqualError("field values don't match:", errors)
class NotEqualError(Exception):
def __init__(
self,
msg: str,
mismatched: List[Tuple[str, str, str]],
) -> None:
details = "\n".join(
[
"\n".join(
[
f"==> {inner_msg}",
f" > Left: {str1}",
f" > Right: {str2}",
]
)
for inner_msg, str1, str2 in mismatched
]
)
super().__init__(
f"""\
ShapeEnv not equal: {msg}
{details}
"""
)

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
class Equality:
def __init__(self, lhs, rhs):
self.lhs = lhs
self.rhs = rhs
def __str__(self):
return f'{self.lhs} = {self.rhs}'
def __repr__(self):
return f'{self.lhs} = {self.rhs}'
def __eq__(self, other):
if isinstance(other, Equality):
return self.lhs == other.lhs and self.rhs == other.rhs
else:
return False

View File

@ -0,0 +1,128 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import ast
import inspect
import textwrap
import copy
import functools
from types import FunctionType
from typing import cast, Union, Callable, Dict, Optional, Any
from torch.fx._symbolic_trace import Tracer
from torch.fx.graph import Graph
from torch._sources import normalize_source_lines
import torch
class AST_Rewriter(ast.NodeTransformer):
"""
Take a FunctionType object representing a `forward` method, then
perform an AST rewrite to swap out nodes that are not symbolically
traceable with a callsite to the FX alternative.
To support swapping out an AST node, define a new `visit` method on
that node. For more details, see:
https://docs.python.org/3/library/ast.html#ast.NodeTransformer
"""
# This function checks for new keys added in the globals dict. TorchDynamo
# can insert new keys in the global dict and upset the check. Therefore, put
# a disable here. This function is an optimization pass and not really
# suitable for dynamo tracing anyways.
@torch._dynamo.disable
def rewrite(self, fn: FunctionType):
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
normalized_str = textwrap.dedent(source)
# Rewrite the original AST
source_ast = ast.parse(normalized_str)
dest_ast = ast.fix_missing_locations(self.visit(source_ast))
# Pull out the compiled function from the newly-created Module
code = compile(dest_ast, "", "exec")
globals_dict = copy.copy(fn.__globals__)
keys_before = set(globals_dict.keys())
exec(code, globals_dict)
new_keys = list(set(globals_dict.keys()) - keys_before)
assert len(new_keys) == 1
fn_compiled = globals_dict[new_keys[0]]
# return the compiled function with the original globals
def change_func_globals(f, globals):
"""Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
# __globals__ is a private member of the function class
# so we have to copy the function, f, all of its member, except f.__globals__
g = FunctionType(
f.__code__,
globals,
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined]
return g
# Return the correct FunctionType object
return change_func_globals(fn_compiled, globals=fn.__globals__)
def visit_Assert(self, node):
"""
Swap out the Assert node (Python's `assert`) with a callsite to the
symbolically-traceable torch._assert function
"""
# Create the Call node
n = ast.parse('torch._assert()', mode='eval')
assert isinstance(n, ast.Expression)
call_node = n.body
assert isinstance(call_node, ast.Call)
msg = node.msg if node.msg else ast.Constant(value="", kind=None)
call_node.args = [node.test, msg]
# Ensure that the new node conforms to the Python AST grammar
expr_wrapper = ast.Expr(value=call_node)
# Return the new Call node to signify that we want to use it as
# a replacement for the original _assert node
return ast.copy_location(expr_wrapper, node)
def visit_AnnAssign(self, node):
"""
Swap out Python's AnnAssign with an Assign node where the annotation function is called.
Example:
Original:
y: Tensor_Type(1,2,3, Dyn) = f2(x)
Output:
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
"""
return ast.Assign(targets=[node.target], value=ast.Call(
func=ast.Name(id='annotate', ctx=ast.Load()),
args=[node.value, node.annotation], keywords=[]))
class RewritingTracer(Tracer):
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
return super().trace(_rewrite(root), concrete_args)
def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
if isinstance(fn, torch.nn.Module):
# Rewrite this module's `forward` as well as the `forward`s of
# all of this module's recursive descendents. Return the new,
# rewritten module hierarchy.
def rewrite_module(m : torch.nn.Module):
class RewrittenModule(torch.nn.Module):
def __init__(self, orig):
super().__init__()
for k, v in orig.__dict__.items():
if isinstance(v, torch.nn.Module):
self.__dict__[k] = copy.copy(rewrite_module(v))
else:
self.__dict__[k] = copy.copy(v)
RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
return RewrittenModule(m)
return rewrite_module(fn)
else:
# Rewrite this single free function
return AST_Rewriter().rewrite(cast(FunctionType, fn))

View File

@ -0,0 +1,112 @@
# mypy: allow-untyped-defs
import torch
import torch.fx
import inspect
from typing import Any, Dict, Optional, Tuple
from torch.fx.node import Argument, Target
from torch._jit_internal import boolean_dispatched
from torch.fx.operator_schemas import _torchscript_type_to_python_type
from torch.fx import Transformer
class AnnotateTypesWithSchema(Transformer):
"""
Use Python function signatures to annotate types for `Nodes` within an FX graph.
This pulls out Python function signatures for:
1. Standard `torch.nn` Module calls
2. `torch.nn.functional` calls
3. Attribute fetches via `get_attr`
Example usage:
m = torchvision.models.resnet18()
traced = torch.fx.symbolic_trace(m)
traced = AnnotateTypesWithSchema(traced).transform()
"""
def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
annotate_modules : bool = True, annotate_get_attrs : bool = True):
super().__init__(module)
self.annotate_functionals = annotate_functionals
self.annotate_modules = annotate_modules
self.annotate_get_attrs = annotate_get_attrs
def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
python_ret_type = None
if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
target_for_analysis = target
if target in boolean_dispatched:
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
# a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
# branches of the dispatch have exactly the same signature. If they do, use the `true`
# branch signature for analysis. Otherwise, leave this un-normalized
assert not isinstance(target, str)
dispatched = boolean_dispatched[target]
if_true, if_false = dispatched['if_true'], dispatched['if_false']
# TODO: can we emit the union of these? What are the implications on TorchScript
# compilation?
if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
return super().call_function(target, args, kwargs)
target_for_analysis = if_true
python_ret_type = self._extract_python_return_type(target_for_analysis)
return_proxy = super().call_function(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return return_proxy
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
python_ret_type = None
assert isinstance(target, str)
submod = self.fetch_attr(target)
if self.annotate_modules and hasattr(submod.__class__, '__name__'):
classname = submod.__class__.__name__
if getattr(torch.nn, classname, None) == submod.__class__:
python_ret_type = self._extract_python_return_type(submod.forward)
return_proxy = super().call_module(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return return_proxy
def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
attr_proxy = super().get_attr(target, args, kwargs)
if self.annotate_get_attrs:
module_itr = self.module
assert isinstance(target, str)
atoms = target.split('.')
for i, atom in enumerate(atoms):
if not hasattr(module_itr, atom):
raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
module_itr = getattr(module_itr, atom)
maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
if maybe_inferred_ts_type.success():
python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
return attr_proxy
def _extract_python_return_type(self, target : Target) -> Optional[Any]:
"""
Given a Python call target, try to extract the Python return annotation
if it is available, otherwise return None
Args:
target (Callable): Python callable to get return annotation for
Returns:
Optional[Any]: Return annotation from the `target`, or None if it was
not available.
"""
assert callable(target)
try:
sig = inspect.signature(target)
except (ValueError, TypeError):
return None
return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
# mypy: disable-error-code=attr-defined
from .core import unify, reify # noqa: F403
from .more import unifiable # noqa: F403
from .variable import var, isvar, vars, variables, Var # noqa: F403

View File

@ -0,0 +1,119 @@
# mypy: allow-untyped-defs
from collections.abc import Iterator # type: ignore[import]
from functools import partial
from .unification_tools import assoc # type: ignore[import]
from .utils import transitive_get as walk
from .variable import isvar
from .dispatch import dispatch
__all__ = ["reify", "unify"]
###############
# Reification #
###############
@dispatch(Iterator, dict)
def _reify(t, s):
return map(partial(reify, s=s), t)
# return (reify(arg, s) for arg in t)
_reify
@dispatch(tuple, dict) # type: ignore[no-redef]
def _reify(t, s):
return tuple(reify(iter(t), s))
_reify
@dispatch(list, dict) # type: ignore[no-redef]
def _reify(t, s):
return list(reify(iter(t), s))
_reify
@dispatch(dict, dict) # type: ignore[no-redef]
def _reify(d, s):
return {k: reify(v, s) for k, v in d.items()}
_reify
@dispatch(object, dict) # type: ignore[no-redef]
def _reify(o, s):
return o # catch all, just return the object
def reify(e, s):
""" Replace variables of expression with substitution
>>> # xdoctest: +SKIP
>>> x, y = var(), var()
>>> e = (1, x, (3, y))
>>> s = {x: 2, y: 4}
>>> reify(e, s)
(1, 2, (3, 4))
>>> e = {1: x, 3: (y, 5)}
>>> reify(e, s)
{1: 2, 3: (4, 5)}
"""
if isvar(e):
return reify(s[e], s) if e in s else e
return _reify(e, s)
###############
# Unification #
###############
seq = tuple, list, Iterator
@dispatch(seq, seq, dict)
def _unify(u, v, s):
if len(u) != len(v):
return False
for uu, vv in zip(u, v): # avoiding recursion
s = unify(uu, vv, s)
if s is False:
return False
return s
#
# @dispatch((set, frozenset), (set, frozenset), dict)
# def _unify(u, v, s):
# i = u & v
# u = u - i
# v = v - i
# return _unify(sorted(u), sorted(v), s)
#
#
# @dispatch(dict, dict, dict)
# def _unify(u, v, s):
# if len(u) != len(v):
# return False
# for key, uval in iteritems(u):
# if key not in v:
# return False
# s = unify(uval, v[key], s)
# if s is False:
# return False
# return s
#
#
# @dispatch(object, object, dict)
# def _unify(u, v, s):
# return False # catch all
@dispatch(object, object, dict)
def unify(u, v, s): # no check at the moment
""" Find substitution so that u == v while satisfying s
>>> x = var('x')
>>> unify((1, x), (1, 2), {})
{~x: 2}
"""
u = walk(u, s)
v = walk(v, s)
if u == v:
return s
if isvar(u):
return assoc(s, u, v)
if isvar(v):
return assoc(s, v, u)
return _unify(u, v, s)
unify
@dispatch(object, object) # type: ignore[no-redef]
def unify(u, v):
return unify(u, v, {})

View File

@ -0,0 +1,6 @@
from functools import partial
from .multipledispatch import dispatch # type: ignore[import]
namespace = {} # type: ignore[var-annotated]
dispatch = partial(dispatch, namespace=namespace)

View File

@ -0,0 +1,122 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .variable import isvar
from .utils import _toposort, freeze
from .unification_tools import groupby, first # type: ignore[import]
class Dispatcher:
def __init__(self, name):
self.name = name
self.funcs = {}
self.ordering = []
def add(self, signature, func):
self.funcs[freeze(signature)] = func
self.ordering = ordering(self.funcs)
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
return func(*args, **kwargs)
def resolve(self, args):
n = len(args)
for signature in self.ordering:
if len(signature) != n:
continue
s = unify(freeze(args), signature)
if s is not False:
result = self.funcs[signature]
return result, s
raise NotImplementedError("No match found. \nKnown matches: "
+ str(self.ordering) + "\nInput: " + str(args))
def register(self, *signature):
def _(func):
self.add(signature, func)
return self
return _
class VarDispatcher(Dispatcher):
""" A dispatcher that calls functions with variable names
>>> # xdoctest: +SKIP
>>> d = VarDispatcher('d')
>>> x = var('x')
>>> @d.register('inc', x)
... def f(x):
... return x + 1
>>> @d.register('double', x)
... def f(x):
... return x * 2
>>> d('inc', 10)
11
>>> d('double', 10)
20
"""
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = {k.token: v for k, v in s.items()}
return func(**d)
global_namespace = {} # type: ignore[var-annotated]
def match(*signature, **kwargs):
namespace = kwargs.get('namespace', global_namespace)
dispatcher = kwargs.get('Dispatcher', Dispatcher)
def _(func):
name = func.__name__
if name not in namespace:
namespace[name] = dispatcher(name)
d = namespace[name]
d.add(signature, func)
return d
return _
def supercedes(a, b):
""" ``a`` is a more specific match than ``b`` """
if isvar(b) and not isvar(a):
return True
s = unify(a, b)
if s is False:
return False
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
if reify(a, s) == a:
return True
if reify(b, s) == b:
return False
# Taken from multipledispatch
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
if supercedes(a, b):
if supercedes(b, a):
return tie_breaker(a) > tie_breaker(b)
else:
return True
return False
# Taken from multipledispatch
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(first, edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
return _toposort(edges)

View File

@ -0,0 +1,118 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .dispatch import dispatch
def unifiable(cls):
""" Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
>>> # xdoctest: +SKIP
>>> class A(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> unifiable(A)
<class 'unification.more.A'>
>>> x = var('x')
>>> a = A(1, 2)
>>> b = A(1, x)
>>> unify(a, b, {})
{~x: 2}
"""
_unify.add((cls, cls, dict), unify_object)
_reify.add((cls, dict), reify_object)
return cls
#########
# Reify #
#########
def reify_object(o, s):
""" Reify a Python object with a substitution
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
>>> f = Foo(1, x)
>>> print(f)
Foo(1, ~x)
>>> print(reify_object(f, {x: 2}))
Foo(1, 2)
"""
if hasattr(o, '__slots__'):
return _reify_object_slots(o, s)
else:
return _reify_object_dict(o, s)
def _reify_object_dict(o, s):
obj = object.__new__(type(o))
d = reify(o.__dict__, s)
if d == o.__dict__:
return o
obj.__dict__.update(d)
return obj
def _reify_object_slots(o, s):
attrs = [getattr(o, attr) for attr in o.__slots__]
new_attrs = reify(attrs, s)
if attrs == new_attrs:
return o
else:
newobj = object.__new__(type(o))
for slot, attr in zip(o.__slots__, new_attrs):
setattr(newobj, slot, attr)
return newobj
@dispatch(slice, dict)
def _reify(o, s):
""" Reify a Python ``slice`` object """
return slice(*reify((o.start, o.stop, o.step), s))
#########
# Unify #
#########
def unify_object(u, v, s):
""" Unify two Python objects
Unifies their type and ``__dict__`` attributes
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
>>> f = Foo(1, x)
>>> g = Foo(1, 2)
>>> unify_object(f, g, {})
{~x: 2}
"""
if type(u) != type(v):
return False
if hasattr(u, '__slots__'):
return unify([getattr(u, slot) for slot in u.__slots__],
[getattr(v, slot) for slot in v.__slots__],
s)
else:
return unify(u.__dict__, v.__dict__, s)
@dispatch(slice, slice, dict)
def _unify(u, v, s):
""" Unify a Python ``slice`` object """
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)

View File

@ -0,0 +1,3 @@
from .core import dispatch
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
MDNotImplementedError)

View File

@ -0,0 +1,121 @@
# mypy: allow-untyped-defs
from .utils import _toposort, groupby
from .variadic import isvariadic
import operator
__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
"edge", "ordering"]
class AmbiguityWarning(Warning):
pass
def supercedes(a, b):
""" A is consistent and strictly more specific than B """
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
elif len(a) == len(b):
return all(map(issubclass, a, b))
else:
# len(a) > len(b)
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not (isvariadic(cur_a) or isvariadic(cur_b)):
if not issubclass(cur_a, cur_b):
return False
p1 += 1
p2 += 1
elif isvariadic(cur_a):
assert p1 == len(a) - 1
return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
elif isvariadic(cur_b):
assert p2 == len(b) - 1
if not issubclass(cur_a, cur_b):
return False
p1 += 1
return p2 == len(b) - 1 and p1 == len(a)
def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
# Need to check for empty args
if not a:
return not b or isvariadic(b[0])
if not b:
return not a or isvariadic(a[0])
# Non-empty args check for mutual subclasses
if len(a) == len(b):
return all(issubclass(aa, bb) or issubclass(bb, aa)
for aa, bb in zip(a, b))
else:
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
return False
if not (isvariadic(cur_a) or isvariadic(cur_b)):
p1 += 1
p2 += 1
elif isvariadic(cur_a):
p2 += 1
elif isvariadic(cur_b):
p1 += 1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
def ambiguous(a, b):
""" A is consistent with B but neither is strictly more specific """
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
signatures = list(map(tuple, signatures))
return {(a, b) for a in signatures for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b)
for c in signatures)}
def super_signature(signatures):
""" A signature that would break ambiguities """
n = len(signatures[0])
assert all(len(s) == n for s in signatures)
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
for i in range(n)]
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
# A either supercedes B and B does not supercede A or if B does then call
# tie_breaker
return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(operator.itemgetter(0), edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined]
return _toposort(edges)

View File

@ -0,0 +1,84 @@
# mypy: allow-untyped-defs
import inspect
import sys
from .dispatcher import Dispatcher, MethodDispatcher
global_namespace = {} # type: ignore[var-annotated]
__all__ = ["dispatch", "ismethod"]
def dispatch(*types, **kwargs):
""" Dispatch function on the types of the inputs
Supports dispatch on all non-keyword arguments.
Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity.
Example:
>>> # xdoctest: +SKIP
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> # xdoctest: +SKIP
>>> f(3)
4
>>> f(3.0)
2.0
>>> # Specify an isolated namespace with the namespace keyword argument
>>> my_namespace = {}
>>> @dispatch(int, namespace=my_namespace)
... def foo(x):
... return x + 1
>>> # Dispatch on instance methods within classes
>>> class MyClass(object):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
... @dispatch(int)
... def __init__(self, datum):
... self.data = [datum]
>>> MyClass([1, 2, 3]).data
[1, 2, 3]
>>> MyClass(3).data
[3]
"""
namespace = kwargs.get('namespace', global_namespace)
types = tuple(types)
def _df(func):
name = func.__name__
if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr]
name, # type: ignore[union-attr]
MethodDispatcher(name),
)
else:
if name not in namespace:
namespace[name] = Dispatcher(name)
dispatcher = namespace[name]
dispatcher.add(types, func)
return dispatcher
return _df
def ismethod(func):
""" Is func a method?
Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
if hasattr(inspect, "signature"):
signature = inspect.signature(func)
return signature.parameters.get('self', None) is not None
else:
if sys.version_info.major < 3:
spec = inspect.getargspec(func) # type: ignore[attr-defined]
else:
spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment]
return spec and spec.args and spec.args[0] == 'self'

View File

@ -0,0 +1,427 @@
# mypy: allow-untyped-defs
from warnings import warn
import inspect
from typing_extensions import deprecated
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
from .variadic import Variadic, isvariadic
import itertools as itl
__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
"variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
warning_text
"""
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
@deprecated(
"`halt_ordering` is deprecated, you can safely remove this call.",
category=FutureWarning,
)
def halt_ordering():
"""Deprecated interface to temporarily disable ordering."""
@deprecated(
"`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, "
"you should call the `reorder()` method on each dispatcher.",
category=FutureWarning,
)
def restart_ordering(on_ambiguity=ambiguity_warn):
"""Deprecated interface to temporarily resume ordering."""
def variadic_signature_matches_iter(types, full_signature):
"""Check if a set of input types matches a variadic signature.
Notes
-----
The algorithm is as follows:
Initialize the current signature to the first in the sequence
For each type in `types`:
If the current signature is variadic
If the type matches the signature
yield True
Else
Try to get the next signature
If no signatures are left we can't possibly have a match
so yield False
Else
yield True if the type matches the current signature
Get the next signature
"""
sigiter = iter(full_signature)
sig = next(sigiter)
for typ in types:
matches = issubclass(typ, sig)
yield matches
if not isvariadic(sig):
# we're not matching a variadic argument, so move to the next
# element in the signature
sig = next(sigiter)
else:
try:
sig = next(sigiter)
except StopIteration:
assert isvariadic(sig)
yield True
else:
# We have signature items left over, so all of our arguments
# haven't matched
yield False
def variadic_signature_matches(types, full_signature):
# No arguments always matches a variadic signature
assert full_signature
return all(variadic_signature_matches_iter(types, full_signature))
class Dispatcher:
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
>>> # xdoctest: +SKIP("bad import name")
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
"""
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc
self._cache = {}
def register(self, *types, **kwargs):
""" register dispatcher with new implementation
>>> # xdoctest: +SKIP
>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1
>>> @f.register(float)
... def dec(x):
... return x - 1
>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]
>>> f(1)
2
>>> f(1.0)
0.0
>>> f([1, 2, 3])
[3, 2, 1]
"""
def _df(func):
self.add(types, func, **kwargs) # type: ignore[call-arg]
return func
return _df
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return sig.parameters.values()
@classmethod
def get_func_annotations(cls, func):
""" get annotations of function positional parameters
"""
params = cls.get_func_params(func)
if params:
Parameter = inspect.Parameter
params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))
annotations = tuple(
param.annotation
for param in params)
if all(ann is not Parameter.empty for ann in annotations):
return annotations
def add(self, signature, func):
""" Add new types/method pair to dispatcher
>>> # xdoctest: +SKIP
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
>>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
>>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
>>> # as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
annotations = self.get_func_annotations(func)
if annotations:
signature = annotations
# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return
new_signature = []
for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
f"In signature: <{str_sig}>\n"
f"In function: {self.name}")
# handle variadic signatures
if isinstance(typ, list):
if index != len(signature):
raise TypeError(
'Variadic signature must be the last element'
)
if len(typ) != 1:
raise TypeError(
'Variadic signature must contain exactly one element. '
'To use a variadic union type place the desired types '
'inside of a tuple, e.g., [(int, str)]'
)
new_signature.append(Variadic[typ[0]])
else:
new_signature.append(typ)
self.funcs[tuple(new_signature)] = func
self._cache.clear()
try:
del self._ordering
except AttributeError:
pass
@property
def ordering(self):
try:
return self._ordering
except AttributeError:
return self.reorder()
def reorder(self, on_ambiguity=ambiguity_warn):
self._ordering = od = ordering(self.funcs)
amb = ambiguities(self.funcs)
if amb:
on_ambiguity(self, amb)
return od
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError as e:
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
self._cache[types] = func
try:
return func(*args, **kwargs)
except MDNotImplementedError as e:
funcs = self.dispatch_iter(*types)
next(funcs) # burn first
for func in funcs:
try:
return func(*args, **kwargs)
except MDNotImplementedError:
pass
raise NotImplementedError(
"Matching functions for "
f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
def __str__(self):
return f"<dispatched {self.name}>"
__repr__ = __str__
def dispatch(self, *types):
"""Determine appropriate implementation for this type signature
This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.
>>> # xdoctest: +SKIP
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1
>>> implementation = inc.dispatch(int)
>>> implementation(3)
4
>>> print(inc.dispatch(float))
None
See Also:
``multipledispatch.conflict`` - module to determine resolution order
"""
if types in self.funcs:
return self.funcs[types]
try:
return next(self.dispatch_iter(*types))
except StopIteration:
return None
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
result = self.funcs[signature]
yield result
elif len(signature) and isvariadic(signature[-1]):
if variadic_signature_matches(types, signature):
result = self.funcs[signature]
yield result
@deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning)
def resolve(self, types):
""" Determine appropriate implementation for this type signature
.. deprecated:: 0.4.4
Use ``dispatch(*types)`` instead
"""
return self.dispatch(*types)
def __getstate__(self):
return {'name': self.name,
'funcs': self.funcs}
def __setstate__(self, d):
self.name = d['name']
self.funcs = d['funcs']
self._ordering = ordering(self.funcs)
self._cache = {}
@property
def __doc__(self):
docs = [f"Multiply dispatched method: {self.name}"]
if self.doc:
docs.append(self.doc)
other = []
for sig in self.ordering[::-1]:
func = self.funcs[sig]
if func.__doc__:
s = f'Inputs: <{str_signature(sig)}>\n'
s += '-' * len(s) + '\n'
s += func.__doc__.strip()
docs.append(s)
else:
other.append(str_signature(sig))
if other:
docs.append('Other signatures:\n ' + '\n '.join(other))
return '\n\n'.join(docs)
def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
print(self._help(*args))
def _source(self, *args):
func = self.dispatch(*map(type, args))
if not func:
raise TypeError("No function found")
return source(func)
def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))
def source(func):
s = f'File: {inspect.getsourcefile(func)}\n\n'
s = s + inspect.getsource(func)
return s
class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature
See Also:
Dispatcher
"""
__slots__ = ('obj', 'cls')
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
return func(self.obj, *args, **kwargs)
def str_signature(sig):
""" String representation of type signature
>>> str_signature((int, float))
'int, float'
"""
return ', '.join(cls.__name__ for cls in sig)
def warning_text(name, amb):
""" The text for ambiguity warnings """
text = f"\nAmbiguities exist in dispatched function {name}\n\n"
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + \
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+ f')\ndef {name}(...)' for s in amb])
return text

View File

@ -0,0 +1,126 @@
# mypy: allow-untyped-defs
from collections import OrderedDict
__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
def raises(err, lamda):
try:
lamda()
return False
except err:
return True
def expand_tuples(L):
"""
>>> expand_tuples([1, (2, 3)])
[(1, 2), (1, 3)]
>>> expand_tuples([1, 2])
[(1, 2)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
>>> # Closely follows the wikipedia page [2]
>>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
>>> # Communications of the ACM
>>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = OrderedDict((k, set(val))
for k, val in incoming_edges.items())
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
L = []
while S:
n, _ = S.popitem()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S[m] = None
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = OrderedDict() # type: ignore[var-annotated]
for key in d:
for val in d[key]:
result[val] = result.get(val, ()) + (key,)
return result
# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def groupby(func, seq):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
See Also:
``countby``
"""
d = OrderedDict() # type: ignore[var-annotated]
for item in seq:
key = func(item)
if key not in d:
d[key] = []
d[key].append(item)
return d
def typename(type):
"""Get the name of `type`.
Parameters
----------
type : Union[Type, Tuple[Type]]
Returns
-------
str
The name of `type` or a tuple of the names of the types in `type`.
Examples
--------
>>> typename(int)
'int'
>>> typename((int, float))
'(int, float)'
"""
try:
return type.__name__
except AttributeError:
if len(type) == 1:
return typename(*type)
return f"({', '.join(map(typename, type))})"

View File

@ -0,0 +1,92 @@
# mypy: allow-untyped-defs
from .utils import typename
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
class VariadicSignatureType(type):
# checking if subclass is a subclass of self
def __subclasscheck__(cls, subclass):
other_type = (subclass.variadic_type if isvariadic(subclass)
else (subclass,))
return subclass is cls or all(
issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
)
def __eq__(cls, other):
"""
Return True if other has the same variadic type
Parameters
----------
other : object (type)
The object (type) to check
Returns
-------
bool
Whether or not `other` is equal to `self`
"""
return (isvariadic(other) and
set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
def __hash__(cls):
return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
def isvariadic(obj):
"""Check whether the type `obj` is variadic.
Parameters
----------
obj : type
The type to check
Returns
-------
bool
Whether or not `obj` is variadic
Examples
--------
>>> # xdoctest: +SKIP
>>> isvariadic(int)
False
>>> isvariadic(Variadic[int])
True
"""
return isinstance(obj, VariadicSignatureType)
class VariadicSignatureMeta(type):
"""A metaclass that overrides ``__getitem__`` on the class. This is used to
generate a new type for Variadic signatures. See the Variadic class for
examples of how this behaves.
"""
def __getitem__(cls, variadic_type):
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
raise ValueError("Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]")
if not isinstance(variadic_type, tuple):
variadic_type = variadic_type,
return VariadicSignatureType(
f'Variadic[{typename(variadic_type)}]',
(),
dict(variadic_type=variadic_type, __slots__=())
)
class Variadic(metaclass=VariadicSignatureMeta):
"""A class whose getitem method can be used to generate a new type
representing a specific variadic signature.
Examples
--------
>>> # xdoctest: +SKIP
>>> Variadic[int] # any number of int arguments
<class 'multipledispatch.variadic.Variadic[int]'>
>>> Variadic[(int, str)] # any number of one of int or str arguments
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
>>> issubclass(int, Variadic[int])
True
>>> issubclass(int, Variadic[(int, str)])
True
>>> issubclass(str, Variadic[(int, str)])
True
>>> issubclass(float, Variadic[(int, str)])
False
"""

View File

@ -0,0 +1,396 @@
# mypy: allow-untyped-defs
import collections
import operator
from functools import reduce
from collections.abc import Mapping
__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
'valfilter', 'keyfilter', 'itemfilter',
'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in']
def _get_factory(f, kwargs):
factory = kwargs.pop('factory', dict)
if kwargs:
raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
return factory
def merge(*dicts, **kwargs):
""" Merge a collection of dictionaries
>>> merge({1: 'one'}, {2: 'two'})
{1: 'one', 2: 'two'}
Later dictionaries have precedence
>>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
{1: 2, 3: 3, 4: 4}
See Also:
merge_with
"""
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
dicts = dicts[0]
factory = _get_factory(merge, kwargs)
rv = factory()
for d in dicts:
rv.update(d)
return rv
def merge_with(func, *dicts, **kwargs):
""" Merge dictionaries and apply function to combined values
A key may occur in more than one dict, and all values mapped from the key
will be passed to the function as a list, such as func([val1, val2, ...]).
>>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
{1: 11, 2: 22}
>>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
{1: 1, 2: 2, 3: 30}
See Also:
merge
"""
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
dicts = dicts[0]
factory = _get_factory(merge_with, kwargs)
result = factory()
for d in dicts:
for k, v in d.items():
if k not in result:
result[k] = [v]
else:
result[k].append(v)
return valmap(func, result, factory)
def valmap(func, d, factory=dict):
""" Apply function to values of dictionary
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
>>> valmap(sum, bills) # doctest: +SKIP
{'Alice': 65, 'Bob': 45}
See Also:
keymap
itemmap
"""
rv = factory()
rv.update(zip(d.keys(), map(func, d.values())))
return rv
def keymap(func, d, factory=dict):
""" Apply function to keys of dictionary
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
>>> keymap(str.lower, bills) # doctest: +SKIP
{'alice': [20, 15, 30], 'bob': [10, 35]}
See Also:
valmap
itemmap
"""
rv = factory()
rv.update(zip(map(func, d.keys()), d.values()))
return rv
def itemmap(func, d, factory=dict):
""" Apply function to items of dictionary
>>> accountids = {"Alice": 10, "Bob": 20}
>>> itemmap(reversed, accountids) # doctest: +SKIP
{10: "Alice", 20: "Bob"}
See Also:
keymap
valmap
"""
rv = factory()
rv.update(map(func, d.items()))
return rv
def valfilter(predicate, d, factory=dict):
""" Filter items in dictionary by value
>>> iseven = lambda x: x % 2 == 0
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
>>> valfilter(iseven, d)
{1: 2, 3: 4}
See Also:
keyfilter
itemfilter
valmap
"""
rv = factory()
for k, v in d.items():
if predicate(v):
rv[k] = v
return rv
def keyfilter(predicate, d, factory=dict):
""" Filter items in dictionary by key
>>> iseven = lambda x: x % 2 == 0
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
>>> keyfilter(iseven, d)
{2: 3, 4: 5}
See Also:
valfilter
itemfilter
keymap
"""
rv = factory()
for k, v in d.items():
if predicate(k):
rv[k] = v
return rv
def itemfilter(predicate, d, factory=dict):
""" Filter items in dictionary by item
>>> def isvalid(item):
... k, v = item
... return k % 2 == 0 and v < 4
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
>>> itemfilter(isvalid, d)
{2: 3}
See Also:
keyfilter
valfilter
itemmap
"""
rv = factory()
for item in d.items():
if predicate(item):
k, v = item
rv[k] = v
return rv
def assoc(d, key, value, factory=dict):
""" Return a new dict with new key value pair
New dict has d[key] set to value. Does not modify the initial dictionary.
>>> assoc({'x': 1}, 'x', 2)
{'x': 2}
>>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
{'x': 1, 'y': 3}
"""
d2 = factory()
d2.update(d)
d2[key] = value
return d2
def dissoc(d, *keys, **kwargs):
""" Return a new dict with the given key(s) removed.
New dict has d[key] deleted for each supplied key.
Does not modify the initial dictionary.
>>> dissoc({'x': 1, 'y': 2}, 'y')
{'x': 1}
>>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
{}
>>> dissoc({'x': 1}, 'y') # Ignores missing keys
{'x': 1}
"""
factory = _get_factory(dissoc, kwargs)
d2 = factory()
if len(keys) < len(d) * .6:
d2.update(d)
for key in keys:
if key in d2:
del d2[key]
else:
remaining = set(d)
remaining.difference_update(keys)
for k in remaining:
d2[k] = d[k]
return d2
def assoc_in(d, keys, value, factory=dict):
""" Return a new dict with new, potentially nested, key value pair
>>> purchase = {'name': 'Alice',
... 'order': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
{'credit card': '5555-1234-1234-1234',
'name': 'Alice',
'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
"""
return update_in(d, keys, lambda x: value, value, factory)
def update_in(d, keys, func, default=None, factory=dict):
""" Update value in a (potentially) nested dictionary
inputs:
d - dictionary on which to operate
keys - list or tuple giving the location of the value to be changed in d
func - function to operate on that value
If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
original dictionary with v replaced by func(v), but does not mutate the
original dictionary.
If k0 is not a key in d, update_in creates nested dictionaries to the depth
specified by the keys, with the innermost value set to func(default).
>>> inc = lambda x: x + 1
>>> update_in({'a': 0}, ['a'], inc)
{'a': 1}
>>> transaction = {'name': 'Alice',
... 'purchase': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
{'credit card': '5555-1234-1234-1234',
'name': 'Alice',
'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
>>> # updating a value when k0 is not in d
>>> update_in({}, [1, 2, 3], str, default="bar")
{1: {2: {3: 'bar'}}}
>>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
{1: 'foo', 2: {3: {4: 1}}}
"""
ks = iter(keys)
k = next(ks)
rv = inner = factory()
rv.update(d)
for key in ks:
if k in d:
d = d[k]
dtemp = factory()
dtemp.update(d)
else:
d = dtemp = factory()
inner[k] = inner = dtemp
k = key
if k in d:
inner[k] = func(d[k])
else:
inner[k] = func(default)
return rv
def get_in(keys, coll, default=None, no_default=False):
""" Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
``no_default`` is specified, then it raises KeyError or IndexError.
``get_in`` is a generalization of ``operator.getitem`` for nested data
structures such as dictionaries and lists.
>>> transaction = {'name': 'Alice',
... 'purchase': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> get_in(['purchase', 'items', 0], transaction)
'Apple'
>>> get_in(['name'], transaction)
'Alice'
>>> get_in(['purchase', 'total'], transaction)
>>> get_in(['purchase', 'items', 'apple'], transaction)
>>> get_in(['purchase', 'items', 10], transaction)
>>> get_in(['purchase', 'total'], transaction, 0)
0
>>> get_in(['y'], {}, no_default=True)
Traceback (most recent call last):
...
KeyError: 'y'
See Also:
itertoolz.get
operator.getitem
"""
try:
return reduce(operator.getitem, keys, coll)
except (KeyError, IndexError, TypeError):
if no_default:
raise
return default
def getter(index):
if isinstance(index, list):
if len(index) == 1:
index = index[0]
return lambda x: (x[index],)
elif index:
return operator.itemgetter(*index)
else:
return lambda x: ()
else:
return operator.itemgetter(index)
def groupby(key, seq):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
Non-callable keys imply grouping on a member.
>>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
... {'name': 'Bob', 'gender': 'M'},
... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
{'F': [{'gender': 'F', 'name': 'Alice'}],
'M': [{'gender': 'M', 'name': 'Bob'},
{'gender': 'M', 'name': 'Charlie'}]}
Not to be confused with ``itertools.groupby``
See Also:
countby
"""
if not callable(key):
key = getter(key)
d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
for item in seq:
d[key(item)](item)
rv = {}
for k, v in d.items():
rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
return rv
def first(seq):
""" The first element in a sequence
>>> first('ABC')
'A'
"""
return next(iter(seq))

View File

@ -0,0 +1,106 @@
# mypy: allow-untyped-defs
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
def hashable(x):
try:
hash(x)
return True
except TypeError:
return False
def transitive_get(key, d):
""" Transitive dict.get
>>> d = {1: 2, 2: 3, 3: 4}
>>> d.get(1)
2
>>> transitive_get(1, d)
4
"""
while hashable(key) and key in d:
key = d[key]
return key
def raises(err, lamda):
try:
lamda()
return False
except err:
return True
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> # xdoctest: +SKIP
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = ({v for v in edges if v not in incoming_edges})
L = []
while S:
n = S.pop()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = {} # type: ignore[var-annotated]
for key in d:
for val in d[key]:
result[val] = result.get(val, ()) + (key,)
return result
def xfail(func):
try:
func()
raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002
except Exception:
pass
def freeze(d):
""" Freeze container to hashable form
>>> freeze(1)
1
>>> freeze([1, 2])
(1, 2)
>>> freeze({1: 2}) # doctest: +SKIP
frozenset([(1, 2)])
"""
if isinstance(d, dict):
return frozenset(map(freeze, d.items()))
if isinstance(d, set):
return frozenset(map(freeze, d))
if isinstance(d, (tuple, list)):
return tuple(map(freeze, d))
return d

View File

@ -0,0 +1,86 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from .utils import hashable
from .dispatch import dispatch
_global_logic_variables = set() # type: ignore[var-annotated]
_glv = _global_logic_variables
class Var:
""" Logic Variable """
_id = 1
def __new__(cls, *token):
if len(token) == 0:
token = f"_{Var._id}" # type: ignore[assignment]
Var._id += 1
elif len(token) == 1:
token = token[0]
obj = object.__new__(cls)
obj.token = token # type: ignore[attr-defined]
return obj
def __str__(self):
return "~" + str(self.token) # type: ignore[attr-defined]
__repr__ = __str__
def __eq__(self, other):
return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined]
def __hash__(self):
return hash((type(self), self.token)) # type: ignore[attr-defined]
def var():
return lambda *args: Var(*args)
def vars():
return lambda n: [var() for i in range(n)]
@dispatch(Var)
def isvar(v):
return True
isvar
@dispatch(object) # type: ignore[no-redef]
def isvar(o):
return not not _glv and hashable(o) and o in _glv
@contextmanager
def variables(*variables):
"""
Context manager for logic variables
Example:
>>> # xdoctest: +SKIP("undefined vars")
>>> from __future__ import with_statement
>>> with variables(1):
... print(isvar(1))
True
>>> print(isvar(1))
False
>>> # Normal approach
>>> from unification import unify
>>> x = var('x')
>>> unify(x, 1)
{~x: 1}
>>> # Context Manager approach
>>> with variables('x'):
... print(unify('x', 1))
{'x': 1}
"""
old_global_logic_variables = _global_logic_variables.copy()
_global_logic_variables.update(set(variables))
try:
yield
finally:
_global_logic_variables.clear()
_global_logic_variables.update(old_global_logic_variables)

View File

@ -0,0 +1,121 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.graph_gradual_typechecker import Refine
from torch.fx.tensor_type import TensorType
from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
def infer_symbolic_types_single_pass(traced):
"""
Calls our symbolic inferencer once.
"""
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
def infer_symbolic_types(traced):
"""
Calls our symbolic inferencer twice.
This is useful when one pass is not enough
to infer all the information such as the case
for braodcasting.
"""
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
r.symbolic_relations()
def convert_eq(list_of_eq):
"""
Convert equality constraints in the right format
to be used by unification library.
"""
lhs = []
rhs = []
for eq in list_of_eq:
lhs.append(eq.lhs)
rhs.append(eq.rhs)
return tuple(lhs), tuple(rhs)
def unify_eq(list_of_eq):
"""
Apply unification to a set of
equality constraints
"""
lhs, rhs = convert_eq(list_of_eq)
return unify(lhs, rhs)
def substitute_solution_one_type(mapping, t):
"""
Apply the most general unifier to a type
"""
if isinstance(t, Var):
if t in mapping.keys():
return mapping[t]
else:
return t
elif isinstance(t, TensorType):
new_type = []
for typ in t.__args__:
if typ in mapping.keys():
new_type.append(mapping[typ])
else:
new_type.append(typ)
return TensorType(tuple(new_type))
elif isinstance(t, list):
new_type = []
for typ in t:
new_type.append(substitute_solution_one_type(mapping, typ))
return new_type
elif isinstance(t, tuple):
new_type = []
for typ in t:
new_type.append(substitute_solution_one_type(mapping, typ))
return tuple(new_type)
else:
return t
def substitute_all_types(graph, mapping):
"""
Apply the most general unifier to all types in a graph
till reaching a fixed point. If the input and output graph
are the same, we converge.
"""
flag = True
while flag:
flag = False
for k in mapping:
old_mapping_val = mapping[k]
if mapping[k] in mapping.keys():
new_key = mapping[k]
mapping[k] = mapping[new_key]
if old_mapping_val != mapping[k]:
flag = True
for n in graph.nodes:
n.type = substitute_solution_one_type(mapping, n.type)
def check_for_type_equality(g1, g2):
"""
A check equality to be used in fixed points.
We do not use graph equality but instead type
equality.
"""
for n, m in zip(g1.nodes, g2.nodes):
if n.type != m.type:
return False
return True

View File

@ -0,0 +1,787 @@
# mypy: allow-untyped-defs
import functools
import logging
import math
import operator
import sympy
import builtins
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
import torch.fx
import torch.fx.traceback as fx_traceback
from torch._dynamo.exc import TorchDynamoException
from torch.fx.node import Argument, Target
from torch.utils._sympy.interp import sympy_interp
from torch._dynamo.utils import dynamo_timed
log = logging.getLogger(__name__)
try:
import z3 # type: ignore[import]
# Translation Validation for Dynamo guards
# ========================================
#
# Checks whether optimizations applied to the collected guards are
# valid. In other words, whether the guard function we actually run
# does not have false positives (unsound).
#
# In order to do so, we build the guards using 2 different information
# attached to each 'SymNode':
# 1. SymPy expressions
# 2. FX nodes
#
# SymPy expressions have implicit optimizations baked within itself,
# which may have a few bugs. On the other hand, we build the FX graph
# manually, with no optimizations enabled. This gives us access to
# the "ground truth".
#
# We then convert into Z3 expressions both the SymPy expressions
# (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function
# and the FX nodes (see [Note: PopulateValidator]) that go through
# 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
# (see [Note: TranslationValidator])
# Better Z3 to string implementation (for a small fraction of Z3).
#
# Here are the things we clean before showing the Z3 expression:
# - Rename a few ops (e.g. "Distinct" ==> "!=")
#
# - Ignore ToInt and ToReal operations:
# usually they don't really matter
#
# - Transform (ToInt (/ ...)) into (idiv ...):
# this is the pattern for floor division
#
# - Collect a chain of the same operations into one
def z3str(e: z3.ExprRef) -> str:
assert z3.is_expr(e), f"unsupported expression type: {e}"
def get_args_str(e: z3.ExprRef) -> List[str]:
return [z3str(e.arg(i)) for i in range(e.num_args())]
# First, we simplify the given expression.
# This is done using rewriting rules, so shouldn't take long.
e = z3.simplify(e)
# Only support function applications.
# Even Z3 "variables" are, in fact, function applications.
if not z3.is_app(e):
raise ValueError(f"can't print Z3 expression: {e}")
if z3.is_int_value(e) or z3.is_rational_value(e):
return e.as_string() # type: ignore[attr-defined]
decl = e.decl()
kind = decl.kind()
op = str(decl)
args = get_args_str(e)
if kind == z3.Z3_OP_POWER:
op = "pow"
elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
# Collect the arguments of chains of ADD and MUL.
# This is safe, since they are associative.
def collect_str_args(e):
if not (z3.is_app(e) and e.decl().kind() == kind):
return [z3str(e)]
else:
return [
x
for i in range(e.num_args())
for x in collect_str_args(e.arg(i))
]
args = collect_str_args(e)
elif kind == z3.Z3_OP_NOT:
# Revert some conversions that z3.simplify applies:
# - a != b ==> (Not (== a b)) ==> (!= a b)
# - a < b ==> (Not (<= b a)) ==> (> b a)
# - a > b ==> (Not (<= a b)) ==> (> a b)
assert e.num_args() == 1
arg = e.arg(0)
assert z3.is_app(arg)
argkind = arg.decl().kind()
logic_inverse = {
z3.Z3_OP_EQ: "!=",
z3.Z3_OP_LE: ">",
z3.Z3_OP_GE: "<",
}
if argkind in logic_inverse:
op = logic_inverse[argkind]
args = get_args_str(arg)
elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
assert e.num_args() == 1
argstr = z3str(e.arg(0))
# Check if it's the floor division pattern.
if argstr.startswith("(/"):
return "(idiv" + argstr[2:]
# Otherwise, just ignore it.
return argstr
elif kind == z3.Z3_OP_UNINTERPRETED:
assert e.num_args() == 0
return str(decl)
string = op + " " + " ".join(args)
return f"({string.rstrip()})"
# Implementation of Python semantics as Z3 expressions.
#
# Z3 Real-Int theory has operators with semantics that differ that of
# Python. Therefore, in order to get it right, we need to implement
# the (Python) semantics we are relying on in Z3.
@dataclass
class _Z3Ops:
# Validator used for adding assertions as needed.
# e.g. div(a, b) requires b != 0.
validator: "TranslationValidator"
# The 2 functions below are used for conditionally casting between
# integer and reals.
#
# Returns a real expression from 'x'.
@staticmethod
def to_real(x: z3.ArithRef) -> z3.ArithRef:
return x if x.is_real() else z3.ToReal(x)
# Returns an integer expression from 'x'.
@staticmethod
def to_int(x: z3.ArithRef) -> z3.ArithRef:
return x if x.is_int() else z3.ToInt(x)
# Implements Python division semantics.
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
self.validator.add_assertion(denominator != 0) # type: ignore[arg-type]
return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator)
def floor(self, number: z3.ArithRef) -> z3.ArithRef:
# Z3 ToInt function rounds a real number towards negative infinity.
return _Z3Ops.to_int(number)
# Python semantics for 'FloorDiv' states that before applying the floor
# function, the operands are converted to their common type.
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
cast_result_to_real = numerator.is_real() or denominator.is_real()
result = _Z3Ops.to_int(self.div(numerator, denominator))
# Since the 'result' is already an integer, we just have to check
# whether we should cast it to real.
return _Z3Ops.to_real(result) if cast_result_to_real else result
def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.If(
self.floor(number) < number,
self.floor(number + 1),
number
) # type: ignore[return-value]
def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
return z3.If(a > b, a, b) # type: ignore[return-value]
def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
return z3.If(a < b, a, b) # type: ignore[return-value]
# Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q
# It should work with both integer and reals.
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
return p - self.floordiv(p, q) * q
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
# Z3 can't handle complex numbers very well.
self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type]
return base ** exp
def sqrt(self, number: z3.ArithRef) -> z3.ArithRef:
# Square-root:
# 1. Only work with reals
number = _Z3Ops.to_real(number)
# 2. The number should be positive or zero.
# Otherwise, Z3 returns 'unknown'.
self.validator.add_assertion(number >= 0)
return number ** 0.5
def abs(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.Abs(number)
def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
# Pythons builtin 'round' implements the 'round half to even' strategy
# See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
# z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
# floating point numbers, which is different from real numbers that we are dealing with here.
# Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and
# 'round half down' (ceil(x - 0.5)).
# Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ...
# to round down, i.e. use the 'round half down' strategy
return z3.If(
self.mod(number, z3.IntVal(2)) == 0.5,
self.ceil(number - 0.5),
self.floor(number + 0.5),
)
# Lifts a callable to be used in Z3.
#
# This function replaces the given 'op' by a function that:
#
# 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3)
#
# 2. Calls an operation that corresponds to 'op', but works with Z3
# inhabitants (left as is if it works as is)
def z3op(op: Callable, validator: "TranslationValidator") -> Callable:
# Operations that have booleans as their argument.
# This is needed because the argument of some FX nodes were
# literal integers, instead of booleans. So, whenever this flag
# is set, we also convert ints to booleans.
boolean_ops = {operator.not_, operator.and_, operator.or_}
as_bool = op in boolean_ops
# Lifts the function into 'z3.ExprRef' domain.
def lift(func):
def wrap(a) -> z3.ExprRef:
if isinstance(a, (z3.ArithRef, z3.BoolRef)):
return a
# Convert it into a Z3 value, if it is some of the supported
# types below.
if isinstance(a, bool) or (as_bool and isinstance(a, int)):
return z3.BoolVal(bool(a))
if isinstance(a, (int, sympy.Integer)):
return z3.IntVal(int(a))
if isinstance(a, (float, sympy.Float)):
return z3.RealVal(float(a))
raise ValueError(f"can't lift type: {type(a)}")
@functools.wraps(func)
def wrapper(*args):
# Lifts the arguments into a list of Z3 inhabitants.
wrapped_args = (wrap(a) for a in args)
# Run the function on the Z3 expressions.
return func(*wrapped_args)
return wrapper
ops = _Z3Ops(validator)
replacement_map = {
# Operator module.
operator.not_: lift(z3.Not),
operator.and_: lift(z3.And),
operator.or_: lift(z3.Or),
operator.floordiv: lift(ops.floordiv),
operator.truediv: lift(ops.div),
operator.mod: lift(ops.mod),
operator.abs: lift(ops.abs),
builtins.round: lift(ops.round_to_int),
# Math module.
math.ceil: lift(ops.ceil),
math.floor: lift(ops.floor),
# Torch module.
torch.sym_float: lift(ops.to_real),
torch.sym_max: lift(ops.max),
torch.sym_min: lift(ops.min),
torch.sym_ite: lift(lambda b, t, f: t if b else f),
torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined]
# Not lifted because we only use this function as a
# marker for adding the expression as validator input.
torch._assert: torch._assert,
}
return replacement_map[op] if op in replacement_map else lift(op)
# Processes an FX graph, populating the given validator.
#
# [Note: PopulateValidator]
# This class walks through each node in the FX graph, translating
# them into the Z3 world.
#
# Then, whenever it finds an 'torch._assert' call_function operation,
# it adds the Z3 expression corresponding to the argument as validator
# input.
class PopulateValidator(torch.fx.Interpreter):
def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"):
# Reference to the translation validator.
self.validator = validator
# Build the graph module and call `Interpreter` constructor.
module = torch.fx.GraphModule(root={}, graph=graph)
super().__init__(module, garbage_collect_values=True)
def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
symbol = fx_traceback.get_current_meta()["symbol"]
return self.validator.z3var(symbol)
def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
if target != torch._assert:
# Lift and runs the node target function
return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
# Adds the Z3 expression corresponding to the first argument
# as a validator input.
assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} "
self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
# Translates SymPy expressions into Z3 expressions.
#
# [Note: SympyToZ3]
# At the time of the translation, all free variables present in the
# SymPy expression being translated must be already mapped to a Z3
# integer variable.
class SympyToZ3:
OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
def __init__(
self,
validator: "TranslationValidator",
) -> None:
self._validator = validator
self._ops = _Z3Ops(self._validator)
def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
# TODO: Probably OK to relax this and allow lower precision
if dtype is torch.int64:
return z3.IntVal(int(value))
if dtype is torch.double:
return z3.RealVal(float(value))
if dtype is torch.bool:
return z3.BoolVal(bool(value))
raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
if dtype == torch.float64:
return z3.ToReal(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return z3.ToInt(x)
def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return self._ops.round_to_int(x)
def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
return self._ops.div(numerator, denominator)
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
return self._ops.div(numerator, denominator)
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
return self._ops.floordiv(numerator, denominator)
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
return self._ops.floordiv(numerator, denominator)
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
return self._ops.pow(base, exp)
def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
return self._ops.pow(base, exp)
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
return self._ops.mod(p, q)
def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return self._ops.ceil(x)
def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
return self._ops.floor(x)
def __getattr__(self, name: str) -> Any:
REPLACEMENT = {
"and_": z3.And,
"or_": z3.Or,
"not_": z3.Not,
"floor": self._ops.floor,
"ceil": self._ops.ceil,
"minimum": self._ops.min,
"maximum": self._ops.max,
}
if name in REPLACEMENT:
return REPLACEMENT[name]
if name in self.OPERATOR_HANDLES:
return getattr(operator, name)
raise AttributeError(f"unhandled operator: {name}")
def run(self, expr: sympy.Basic) -> z3.ExprRef:
return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type]
# Dynamo guards translation validator.
#
# [Note: TranslationValidator]
# Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound.
# That is: whether those (target) guards only yield TRUE whenever the original,
# unoptimized, (source) guards yield TRUE.
#
# More concretely, given 'source' and 'target' guard expressions, we wish to
# check whether the following expression holds:
#
# Not(And(source)) AND And(target)
#
# i.e. whether there is an assignment of the free variables where the opposite
# happens: target is TRUE, but source is FALSE.
class TranslationValidator:
def __init__(self) -> None:
log.debug("new instance")
# Mapping of SymPy symbols to Z3 variables.
self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {}
# Set of source Z3 expressions.
# They represent the generated guards without any kind of
# simplification or transformation.
self._source_exprs: Set[z3.BoolRef] = set()
# Set of target Z3 expressions.
# They represent the actual checked guards at runtime. They might
# be simplified or transformed versions of the source guards.
self._target_exprs: Set[z3.BoolRef] = set()
# Set of Z3 expressions representing assertions over both the
# source and target expressions.
self._assertions: Set[z3.BoolRef] = set()
# Retrieves the corresponding Z3 variable.
def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
assert symbol in self.symbols, f"Z3 variable not found for: {symbol}"
return self.symbols[symbol]
# Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef:
if symbol in self.symbols:
return self.symbols[symbol]
log.debug("new variable: %s (%s)", symbol.name, type.__name__)
if type is int:
var = z3.Int(symbol.name)
# If 'symbol' is positive (SymPy assumption), we have to
# convey it to Z3 as well.
if symbol.is_positive: # type: ignore[attr-defined]
self._target_exprs.add(var > 0)
elif type is float:
var = z3.Real(symbol.name)
elif type is bool:
var = z3.Bool(symbol.name)
else:
raise RuntimeError(f"unsupported type for Z3 variable: {type}")
self.symbols[symbol] = var
return var
# Checks whether all symbols were already added.
def _check_freesymbols(self, e: sympy.Basic) -> None:
for s in e.free_symbols:
assert isinstance(s, sympy.Symbol)
# Call 'z3var' just to check whether there's already a
# Z3 variable corresponding to 's'.
self.z3var(s)
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
z3expr = SympyToZ3(self).run(e)
assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}"
return z3expr
def add_source_expr(self, e: z3.BoolRef) -> None:
if e not in self._source_exprs:
log.debug("add source guard: %s", z3str(e))
self._source_exprs.add(e)
def add_target_expr(self, e: sympy.Expr) -> None:
self._check_freesymbols(e)
z3expr = self.to_z3_boolean_expr(e)
if e not in self._target_exprs:
log.debug("add target guard: %s", z3str(z3expr))
self._target_exprs.add(z3expr)
def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
if isinstance(e, sympy.Basic):
self._check_freesymbols(e)
ref = self.to_z3_boolean_expr(e)
else:
ref = e
assert isinstance(ref, z3.BoolRef)
if ref not in self._assertions:
log.debug("add assertion: %s", z3str(ref))
self._assertions.add(ref)
def validate(self) -> None:
with dynamo_timed("TranslationValidator.validate"):
return self._validate()
def _validate(self) -> None:
if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
# If there are no source/target expressions, there's nothing we really
# wish to prove. So, we just return.
return None
# Here, we use "QF_NRA" logic for the solver:
# "Quantifier-free Non-linear Real Arithmetic".
#
# Most of the guards expressions have:
# 1. arithmetic between integer and reals
# 2. no quantifiers
# 3. potentially non-linear.
#
# Although there's also "QF_NIRA" (mixed integer-real arithmetic),
# "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'.
solver = z3.SolverFor("QF_NRA")
# Set a timeout for finding a solution.
solver.set(timeout=translation_validation_timeout())
# Add all the assertions to the solver.
for assertion in self._assertions:
solver.add(assertion)
# "Is there any case where it's TRUE for the target expressions,
# but FALSE for the source expressions?"
solver.add(z3.Not(z3.And(*self._source_exprs)))
solver.add(*self._target_exprs)
log.debug("translation validation: start")
r = solver.check()
if r == z3.sat:
# Target expressions are unsound.
# Log the found model and the source expressions that failed.
model = solver.model()
raise ValidationException(
model, self._assertions, self._target_exprs,
failed_source_exprs=[
inp for inp in self._source_exprs if not model.evaluate(inp)
]
)
else:
if r == z3.unknown:
# Could not find a solution. It didn't fail, but it also
# didn't succeed. Canceling the validation execution (keyboard
# interrupt) also gets to this branch.
log.warning("translation validation: could not validate: got z3.unknown")
else:
# Target expressions are sound.
assert r == z3.unsat
log.debug("translation validation: success")
except ImportError:
_HAS_Z3 = False
__all__ = [
"translation_validation_enabled", "translation_validation_timeout",
"ValidationException", "BisectValidationException",
]
else:
_HAS_Z3 = True
__all__ = [
"z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator",
"translation_validation_enabled", "translation_validation_timeout",
"ValidationException", "BisectValidationException",
]
from torch.fx.experimental import _config as config
def translation_validation_enabled() -> bool:
# Checks everytime this function is called, in case the Dynamo
# option is set, but Z3 is not installed.
_assert_z3_installed_if_tv_set()
return _HAS_Z3 and config.translation_validation
def translation_validation_timeout() -> int:
return config.translation_validation_timeout
def _assert_z3_installed_if_tv_set():
assert _HAS_Z3 or not config.translation_validation, (
"translation validation requires Z3 package. Please, either install "
"z3-solver or disable translation validation."
)
class ValidationException(TorchDynamoException):
def __init__(self, model, assertions, target_exprs, failed_source_exprs):
assert _HAS_Z3
def symbolstr(sym) -> str:
return f"{sym}: {model[sym]}"
def joinlines(xs) -> str:
return "\n".join(f" ==> {x}" for x in xs)
model_str = joinlines(sorted(map(symbolstr, model)))
assertions_str = joinlines(sorted(map(z3str, assertions)))
target_exprs_str = joinlines(sorted(map(z3str, target_exprs)))
failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs)))
self.msg = "translation validation failed."
self.details = f"""\
Model:
{model_str}
Assertions:
{assertions_str}
Target Expressions:
{target_exprs_str}
Failed Source Expressions:
{failed_source_exprs_str}"""
def __str__(self):
return f"{self.msg}\n\n{self.details}"
class BisectValidationException(TorchDynamoException):
def __init__(self, validation_exc, expr, failed_action, traced_node):
self.msg = f"translation validation failed when {failed_action}: {expr}"
self.details = f"""\
Failure occurred while running node:
{traced_node.format_node()}
{validation_exc.details}"""
def __str__(self):
return f"{self.msg}\n\n{self.details}"
# Checks when this module is loaded.
_assert_z3_installed_if_tv_set()
# Translation validation bisection.
#
# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
# the earliest ValidationException.
#
# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors
# might be silently happening. This function tries to nail down exactly at which
# point things went wrong from a validation perspective.
def bisect(shape_env):
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY
from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events
events = shape_env.events
# Retrieves the ShapeEnvEvent associated with node.
def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent:
assert SHAPEENV_EVENT_KEY in node.meta
return events[node.meta[SHAPEENV_EVENT_KEY]]
# Creates a new instance of fake, but updating every symbolic value's ShapeEnv
# reference to the one given as argument.
#
# This is needed so as not to simplify a symbolic expression using a ShapeEnv
# "from the future", where it may have a different set of replacements.
def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any:
if isinstance(fake, int):
return fake
if isinstance(fake, torch.SymInt):
return torch.SymInt(fake.node.with_shape_env(shape_env))
assert isinstance(fake, FakeTensorMeta)
return FakeTensorMeta(
tuple(new_with_shape_env(shape_env, s) for s in fake.size()),
tuple(new_with_shape_env(shape_env, s) for s in fake.stride()),
new_with_shape_env(shape_env, fake.storage_offset()),
fake.is_nested,
)
# Checks whether the given shape_env fails when produce_guards is called.
def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]:
assert tracked_fakes is not None
try:
# This produce_guards call is a best-effort replication, since we
# don't populate EqualityConstraint list. Reason: we would also have
# to save OutputGraph.tracked_fakes_id_to_source.
shape_env.produce_guards(
[new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
[a.source for a in tracked_fakes],
input_contexts=[a.symbolic_context for a in tracked_fakes],
)
return None
except ValidationException as e:
return e
# Checks whether the ShapeEnv reconstructed by replaying the events until
# node is created fails when produce_guards is called.
def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
number = node.meta[SHAPEENV_EVENT_KEY]
# Reconstruct shape_env until the event at event_number.
shape_env = replay_shape_env_events(events[:number + 1])
shape_env.graph.lint()
return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes())
if not last_exception:
# We don't actually fail due to a produce_guards call.
# Stop and don't bisect.
log.info("translation validation succeeded: no errors found.")
return
if not shape_env.should_record_events or config.translation_validation_no_bisect:
# Bisection is off.
# Return the last ValidationException we got.
raise last_exception
# Cache the raised exception (if any) at each bisection point.
exception = {}
# Bisection happens on the assertion nodes of the recorded FX graph for
# dynamic shapes.
assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert]
# Preparing the indices for binary search.
left, mid, right = 0, 0, len(assert_nodes) - 1
while left < right:
mid = (left + right) // 2
node = assert_nodes[mid]
log.debug("bisecting at %s: %s", mid, get_node_event(node))
# Check whether the new shape_env raises a ValidationException or not.
exception[mid] = check_node_fails(node)
if exception[mid]:
right = mid
else:
left = mid + 1
assert left in exception and isinstance(exception[left], ValidationException)
node = assert_nodes[left]
event = get_node_event(node)
if event.is_evaluate_expr():
failed_action = "evaluating"
else:
assert event.is_defer_runtime_assert(), f"unexpected event type: {event}"
failed_action = "adding runtime assert"
args = event.args
assert args is not None
assert len(args) >= 2, (
f"bisecting expects {event.name} to have at least 2 positional arguments. "
f"Got: {len(args)}"
)
assert isinstance(args[1], sympy.Basic), (
f"bisecting expects {event.name} to have a SymPy expression as its second argument. "
f"Got: {type(args[1])}"
)
raise BisectValidationException(
exception[left],
expr=args[1],
failed_action=failed_action,
traced_node=node.meta[CURRENT_NODE_KEY],
)