I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,12 @@
from . import graph_drawer
from . import graph_manipulation
from . import net_min_base
from . import operator_support
from . import param_fetch
from . import reinplace
from . import runtime_assert
from . import shape_prop
from . import split_module
from . import split_utils
from . import splitter_base
from . import tools_common

View File

@ -0,0 +1,44 @@
import operator
import torch
def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
"""
Annotate the type of getitem nodes, inferred from the type of sequence node.
If sequence node is not annotated with a type, do nothing.
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
This is helpful since annotations on local names within function are lost during FX transforms.
Adding back known type annotation for getitem nodes to improve jit scriptability.
Args:
graph (Graph): The graph to be annotated
"""
for node in graph.nodes:
if node.target == operator.getitem:
sequence_node, index_node = node.args
if not sequence_node.type:
continue
# container types
if hasattr(sequence_node.type, "_name"):
parameterized_types = sequence_node.type.__args__
if sequence_node.type._name == "Tuple":
if len(parameterized_types) == 2 and isinstance(
parameterized_types[1], type(...)
):
node.type = parameterized_types[0]
else:
assert len(parameterized_types) > index_node
node_type = parameterized_types[index_node]
node.type = node_type
elif sequence_node.type._name == "List":
assert len(parameterized_types) == 1
node.type = parameterized_types[0]
# NamedTuple type
elif hasattr(sequence_node.type, "__annotations__"):
if sequence_node.type == torch.Tensor:
continue
sequence_node_field_types = sequence_node.type.__annotations__
field_name = sequence_node.type._fields[index_node]
node.type = sequence_node_field_types[field_name]

View File

@ -0,0 +1,57 @@
# mypy: allow-untyped-defs
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.utils import _pytree as pytree
import operator
class CudaGraphsSupport(OperatorSupport):
# TODO: why is submodules passed here
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
if node.op not in CALLABLE_NODE_OPS:
return False
if node.target in [torch.ops.aten.embedding_dense_backward.default]:
return False
if node.target in [operator.getitem]:
return True
found_not_cuda = False
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
def find_not_cuda(t):
nonlocal found_not_cuda
if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
found_not_cuda = True
for n in node.all_input_nodes:
pytree.tree_map_(find_not_cuda, meta_fk(n.meta))
pytree.tree_map_(find_not_cuda, meta_fk(node.meta))
# NB: factory function is accounted for because the result would be
# cpu or cuda
return not found_not_cuda
def partition_cudagraphs(gm, inputs):
"""
Partition an FX graph into sub-GraphModules that can be validly run under
CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations
must involve CUDA tensors only/
"""
FakeTensorProp(gm).propagate(*inputs)
supported_ops = CudaGraphsSupport()
# TODO: single node partition may be wrong due to the pessimization
# from copying in and out the data. Check in benchmarks, perhaps
partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
partitions = partitioner.propose_partitions()
fused_graph = partitioner.fuse_partitions(partitions)
return fused_graph

View File

@ -0,0 +1,113 @@
# mypy: allow-untyped-defs
from typing import Dict, Tuple, Any
import torch
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.utils._pytree import tree_flatten
from torch.fx import GraphModule, Graph
from torch.fx import Node
aten = torch.ops.aten
# stateful ops are banned from CSE
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def get_CSE_banned_ops():
return rand_ops.union(inplace_ops)
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
class CSEPass(PassBase):
def __init__(self, banned_ops=None):
"""
This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
For functional dialects, user would only need to specify the random ops in ban list.
Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
If your dialect contains stateful operators, please customized the banned_ops.
"""
if banned_ops is None:
banned_ops = set()
self.banned_ops = banned_ops
super().__init__()
def call(self, graph_module: GraphModule) -> PassResult:
"""
Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
Example usage:
from torch.fx.experimental.proxy_tensor import make_fx
def f(a):
b = a * a
c = a * a
return b+c
p = CSEPass()
traced_graph = make_fx(f)(torch.tensor(1))
print(traced_graph)
result = p(traced_graph)
print(result.graph_module)
"""
def get_aten_target(node):
if hasattr(node.target, 'overloadpacket'):
return node.target.overloadpacket
return node.target
modified = False
new_graph = Graph()
env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph
hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph
token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token
for n in graph_module.graph.nodes:
# The placeholder, output, and get_attr nodes are copied to the new graph without change
# do not CSE away random operations
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
# substitute args and kwargs members to their mapping in env if exists
# specs can be used to reconstruct nested list/dictionaries
def substitute(arg_list):
arg_list, spec = tree_flatten(arg_list)
for i in range(len(arg_list)):
v = arg_list[i]
if isinstance(v, Node) and v in env:
arg_list[i] = env[v]
return tuple(arg_list), spec
args, args_spec = substitute(n.args)
kwargs, kwargs_spec = substitute(n.kwargs)
# each token corresponds to a unique node
# nodes with the same token can be substituted
token = {"target": n.target, "args": args, "args_spec": args_spec,
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
# hash substituted args to a number, do not hash specs because specs are not hashable
hash_arg = hash((args, kwargs))
hash_val = (n.target, hash_arg)
# check if a node has a substitute and can be eliminated
hash_val_in_hash_env = hash_val in hash_env
if hash_val_in_hash_env and token_map[hash_val] == token:
modified = True # substitution happens and the graph is modified
env[n] = hash_env[hash_val]
continue
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
if not hash_val_in_hash_env:
hash_env[hash_val] = new_node
token_map[hash_val] = token
csed_gm = GraphModule(graph_module, new_graph)
return PassResult(csed_gm, modified)

View File

@ -0,0 +1,70 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch.fx
from torch.fx import Node
from torch.fx.node import map_aggregate
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types
__all__ = ['FakeTensorProp']
@compatibility(is_backward_compatible=False)
class FakeTensorProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and record a fake tensor representing
the metadata for the node. Unlike ShapeProp, (1) this propagation
is cheap--it does the propagation with meta tensors which do not actually
store data, and (2) the fake tensors have much more fine grained information,
e.g., they have accurate alias information that can be consulted by looking
at the storages.
Args:
module (GraphModule): The module to be executed
mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
"""
def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
super().__init__(module)
if mode is None:
mode = FakeTensorMode()
self._mode = mode
mode.epoch += 1
mode.reset_nt_tensor_id_counter()
def run_node(self, n: Node):
from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings
result = super().run_node(n)
rebind_unbacked(self._mode.shape_env, n, result)
def extract_val(obj):
if isinstance(obj, FakeTensor):
return snapshot_fake(obj)
elif isinstance(obj, torch.Tensor):
# TODO: How is it possible that we get a non fake tensor? We
# should be running under the mode...
return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
elif isinstance(obj, py_sym_types):
return obj
else:
return None
meta = map_aggregate(result, extract_val)
if meta is not None:
n.meta['val'] = meta
if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)):
n.meta["unbacked_bindings"] = symbol_to_path
return result
def propagate(self, *args):
fake_args = [
self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a
for a in args
]
return self.propagate_dont_convert_inputs(*fake_args)
def propagate_dont_convert_inputs(self, *args):
with self._mode:
return super().run(*args)

View File

@ -0,0 +1,443 @@
# mypy: allow-untyped-defs
import hashlib
from itertools import chain
from typing import Any, Dict, Optional, TYPE_CHECKING
import torch
import torch.fx
from torch.fx._compatibility import compatibility
from torch.fx.graph import _parse_stack_trace
from torch.fx.node import _format_arg, _get_qualified_name
from torch.fx.operator_schemas import normalize_function
from torch.fx.passes.shape_prop import TensorMetadata
try:
import pydot
HAS_PYDOT = True
except ModuleNotFoundError:
HAS_PYDOT = False
pydot = None
__all__ = ["FxGraphDrawer"]
_COLOR_MAP = {
"placeholder": '"AliceBlue"',
"call_module": "LemonChiffon1",
"get_param": "Yellow2",
"get_attr": "LightGrey",
"output": "PowderBlue",
}
_HASH_COLOR_MAP = [
"CadetBlue1",
"Coral",
"DarkOliveGreen1",
"DarkSeaGreen1",
"GhostWhite",
"Khaki1",
"LavenderBlush1",
"LightSkyBlue",
"MistyRose1",
"MistyRose2",
"PaleTurquoise2",
"PeachPuff1",
"Salmon",
"Thistle1",
"Thistle3",
"Wheat1",
]
_WEIGHT_TEMPLATE = {
"fillcolor": "Salmon",
"style": '"filled,rounded"',
"fontcolor": "#000000",
}
if HAS_PYDOT:
@compatibility(is_backward_compatible=False)
class FxGraphDrawer:
"""
Visualize a torch.fx.Graph with graphviz
Basic usage:
g = FxGraphDrawer(symbolic_traced, "resnet18")
g.get_dot_graph().write_svg("a.svg")
"""
def __init__(
self,
graph_module: torch.fx.GraphModule,
name: str,
ignore_getattr: bool = False,
ignore_parameters_and_buffers: bool = False,
skip_node_names_in_args: bool = True,
parse_stack_trace: bool = False,
dot_graph_shape: Optional[str] = None,
normalize_args: bool = False,
):
self._name = name
self.dot_graph_shape = (
dot_graph_shape if dot_graph_shape is not None else "record"
)
self.normalize_args = normalize_args
_WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
self._dot_graphs = {
name: self._to_dot(
graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace
)
}
for node in graph_module.graph.nodes:
if node.op != "call_module":
continue
leaf_node = self._get_leaf_node(graph_module, node)
if not isinstance(leaf_node, torch.fx.GraphModule):
continue
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
leaf_node,
f"{name}_{node.target}",
ignore_getattr,
ignore_parameters_and_buffers,
skip_node_names_in_args,
parse_stack_trace,
)
def get_dot_graph(self, submod_name=None) -> pydot.Dot:
"""
Visualize a torch.fx.Graph with graphviz
Example:
>>> # xdoctest: +REQUIRES(module:pydot)
>>> # xdoctest: +REQUIRES(module:ubelt)
>>> # define module
>>> class MyModule(torch.nn.Module):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.linear = torch.nn.Linear(4, 5)
>>> def forward(self, x):
>>> return self.linear(x).clamp(min=0.0, max=1.0)
>>> module = MyModule()
>>> # trace the module
>>> symbolic_traced = torch.fx.symbolic_trace(module)
>>> # setup output file
>>> import ubelt as ub
>>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
>>> fpath = dpath / 'linear.svg'
>>> # draw the graph
>>> g = FxGraphDrawer(symbolic_traced, "linear")
>>> g.get_dot_graph().write_svg(fpath)
"""
if submod_name is None:
return self.get_main_dot_graph()
else:
return self.get_submod_dot_graph(submod_name)
def get_main_dot_graph(self) -> pydot.Dot:
return self._dot_graphs[self._name]
def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
return self._dot_graphs[f"{self._name}_{submod_name}"]
def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
return self._dot_graphs
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
template = {
"shape": self.dot_graph_shape,
"fillcolor": "#CAFFE3",
"style": '"filled,rounded"',
"fontcolor": "#000000",
}
if node.op in _COLOR_MAP:
template["fillcolor"] = _COLOR_MAP[node.op]
else:
# Use a random color for each node; based on its name so it's stable.
target_name = node._pretty_print_target(node.target)
target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
return template
def _get_leaf_node(
self, module: torch.nn.Module, node: torch.fx.Node
) -> torch.nn.Module:
py_obj = module
assert isinstance(node.target, str)
atoms = node.target.split(".")
for atom in atoms:
if not hasattr(py_obj, atom):
raise RuntimeError(
str(py_obj) + " does not have attribute " + atom + "!"
)
py_obj = getattr(py_obj, atom)
return py_obj
def _typename(self, target: Any) -> str:
if isinstance(target, torch.nn.Module):
ret = torch.typename(target)
elif isinstance(target, str):
ret = target
else:
ret = _get_qualified_name(target)
# Escape "{" and "}" to prevent dot files like:
# https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
# which triggers `Error: bad label format (...)` from dot
return ret.replace("{", r"\{").replace("}", r"\}")
# shorten path to avoid drawing long boxes
# for full path = '/home/weif/pytorch/test.py'
# return short path = 'pytorch/test.py'
def _shorten_file_name(
self,
full_file_name: str,
truncate_to_last_n: int = 2,
):
splits = full_file_name.split('/')
if len(splits) >= truncate_to_last_n:
return '/'.join(splits[-truncate_to_last_n:])
return full_file_name
def _get_node_label(
self,
module: torch.fx.GraphModule,
node: torch.fx.Node,
skip_node_names_in_args: bool,
parse_stack_trace: bool,
) -> str:
def _get_str_for_args_kwargs(arg):
if isinstance(arg, tuple):
prefix, suffix = r"|args=(\l", r",\n)\l"
arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
elif isinstance(arg, dict):
prefix, suffix = r"|kwargs={\l", r",\n}\l"
arg_strs_list = [
f"{k}: {_format_arg(v, max_list_len=8)}"
for k, v in arg.items()
]
else: # Fall back to nothing in unexpected case.
return ""
# Strip out node names if requested.
if skip_node_names_in_args:
arg_strs_list = [a for a in arg_strs_list if "%" not in a]
if len(arg_strs_list) == 0:
return ""
arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
if len(arg_strs_list) == 1:
arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
return arg_strs.replace("{", r"\{").replace("}", r"\}")
label = "{" + f"name=%{node.name}|op_code={node.op}\n"
if node.op == "call_module":
leaf_module = self._get_leaf_node(module, node)
label += r"\n" + self._typename(leaf_module) + r"\n|"
extra = ""
if hasattr(leaf_module, "__constants__"):
extra = r"\n".join(
[f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr]
)
label += extra + r"\n"
else:
label += f"|target={self._typename(node.target)}" + r"\n"
if self.normalize_args:
try:
args, kwargs = normalize_function( # type: ignore[misc]
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
)
except Exception:
# Fallback to not normalizing if there's an exception.
# Some functions need overloads specified to normalize.
args, kwargs = node.args, node.kwargs
else:
args, kwargs = node.args, node.kwargs
if len(args) > 0:
label += _get_str_for_args_kwargs(args)
if len(kwargs) > 0:
label += _get_str_for_args_kwargs(kwargs)
label += f"|num_users={len(node.users)}" + r"\n"
tensor_meta = node.meta.get('tensor_meta')
label += self._tensor_meta_to_label(tensor_meta)
# for original fx graph
# print buf=buf0, n_origin=6
buf_meta = node.meta.get('buf_meta', None)
if buf_meta is not None:
label += f"|buf={buf_meta.name}" + r"\n"
label += f"|n_origin={buf_meta.n_origin}" + r"\n"
# for original fx graph
# print file:lineno code
if parse_stack_trace and node.stack_trace is not None:
parsed_stack_trace = _parse_stack_trace(node.stack_trace)
fname = self._shorten_file_name(parsed_stack_trace.file)
label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
return label + "}"
def _tensor_meta_to_label(self, tm) -> str:
if tm is None:
return ""
elif isinstance(tm, TensorMetadata):
return self._stringify_tensor_meta(tm)
elif isinstance(tm, list):
result = ""
for item in tm:
result += self._tensor_meta_to_label(item)
return result
elif isinstance(tm, dict):
result = ""
for v in tm.values():
result += self._tensor_meta_to_label(v)
return result
elif isinstance(tm, tuple):
result = ""
for item in tm:
result += self._tensor_meta_to_label(item)
return result
else:
raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
result = ""
if not hasattr(tm, "dtype"):
print("tm", tm)
result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
if tm.is_quantized:
assert tm.qparams is not None
assert "qscheme" in tm.qparams
qscheme = tm.qparams["qscheme"]
if qscheme in {
torch.per_tensor_affine,
torch.per_tensor_symmetric,
}:
result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
elif qscheme in {
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
}:
result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
else:
raise RuntimeError(f"Unsupported qscheme: {qscheme}")
result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
return result
def _get_tensor_label(self, t: torch.Tensor) -> str:
return str(t.dtype) + str(list(t.shape)) + r"\n"
# when parse_stack_trace=True
# print file:lineno code
def _to_dot(
self,
graph_module: torch.fx.GraphModule,
name: str,
ignore_getattr: bool,
ignore_parameters_and_buffers: bool,
skip_node_names_in_args: bool,
parse_stack_trace: bool,
) -> pydot.Dot:
"""
Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
If ignore_parameters_and_buffers is True, the parameters and buffers
created with the module will not be added as nodes and edges.
"""
# "TB" means top-to-bottom rank direction in layout
dot_graph = pydot.Dot(name, rankdir="TB")
buf_name_to_subgraph = {}
for node in graph_module.graph.nodes:
if ignore_getattr and node.op == "get_attr":
continue
style = self._get_node_style(node)
dot_node = pydot.Node(
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style
)
current_graph = dot_graph
buf_meta = node.meta.get('buf_meta', None)
if buf_meta is not None and buf_meta.n_origin > 1:
buf_name = buf_meta.name
if buf_name not in buf_name_to_subgraph:
buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name)
current_graph = buf_name_to_subgraph.get(buf_name)
current_graph.add_node(dot_node)
def get_module_params_or_buffers():
for pname, ptensor in chain(
leaf_module.named_parameters(), leaf_module.named_buffers()
):
pname1 = node.name + "." + pname
label1 = (
pname1 + "|op_code=get_" + "parameter"
if isinstance(ptensor, torch.nn.Parameter)
else "buffer" + r"\l"
)
dot_w_node = pydot.Node(
pname1,
label="{" + label1 + self._get_tensor_label(ptensor) + "}",
**_WEIGHT_TEMPLATE,
)
dot_graph.add_node(dot_w_node)
dot_graph.add_edge(pydot.Edge(pname1, node.name))
if node.op == "call_module":
leaf_module = self._get_leaf_node(graph_module, node)
if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
get_module_params_or_buffers()
for subgraph in buf_name_to_subgraph.values():
subgraph.set('color', 'royalblue')
subgraph.set('penwidth', '2')
dot_graph.add_subgraph(subgraph)
for node in graph_module.graph.nodes:
if ignore_getattr and node.op == "get_attr":
continue
for user in node.users:
dot_graph.add_edge(pydot.Edge(node.name, user.name))
return dot_graph
else:
if not TYPE_CHECKING:
@compatibility(is_backward_compatible=False)
class FxGraphDrawer:
def __init__(
self,
graph_module: torch.fx.GraphModule,
name: str,
ignore_getattr: bool = False,
ignore_parameters_and_buffers: bool = False,
skip_node_names_in_args: bool = True,
parse_stack_trace: bool = False,
dot_graph_shape: Optional[str] = None,
normalize_args: bool = False,
):
raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
'pydot through your favorite Python package manager.')

View File

@ -0,0 +1,111 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, NamedTuple, Optional
import torch
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import (
map_arg,
Node,
Target,
)
from torch.fx.passes.shape_prop import ShapeProp
__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
'get_size_of_node']
@compatibility(is_backward_compatible=False)
def replace_target_nodes_with(
fx_module: GraphModule,
old_op: str,
old_target: Target,
new_op: str,
new_target: Target,
):
"""Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
and updates them to match the new op code and target"""
new_graph = Graph()
val_map: Dict[Node, Node] = {}
for node in fx_module.graph.nodes:
if node.op == old_op and node.target == old_target:
args = map_arg(node.args, lambda n: val_map[n])
kwargs = map_arg(node.kwargs, lambda n: val_map[n])
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
val_map[node] = new_graph.create_node(
new_op, new_target, args, kwargs, node.name
)
else:
val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
fx_module.graph = new_graph
@compatibility(is_backward_compatible=False)
class size_bytes(NamedTuple):
output_size: int
total_size: int
@compatibility(is_backward_compatible=False)
def get_size_of_all_nodes(
fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
) -> None:
"""Given a fx graph module, update each node with its total size (weights + bias + output)
and its output_size(output). For a non-module node, the total size is the output size.
return total size"""
if args is not None:
# Mark shape and dtype for each node (node.shape and node.dtype)
ShapeProp(fx_module).propagate(*args)
# Calculate the total size of the whole fx graph
total_size_of_graph = 0.0
for node in fx_module.graph.nodes:
if node.op == "output":
break
node.size_bytes = get_size_of_node(fx_module, node)
return
@compatibility(is_backward_compatible=False)
def get_tensor_meta(node: Node) -> Any:
tensor_meta = node.meta.get("tensor_meta")
if not tensor_meta:
raise RuntimeError(
f"Node {node} has no tensor metadata associated with it! "
f"Check that shape propagation has run."
)
return tensor_meta
@compatibility(is_backward_compatible=False)
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
"""Given a node with node.dtype and node.shape, return its total size and its output size.
total_size = weights + bias + output_size
"""
# Total num of elements
total_num_of_elems = 0
# For a module, conside all parameters
if node.op == "call_module":
submodule_dict = dict(fx_module.named_modules())
submodule = submodule_dict[node.target]
parameters = submodule.named_parameters()
# Parameters are named tuples
for name, p in parameters:
total_num_of_elems += p.numel()
# Don't forget the output size
# node.shape is the shape of this node's output
tensor_meta = get_tensor_meta(node)
output_elem = tensor_meta.shape.numel()
total_num_of_elems += output_elem
# Assume for now if it's quantized then it's qint8 or quint8
if tensor_meta.is_quantized:
size_per_elem_bytes = torch._empty_affine_quantized(
[], dtype=tensor_meta.dtype
).element_size()
else:
size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
total_size = size_per_elem_bytes * total_num_of_elems
output_size = size_per_elem_bytes * output_elem
return size_bytes(output_size, total_size)

View File

@ -0,0 +1,91 @@
# mypy: allow-untyped-defs
import os
from typing import Optional
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from .graph_drawer import FxGraphDrawer
__all__ = ["GraphTransformObserver"]
@compatibility(is_backward_compatible=False)
class GraphTransformObserver:
__pass_count = 0
def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None):
# If log_url is None, we don't log anything
self.log_url = log_url
if self.log_url is None:
return
GraphTransformObserver.__pass_count += 1
self.gm = gm
self.passname = passname
self.input_dot_graph = FxGraphDrawer(
self.gm,
self.passname,
ignore_getattr=True,
ignore_parameters_and_buffers=True,
).get_dot_graph()
@classmethod
def get_current_pass_count(cls):
return cls.__pass_count
def __enter__(self):
if self.log_url is None or self.gm is None:
return self
self.erased_nodes = set()
self.created_nodes = set()
self.gm._register_create_node_hook(self.on_node_creation)
self.gm._register_erase_node_hook(self.on_node_erase)
return self
def __exit__(self, type, value, tb):
if self.log_url is None or self.gm is None:
return
self.gm._unregister_create_node_hook(self.on_node_creation)
self.gm._unregister_erase_node_hook(self.on_node_erase)
if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0:
for e in self.input_dot_graph.get_node_list():
if e.get_name() in self.erased_nodes:
e.obj_dict["attributes"]["fillcolor"] = "yellow"
else:
e.obj_dict["attributes"]["fillcolor"] = "grey"
self.input_dot_graph.write(
os.path.join(
self.log_url,
f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot",
)
)
output_dot_graph = FxGraphDrawer(
self.gm,
self.passname,
ignore_getattr=True,
ignore_parameters_and_buffers=True,
).get_dot_graph()
for e in output_dot_graph.get_node_list():
if e.get_name() in self.created_nodes:
e.obj_dict["attributes"]["fillcolor"] = "yellow"
else:
e.obj_dict["attributes"]["fillcolor"] = "grey"
output_dot_graph.write(
os.path.join(
self.log_url,
f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot",
)
)
def on_node_creation(self, node):
self.created_nodes.add(node.name)
def on_node_erase(self, node):
self.erased_nodes.add(node.name)

View File

@ -0,0 +1,2 @@
from . import pass_manager

View File

@ -0,0 +1,335 @@
# mypy: allow-untyped-defs
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
import collections
import itertools
import logging
from copy import copy
from typing import Dict, Iterable, List, Optional, Sequence, Set
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupportBase
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class Partition:
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
self.id = id
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
def __repr__(self) -> str:
return str(self.nodes)
def add_node(self, node: Node):
self.nodes.update({node: None})
def remove_node(self, node: Node):
del self.nodes[node]
def size(self):
return len(self.nodes)
class _DependencyViewer:
def __init__(self, graph_module: GraphModule):
self.upstreams = collections.defaultdict(set)
self.downstreams = collections.defaultdict(set)
for node in graph_module.graph.nodes:
for input_node in node.all_input_nodes:
# add input_node and input_node's upstream dependency
self.upstreams[node].add(input_node)
self.upstreams[node].update(self.upstreams[input_node])
for node in reversed(graph_module.graph.nodes):
for output_node in node.users:
# add output_node and output_node's downstream dependency
self.downstreams[node].add(output_node)
self.downstreams[node].update(self.downstreams[output_node])
def downstreams_of(self, node: Node) -> Set[Node]:
return self.downstreams[node]
def upstreams_of(self, node: Node) -> Set[Node]:
return self.upstreams[node]
class CapabilityBasedPartitioner:
def __init__(self,
graph_module: GraphModule,
operator_support: OperatorSupportBase,
allows_single_node_partition: bool = False,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
) -> None:
self.graph_module = graph_module
self.operator_support = operator_support
self.allows_single_node_partition = allows_single_node_partition
self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
self.allowed_single_node_partition_ops = (
allowed_single_node_partition_ops
if allowed_single_node_partition_ops is not None
else []
)
self.dependency_viewer = _DependencyViewer(graph_module)
def __is_node_supported(self, node: Node) -> bool:
return (
self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
)
def propose_partitions(self) -> List[Partition]:
# partition_map is a mapping from partition id to a set of partition id's.
# The value set contains all the partition ids that can be reached by doing a
# DFS starting from the partition id in the key.
partition_map : Dict[int, Set] = collections.defaultdict(set)
# assumptions: nodes in candidate list is sorted in topological order
assignment: Dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
new_partition_id = itertools.count()
# try to merge partition other_id into partition self_id
# merge only happens if the end graph doesn't contain cyclic dependency
# returns `True` when merge happens, `False` otherwise.
def maybe_merge_partition(self_id: int, other_id: int):
# merged_nodes is the union of nodes in two partition to-be-merged
merged_nodes = copy(partitions_by_id[self_id].nodes)
merged_nodes.update(partitions_by_id[other_id].nodes)
def dfs_iter_find_cycle(all_user_nodes: Set[Node]):
for user_node in all_user_nodes:
visited_partition_ids = set()
for path_node in self.dependency_viewer.downstreams_of(user_node):
# If any of the nodes in the dfs path of this node are in the merged_nodes
# list then there is a cycle in the graph.
if path_node in merged_nodes:
return True
# If any of the nodes in the dfs path of this node are in the assignment
# map then we have to make sure that the partitions that these nodes belong
# to do not form a cycle with the current partitions being merged. This means
# iterating through all the nodes in all the parititons that are traversed in
# the dfs path and checking if they are in the merged_nodes list.
if path_node in assignment:
partition_id = assignment[path_node]
# If the partition id has already been visited then we know that it doesn't
# form a cycle with the current partitions being merged.
if partition_id in visited_partition_ids:
continue
p_map = partition_map[partition_id]
if self_id in p_map or other_id in p_map:
return True
visited_partition_ids.add(partition_id)
return False
# check if merge would create cyclic dependency.
all_user_nodes = set()
for node in merged_nodes:
for user_node in node.users:
if user_node not in merged_nodes:
all_user_nodes.add(user_node)
if dfs_iter_find_cycle(all_user_nodes):
# return false indicating cyclic dependency found and
# merge is aborted
return False
# no cyclic dependency found, move forward with the merge
# updating partition nodes
partitions_by_id[self_id].nodes = merged_nodes
# updating assignment map
for node in partitions_by_id[other_id].nodes:
assignment[node] = self_id
# delete other partition
del partitions_by_id[other_id]
partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
del partition_map[other_id]
return True
def merge_single_node(node: Node, id: Optional[int]):
def _update_partition_map(node: Node, id: int):
# Iterate through all the downstream nodes of this node and update the partition map
# to indicate that there is a path from the partition id of this node to the target
# partition id.
downstream_nodes = self.dependency_viewer.downstreams_of(node)
for curr_node in downstream_nodes:
target_id = assignment.get(curr_node, None)
if target_id is not None:
partition_map[id].add(target_id)
# Iterate through all the upstream nodes of this node and update the partition map
# to indicate that there is a path from the partition id of the upstream node to the
# current node's partition id.
upstream_nodes = self.dependency_viewer.upstreams_of(node)
for curr_node in upstream_nodes:
source_id = assignment.get(curr_node, None)
if source_id is not None:
partition_map[source_id].add(id)
if node in assignment:
partitions_by_id[assignment[node]].remove_node(node)
if id is None:
assignment.pop(node)
elif id not in partitions_by_id:
assignment[node] = id
partitions_by_id[id] = Partition(id=id, nodes=[node])
_update_partition_map(node, id)
else:
assignment[node] = id
partitions_by_id[id].add_node(node)
_update_partition_map(node, id)
logger.debug("Proposing partitions...")
for node in reversed(self.graph_module.graph.nodes):
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
merge_candidates: Dict[int, None] = {}
# Note a limited horizontal fusion is enabled:
# when `node` is not supported, the code below attempts to fuse consumer of `node`.
#
# I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
# the fusion by adding an `else` block here to skip horizontal fusion.
if self.__is_node_supported(node) and node not in assignment:
partition_id = next(new_partition_id)
merge_single_node(node, partition_id)
merge_candidates[partition_id] = None
# merge all possible partitions
for node in assignment:
merge_candidates[assignment[node]] = None
merge_candidates_list = list(merge_candidates.keys())
if len(merge_candidates_list) > 1:
self_id = merge_candidates_list[0]
for other_id in merge_candidates_list[1:]:
# note: merge partition `other_id` into partition `self_id` if
# it doesn't create cyclic dependency in the graph, otherwise,
# this is a no-op
maybe_merge_partition(self_id, other_id)
# post processing to re-assign "getitem" nodes into upstream partition
logger.debug("Reassigning getitem nodes to its producer node's partition...")
nodes_reassignment: Dict[Node, int] = {}
for node in self.graph_module.graph.nodes:
is_tuple_output = True
for user in node.users:
if user.op != "call_function" or \
_get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
is_tuple_output = False
break
# node has tuple outputs, re-assign all following getitem node into node's partition
if is_tuple_output:
id = assignment.get(node, None) # type: ignore[arg-type]
for user in node.users:
if assignment.get(user, None) != id: # type: ignore[arg-type]
nodes_reassignment[user] = id # type: ignore[assignment]
for node, id in nodes_reassignment.items():
merge_single_node(node, id)
# filter out single node partitions
if not self.allows_single_node_partition:
logger.debug("Filtering out single node partitions...")
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
partitions_to_remove: List[int] = []
for id, partition in partitions_by_id.items():
compute_node_count = 0
for node in partition.nodes:
if node.op == "call_function":
assert callable(node.target)
if _get_qualified_name(node.target) not in non_compute_ops:
compute_node_count += 1
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
compute_node_count += 1
if compute_node_count <= 1:
partitions_to_remove.append(id)
for id in partitions_to_remove:
del partitions_by_id[id]
logger.debug("Partitions proposed:")
for id, partition in partitions_by_id.items():
logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
return [partition for partition in partitions_by_id.values() if partition.size() > 0]
def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule:
logger.debug("Fusing partitions...")
# fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
return fuse_by_partitions(
self.graph_module,
[list(partition.nodes) for partition in partitions],
prefix=prefix,
)
# remove non-compute-ops that sits at the boundary of a partition.
def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
non_compute_ops = set(self.non_compute_ops)
def is_non_compute_node(node: Node):
return node.op == "call_function" and \
_get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
# cache transparent nodes
transparent_input_nodes: Dict[Node, bool] = {}
transparent_output_nodes: Dict[Node, bool] = {}
def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
return True
if node in transparent_input_nodes:
return transparent_input_nodes[node]
if is_non_compute_node(node):
for input_n in node.all_input_nodes:
if not is_transparent_input_node(input_n, partition, removed_nodes):
transparent_input_nodes[node] = False
return False
transparent_input_nodes[node] = True
return True
transparent_input_nodes[node] = False
return False
def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
return True
if node in transparent_output_nodes:
return transparent_output_nodes[node]
if is_non_compute_node(node):
for output_n in node.users:
if not is_transparent_output_node(output_n, partition, removed_nodes):
transparent_output_nodes[node] = False
return False
transparent_output_nodes[node] = True
return True
transparent_output_nodes[node] = False
return False
for partition in partitions:
# Note it's ok to use `set` here, since we are only query if a node
# has been removed. We are NEVER going to iterate on nodes inside
# the set.
remove_node: Set[Node] = set()
for node in partition.nodes:
if is_non_compute_node(node) and \
(is_transparent_input_node(node, set(partition.nodes), remove_node) or
is_transparent_output_node(node, set(partition.nodes), remove_node)):
remove_node.add(node)
if len(remove_node) != 0:
for node in remove_node:
partition.nodes.pop(node, None)
def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule:
partitions = self.propose_partitions()
fused_gm = self.fuse_partitions(partitions, prefix=prefix)
return fused_gm

View File

@ -0,0 +1,73 @@
# mypy: allow-untyped-defs
import abc
from collections import namedtuple
from typing import Optional
from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
__all__ = ['PassResult', 'PassBase']
@compatibility(is_backward_compatible=False)
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
"""
Result of a pass:
graph_module: The modified graph module
modified: A flag for if the pass has modified the graph module
"""
def __new__(cls, graph_module, modified):
return super().__new__(cls, graph_module, modified)
@compatibility(is_backward_compatible=False)
class PassBase(abc.ABC):
"""
Base interface for implementing passes.
It is required to implement the `call` function so that we can directly
pass instances of the Pass directly to the PassManager and call them as a
function.
We can directly pass an instance of a class implementing this interface into
the PassManager's `passes` attribute.
"""
def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
"""
Runs the precondition check, the pass itself, and the postcondition check.
"""
self.requires(graph_module)
res = self.call(graph_module)
self.ensures(graph_module)
return res
@abc.abstractmethod
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
"""
The pass that is run through the given graph module. To implement a
pass, it is required to implement this function.
Args:
graph_module: The graph module we will run a pass on
"""
def requires(self, graph_module: GraphModule) -> None: # noqa: B027
"""
This function will be called before the pass is run and will check that
the given graph module contains the preconditions needed to run the
pass. It is not required to implement this function.
Args:
graph_module: The graph module we will run checks on
"""
def ensures(self, graph_module: GraphModule) -> None: # noqa: B027
"""
This function will be called after the pass is run and will check that
the given graph module contains the postconditions needed to run the
pass. It is not required to implement this function.
Args:
graph_module: The graph module we will run checks on
"""

View File

@ -0,0 +1,302 @@
# mypy: allow-untyped-defs
import inspect
import logging
from queue import Queue
from functools import wraps
from typing import Callable, Dict, List
import torch.nn as nn
from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
@compatibility(is_backward_compatible=False)
def pass_result_wrapper(fn: Callable) -> Callable:
"""
Wrapper for passes which currently do not return a PassResult.
This wrapper makes them return a PassResult containing the modified object
and True for the "modified" flag.
Args:
fn (Callable[Module, Any])
Returns:
wrapped_fn (Callable[Module, PassResult])
"""
if fn is None:
return None
@wraps(fn)
def wrapped_fn(gm):
res = fn(gm)
if res is None:
return PassResult(gm, True)
if isinstance(res, PassResult):
return res
elif isinstance(res, nn.Module):
return PassResult(res, True)
if not inspect.isfunction(fn):
wrapped_fn.__name__ = type(fn).__name__
return wrapped_fn
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
) -> None:
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
if constraint(a, b):
continue
raise RuntimeError(
f"pass schedule constraint violated. Expected {a} before {b}"
f" but found {a} at index {i} and {b} at index{j} in pass"
f" list."
)
def _topological_sort_passes(
passes: List[Callable], constraints: List[Callable]
) -> List[Callable]:
"""
Args
passes: Passes that we are ordering
constraints: Constraints applied on these passes
Returns
A sorted list of callables and a boolean of if a circular dependency
existed
"""
if len(constraints) == 0:
return passes
# Contruct a graph mapping nodes to a list of their users
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
candidates: Queue = Queue()
for a in passes:
for b in passes:
if a == b:
continue
for constraint in constraints:
if not constraint(a, b):
graph[b].append(a)
indegree_map[a] += 1
if indegree_map[a] == 0:
candidates.put(a)
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
sorted_passes: List[Callable] = []
while not candidates.empty():
p = candidates.get()
sorted_passes.append(p)
visited[p] = True
for n in graph[p]:
if not visited[n]:
indegree_map[n] -= 1
if indegree_map[n] == 0:
candidates.put(n)
# Check if there are unvisited nodes (aka cycles in the graph)
cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
if len(cycle_passes) != 0:
error = f"Circular dependency detected within the following passes: {cycle_passes}"
raise RuntimeError(error)
return sorted_passes
@compatibility(is_backward_compatible=False)
def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
"""
Defines a partial order ('depends on' function) where `this` must occur
before `that`.
For example, the following pass list and constraint list would be invalid.
```
passes = [pass_b, pass_a]
constraints = [
this_before_that_pass_constraint(pass_a, pass_b)
]
```
Args:
this (Callable): pass which should occur first
that (Callable): pass which should occur later
Returns:
depends_on (Callable[[Object, Object], bool]
"""
def depends_on(a: Callable, b: Callable):
return a != that or b != this
return depends_on
@compatibility(is_backward_compatible=False)
class PassManager:
"""
Construct a PassManager.
Collects passes and constraints. This defines the pass schedule, manages
pass constraints and pass execution.
Args:
passes (Optional[List[Callable]]): List of passes. A pass is a
callable which modifies an object and returns a PassResult
constraint (Optional[List[Callable]]): List of constraints. A
constraint is a callable which takes two passes (A, B) and returns
True if A depends on B and False otherwise. See implementation of
`this_before_that_pass_constraint` for example.
steps (int): Max number of times we run the passes (default = 1).
run_checks_after_each_pass (bool): Whether to run checks and linting
after each pass
suppress_check_failures (bool): Whether to raise errors when running
checks
"""
passes: List[Callable[[nn.Module], PassResult]]
constraints: List[Callable[[Callable, Callable], bool]]
_validated: bool = False
steps: int = 1
def __init__(
self,
passes=None,
constraints=None,
steps=None,
run_checks_after_each_pass: bool = False,
suppress_check_failures: bool = False,
):
self.passes = passes or []
self.constraints = constraints or []
if steps:
self.steps = steps
self.run_checks_after_each_pass = run_checks_after_each_pass
self.suppress_check_failures = suppress_check_failures
def add_pass(self, _pass: Callable):
"""
Adds a pass into the current list of passes.
"""
self.passes.append(_pass)
self._validated = False
def add_constraint(self, constraint: Callable):
"""
Adds a constraint into the current list of constraints.
"""
self.constraints.append(constraint)
self._validated = False
def validate_constraints(self):
"""
Validates that current pass schedule defined by `self.passes` is valid
according to all constraints in `self.constraints`
"""
if self._validated:
return
for constraint in self.constraints:
_validate_pass_schedule_constraint(constraint, self.passes)
self._validated = True
def solve_constraints(self):
"""
Finds a valid traversal order based on the given constraints and orders
the passes based on this order.
If a circular dependency exists between the constraints and steps = 1,
then we will raise an error because if steps != 1 this means that we
will re-run the passes, allowing for circular dependencies.
"""
self.passes = _topological_sort_passes(self.passes, self.constraints)
self._validated = True
def add_checks(self, check: Callable) -> None:
"""
Adds a function which takes runs various checks on a given graph module.
This function is run before and after each pass if the
`run_checks_after_each_pass` flag is enabled.
"""
sig = inspect.signature(check)
if len(list(sig.parameters.values())) != 1:
raise TypeError("PassManager check function should only take in one variable, a module")
setattr(self, "check", check) # noqa: B010
def check(self, module: nn.Module) -> None:
pass
def __call__(self, module: nn.Module) -> PassResult:
"""
Runs a list of passes in the order based on `self.passes` on the given
graph module. Each time a pass is run, checks and linting will be run on
the graph module if `run_checks_after_each_pass` is set.
If the module is a graph module, we will run the list of passes until
the graph stops changing, or until `steps` number of times.
"""
# Order the passes based on the constraints
if not self._validated:
self.solve_constraints()
# Check graph invariants
self.check(module)
# Run the set of passes `steps` number of times or until the graph stops
# changing
overall_modified = False
for _ in range(self.steps):
modified = False
# Run the set of passes on the graph module
for i, fn in enumerate(self.passes):
fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
logger.debug("Running pass '%s'", fn_name)
try:
res = fn(module)
if not isinstance(res, PassResult) and not hasattr(
res, "graph_module"
):
raise TypeError(
f"The result of the pass {fn_name} should be type PassResult."
+ "Please wrap it with pass_result_wrapper()"
)
module = res.graph_module
modified = modified or res.modified
if isinstance(module, GraphModule):
logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
module.recompile()
# Check graph invariants
if self.run_checks_after_each_pass:
self.check(module)
except Exception as e:
prev_pass_names = [
p.__name__ if inspect.isfunction(p) else type(p).__name__
for p in self.passes[:i]
]
msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
raise Exception(msg) from e # noqa: TRY002
# If the graph no longer changes, then we can stop running these passes
overall_modified = overall_modified or modified
if not modified:
break
return PassResult(module, overall_modified)

View File

@ -0,0 +1,924 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.fx
from torch.fx._compatibility import compatibility
from torch.fx.node import map_arg
from .shape_prop import ShapeProp
from .split_utils import split_by_tags
from .tools_common import (
CALLABLE_NODE_OPS,
FxNetAccFusionsFinder,
Names,
NodeList,
NodeSet,
TensorOrTensors,
Tensors,
)
__all__ = [
"FxNetMinimizerBadModuleError",
"FxNetMinimizerRunFuncError",
"FxNetMinimizerResultMismatchError",
]
_LOGGER = logging.getLogger(__name__)
@compatibility(is_backward_compatible=False)
class FxNetMinimizerBadModuleError(Exception):
"""
Raised if failed to split out a minimize module
"""
@compatibility(is_backward_compatible=False)
class FxNetMinimizerRunFuncError(Exception):
"""
Raised if error occurs during run_a or run_b functions
"""
@compatibility(is_backward_compatible=False)
class FxNetMinimizerResultMismatchError(Exception):
"""
Raised if comparing function thinks the results are mismatching.
"""
@dataclass
class _MinimizerSettingBase:
"""
Args:
`accumulate_error`: Instead of using a's input for both converted module to verify
, use the previous outputs of each converted module as input to accumulate the
errors.
`traverse_method`: "sequential" or "binary" or "accumulate"
Determine the way of traverse the nodes in FX module.
`find_all`: Minimizer will go through the entire model and return all problematic nodes.
`return_intermediate`: If true, when using `run_nodes()` function to run the
model, intermediate results of all the ops will be returned as output.
"""
accumulate_error: bool = False
traverse_method: str = "sequential"
find_all: bool = False
return_intermediate: bool = False
def __str__(self):
settings_str = "FX Minimizer Settings:\n"
for k, v in vars(self).items():
settings_str += f"\t{k}: {v}\n"
return settings_str
class _MinimizerBase:
"""
This class is used to automatically find problematic nodes in a model. It takes a FX
graphmodule and generate some submodules while traverse the graph. Then two functions
`run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
will be used to compare the results.
Currently we provides two ways to traverse the graph and generate submodules.
1. Sequential traversal: this will traverse the graph node by node and generate
one submodule with one sigle node.
2. Binary searching: this will do a binary search style traversal on the graph.
For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
"""
def __init__(
self,
module: torch.fx.GraphModule,
sample_input: Tensors,
compare_fn: Callable[
[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
],
settings: _MinimizerSettingBase,
module_exporter: Optional[
Callable[
[Tensors, torch.fx.GraphModule, str],
None
]
] = None,
exclusion_fn: Optional[
Callable[[NodeList, int, int], None]
] = None,
):
assert isinstance(module, torch.fx.GraphModule)
self.module = module
self.sample_input = sample_input
self.compare_fn = compare_fn
self.module_exporter = module_exporter
self.settings = settings
self.exclusion_fn = exclusion_fn
# Stores outputs of run_a function
self.a_outputs: Dict[str, Any] = {}
# Stores outputs of run_b function
self.b_outputs: Dict[str, Any] = {}
# Stores the results of compare_fn
self.results: Dict[Any, Any] = {}
# Stores the report for the runs
self.reports: List[List[str]] = []
# Current iteration
self.iteration: int = 0
callable_nodes = {
node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
}
ShapeProp(self.module).propagate(*self.sample_input)
self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
# Check if number of input in sample_input matches the number of placeholders
placeholders = [
node.name for node in self.module.graph.nodes if node.op == "placeholder"
]
assert len(placeholders) == len(self.sample_input)
# Store sample_input
for i, name in enumerate(placeholders):
self.a_outputs[name] = sample_input[i]
self.b_outputs[name] = sample_input[i]
def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
"""
Run `mod` with `inputs` and generate output. The output will be compared with
output of run_b().
"""
raise RuntimeError("run_a() is not implemented.")
def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
"""
Run `mod` with `inputs` and generate output. The output will be compared with
output of run_a().
"""
raise RuntimeError("run_b() is not implemented.")
def _store_outputs(
self,
a_result: TensorOrTensors,
b_result: TensorOrTensors,
submodule: torch.fx.GraphModule,
):
"""
Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
self.b_outputs, so that we can use them when execute preceding nodes that
use those outputs as inputs.
Args:
a_result: Output of self.run_a(). Could be a tensor or tensors.
b_result: Output of self.run_b(). Could be a tensor or tensors.
submodule: The module that generates a_result and b_result.
"""
output_node = next(
node for node in submodule.graph.nodes if node.op == "output"
)
# Only one output
if isinstance(output_node.args[0], torch.fx.Node):
self.a_outputs[output_node.args[0].name] = a_result
self.b_outputs[output_node.args[0].name] = b_result
# Multiple outputs
else:
for i, arg in enumerate(output_node.args[0]):
self.a_outputs[arg.name] = a_result[i]
self.b_outputs[arg.name] = b_result[i]
def _get_submod_inputs(
self, main_module: torch.fx.GraphModule, submod_path: str
) -> Tuple[Tensors, Tensors]:
"""
Try get submodule inputs from stored outputs. If not found then use
torch_glow.get_submod_inputs to get the inputs.
If accumulate_error is False, use a_input for run_a() and run_b()
otherwise use a_input for run_a and b_input for run_b.
Args:
main_module: Top-levlel fx module.
submod_path: Path to the submodule we want to run and compare results.
Returns:
a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
"""
a_input = []
b_input = []
submodule = getattr(main_module, submod_path)
placeholders = [
node.name for node in submodule.graph.nodes if node.op == "placeholder"
]
# If all placeholder can be found in stored outputs, use stored
# outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
# to get the inputs.
if set(placeholders) <= self.a_outputs.keys():
for name in placeholders:
a_input.append(self.a_outputs[name])
b_input.append(self.b_outputs[name])
else:
if self.settings.accumulate_error:
print(f"Can't find previous stored outputs named {placeholders}!")
def get_inputs(self: torch.nn.Module, inputs: Any):
nonlocal a_input
a_input = inputs
# Use forward hook to get the inputs to the submodule
handle = submodule.register_forward_pre_hook(get_inputs)
main_module(*self.sample_input)
handle.remove()
b_input = a_input
if not self.settings.accumulate_error:
return a_input, a_input
return a_input, b_input
def _tag_nodes(self, selected_nodes: NodeSet):
"""
Tag selected nodes with tag "minimize". Nodes with the same tags will
be split to the same submodule afterwards.
Args:
selected_nodes: Nodes that we want to minimize. We will tag those nodes
with "minimize", all preceding nodes with "main_0" and all following
nodes with "main_1".
"""
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
if node in selected_nodes:
node.tag = "minimize"
elif any(
n.tag in {"minimize", "main_1"}
for n in node.all_input_nodes
if n.op in CALLABLE_NODE_OPS
):
node.tag = "main_1"
else:
node.tag = "main_0"
def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
"""
Split self.module so that one submodule consists of `nodes` and only `nodes`.
Args:
nodes: Nodes that we want to include in the minimize submodule.
Returns:
split_module (torch.fx.GraphModule): the module after split.
submodule_name (str): the name of the submodule that consists of `nodes`.
"""
# Color provided nodes
self._tag_nodes(nodes)
# Split module based on coloring
split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
# Find submodule containing colored nodes
submodule_name: str = ""
for child_name, _ in split_module.named_children(): # type: ignore[union-attr]
# Skip submodules we're not interested in at the moment
if "minimize" not in child_name:
continue
if submodule_name == "":
submodule_name = child_name
else:
raise FxNetMinimizerBadModuleError(
f"Expected only one minimize submodule with nodes {nodes}"
)
if submodule_name == "":
raise FxNetMinimizerBadModuleError(
f"Minimize submodule was not found with nodes {nodes}"
)
return split_module, submodule_name # type: ignore[return-value]
def _run_and_compare(
self,
split_module: torch.fx.GraphModule,
submod_name: str,
output_names: Names,
report_idx: int = -1
):
"""
Run the submodule in `split_module` that has name `submod_name`
using `self.run_a` and `self.run_b` and compare their results.
Args:
split_module: Main module that contains the minimize submodule.
submod_name: Name of the minimize submodule.
output_names: Names of the node we want to output. If None, we
will use the original output.
"""
submodule = getattr(split_module, submod_name)
a_input, b_input = self._get_submod_inputs(split_module, submod_name)
if len(self.reports) == 0:
self.reports.append([])
self.iteration = 1
report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1]
report.append("Run and compare ...")
if output_names:
output_nodes: NodeList = []
for node in submodule.graph.nodes:
if node.op == "output":
submodule.graph.erase_node(node)
if node.name in output_names:
output_nodes.append(node)
submodule.graph.output(
output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
)
submodule.graph.lint()
submodule.recompile()
# Use name of args in output node as key to store comparison result
for node in submodule.graph.nodes:
if node.op == "output":
result_key = map_arg(node.args, lambda x: x.name)
try:
a_result = self.run_a(submodule, a_input, report_idx)
b_result = self.run_b(submodule, b_input, report_idx)
self._store_outputs(a_result, b_result, submodule)
except Exception as e:
report.append(f"Exception raised when running {submod_name}: {e}")
raise FxNetMinimizerRunFuncError( # noqa: B904
f"Exception raised when running {submod_name}: {e}"
)
# Compare results
names: Names = output_names
if output_names is None:
names = [str(v) for v in result_key] # type: ignore[possibly-undefined]
numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
self.results[result_key] = numeric_result # type: ignore[possibly-undefined]
report.append(f"Numerical accuracy = {numeric_result}")
if not bool_result:
report.append(f"Result mismatch for {result_key}")
if self.module_exporter:
self.module_exporter(
a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index]
)
self.module_exporter(
b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index]
)
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
def _binary_search_impl(
self, all_nodes: NodeList, start_idx: int, end_idx: int
) -> NodeSet:
"""
Recursive binary search implementation.
"""
culprits: NodeSet = set()
nodes: NodeList = all_nodes[start_idx:end_idx]
report: List[str] = []
if self.exclusion_fn is not None:
self.exclusion_fn(nodes, start_idx, end_idx)
if len(nodes) == 0:
report = ["All nodes are excluded by user"]
self.reports.append(report)
return culprits
first_node_name = nodes[0].name
output_node_name = nodes[-1].name
self.iteration += 1
self.reports.append(report)
report.append(f"Binary search iteration {self.iteration}")
report.append(
f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. "
f"Size of the interested node list is {len(nodes)}"
)
cur_nodes: NodeSet = set(nodes)
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [output_node_name])
except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
if len(nodes) == 1:
report.append(
f"This is the last node in the sub-module. "
f"Search in the current branch is successful with culprit = {cur_nodes}."
)
self.print_report(report)
return cur_nodes
report.append(
"Proceed to split and lower the halves of the current "
"sub-module individually."
)
self.print_report(report)
mid = len(nodes) // 2
culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
if len(culprits) != 0 and not self.settings.find_all:
return culprits
culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
if len(culprits) == 0:
report.append(
f"Further split and lowering found no errors. "
f"Unable to minimize the submodule with list of nodes: {nodes}"
)
self.print_report(report)
return culprits
else:
report.append("No discrepancy found.")
self.print_report(report)
return set()
def _binary_traverse(self, nodes: NodeList) -> NodeSet:
"""
Binary search on `nodes` for culprit.
"""
return self._binary_search_impl(nodes, 0, len(nodes))
def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
"""
Traverse `nodes` one by one and determine if any of them is a culprit.
"""
culprits: NodeSet = set()
for node in nodes:
report: List[str] = []
self.reports.append(report)
self.iteration += 1
report.append(f"Sequential traverse iteration {self.iteration}.")
report.append(f"Visit node: {node.name}")
_LOGGER.info("Visit node: %s", node.name)
node_list: NodeList = [node]
if self.exclusion_fn is not None:
self.exclusion_fn(node_list, -1, -1)
if len(node_list) == 0:
report.append(f"User exclusion : {node.name}")
self.print_report(report)
if not self.settings.find_all:
return culprits
else:
continue
cur_nodes: NodeSet = {node}
if node in self.fusions:
cur_nodes = self.fusions[node]
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [node.name])
self.print_report(report)
except (FxNetMinimizerResultMismatchError):
culprits.add(node)
report.append(f"Found culprit from numeric error: {node}")
self.print_report(report)
if not self.settings.find_all:
return culprits
except (FxNetMinimizerRunFuncError):
culprits.update(cur_nodes)
report.append(f"Found culprit from run error: {node}")
self.print_report(report)
if not self.settings.find_all:
return culprits
return culprits
def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int:
"""
Recursive block search implementation.
find_last_node: If True, search for the last node which result in numerics difference
if False: find first node in sorted node list
"""
report: List[str] = []
mid = (start_idx + end_idx) // 2
cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:]
if self.exclusion_fn:
self.exclusion_fn(cur_nodes_list, -1, -1)
cur_nodes = set(cur_nodes_list)
first_node_name = cur_nodes_list[0].name
last_node_name = cur_nodes_list[-1].name
target_node_name = last_node_name if find_last_node else first_node_name
self.iteration += 1
self.reports.append(report)
report.extend(
[
"=" * 30,
f"Block search iteration {self.iteration}",
]
)
report.extend(
[
f"Search for {'last' if find_last_node else 'first'} node in culprits",
f"From node index {start_idx}:{nodes[start_idx].name} to {end_idx}:{nodes[end_idx].name}. ",
f"Subgraph constructed by {first_node_name} to {last_node_name}",
f"Targeting node: {target_node_name}",
f"Size of the interested node list is {end_idx - start_idx + 1}",
]
)
report_idx = len(self.reports) - 1
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [last_node_name], report_idx)
except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
report.append(f"Culprits found from node {first_node_name} to {last_node_name}.")
if start_idx == mid:
report.extend(
[
"This is the last node in the sub-module. ",
"Search in the current branch is successful with node :",
f"{start_idx}, node name: {nodes[start_idx].name}."
]
)
self.print_report(report)
return start_idx
report.append(
"Proceed to split and lower the halves of the current "
"sub-module individually."
)
self.print_report(report)
if find_last_node:
return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
else:
return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
else:
report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.")
if start_idx == mid:
report.extend(
[
"This is the last node in the sub-module. ",
"Search in the current branch is successful with node",
f"{start_idx}, node name: {nodes[start_idx].name}.",
]
)
self.print_report(report)
return start_idx + 1 if find_last_node else start_idx - 1
report.append(
"Proceed to split and lower the halves of the current "
"sub-module individually."
)
self.print_report(report)
if find_last_node:
return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
else:
return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet:
"""
Traverse topologically sorted node list
Find minimium block (start_idx, end_idx) which contains the culprit
1st pass: search for end_idx by finding the last node in culprit block
where Numerical accuracy (0, end_idx) > threshold
2nd pass: search for start_idx by finding the first node in culprit block
where Numerical accuracy (start_idx, end_idx) < threshold
Form minimum block by (start_idx - 1, end_idx)
"""
culprits: NodeSet = set()
first_node_name = nodes[0].name
last_node_name = nodes[-1].name
last_node_report = [f"Block search from {first_node_name} to {last_node_name}"]
last_node_report.append("*" * 50)
self.reports.append(last_node_report)
start_idx = 0
end_idx = len(nodes) - 1
run_both = True if find_last_node is None else False
# step 1: find (0, end_idx) of culprit block
if run_both or find_last_node:
last_node_report.append("Start searching for last node in culprit")
self.print_report(last_node_report)
end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True)
last_node_report.extend(
[
"Finish Pass 1",
f"Find end_idx = {end_idx}:{nodes[end_idx].name}"
]
)
self.print_report(last_node_report)
# step 2: reduce culprit block to (start_idx, end_idx)
if run_both or not find_last_node:
first_node_report = ["Start searching for first node in culprit"]
self.print_report(first_node_report)
start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False)
first_node_report.append("*" * 50)
self.reports.append(first_node_report)
first_node_report.extend(
[
"Finish Pass 2",
f"Find start_idx = {start_idx}:{nodes[start_idx].name}"
]
)
self.print_report(first_node_report)
# step 3: form module with minimum culprits
culprits.update(nodes[start_idx:end_idx + 1])
result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"]
self.reports.append(result_report)
self.print_report(result_report)
return culprits
def _defined_traverse(self, nodes: NodeList) -> NodeSet:
"""
run user defined `nodes` and determine if it is a culprit.
"""
culprits: NodeSet = set()
if self.exclusion_fn is not None:
self.exclusion_fn(nodes, -1, -1)
if len(nodes) == 0:
report = ["All nodes are excluded by user"]
self.reports.append(report)
return culprits
first_node_name = nodes[0].name
output_node_name = nodes[-1].name
report = [f"Defined graph from {first_node_name} to {output_node_name}"]
cur_nodes: NodeSet = set(nodes)
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [output_node_name])
self.print_report(report)
except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
report.append(f"Found culprit {cur_nodes}")
self.print_report(report)
return culprits
return culprits
def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:
culprits: NodeSet = set()
nodes_to_run: NodeSet = set()
# find_all is not supported for accumulate traversal because all the
# ops run on NNPI. So we return after the first op that raises error.
if self.settings.find_all:
print("'Find All' mode is not supported in accumulate traversal.")
return culprits
for node in nodes:
report: List[str] = []
self.reports.append(report)
self.iteration += 1
report.append(f"Accumulate traverse iteration {self.iteration}.")
nodes_to_run.add(node)
node_name = node.name
if node_name is not None and isinstance(node_name, tuple):
node_name = node_name[0]
assert node_name is not None and isinstance(
node_name, str
), f"minimize: node_name: {node_name}"
report.append(f"Add node: {node_name}")
try:
split_module, submod_name = self._build_submodule(nodes_to_run)
self._run_and_compare(split_module, submod_name, [node_name])
self.print_report(report)
except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
culprits.add(node)
report.append(f"Found culprit {node}")
self.print_report(report)
return culprits
return culprits
def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
"""
Skip certain nodes in graph based on settings
"""
culprits: NodeSet = set()
nodes: NodeList = all_nodes[start_idx:end_idx]
cur_nodes: NodeSet = set(nodes)
if self.exclusion_fn is not None:
self.exclusion_fn(nodes, start_idx, end_idx)
cur_nodes = set(nodes)
else:
for node in nodes:
if node in self.fusions:
cur_nodes.update(self.fusions[node])
report: List[str] = []
self.reports.append(report)
self.iteration += 1
report.append(f" Nodes block {self.iteration}.")
report.append(
f"From node index {start_idx} to {end_idx-1}. "
f"Size of the interested node list is {len(nodes)}"
)
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [])
except (FxNetMinimizerResultMismatchError):
culprits.update(cur_nodes)
report.append(f"Found culprit from numeric error: {cur_nodes}")
self.print_report(report)
return culprits
except (FxNetMinimizerRunFuncError):
culprits.update(cur_nodes)
report.append(f"Found culprit from run error: {cur_nodes}")
self.print_report(report)
return culprits
else:
report.append("No discrepancy found.")
self.print_report(report)
return set()
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
"""
Skip certain nodes in graph based on settings
"""
start_idx = 0
num_nodes = len(all_nodes)
idx = 0
culprits = set()
while idx < num_nodes:
node = all_nodes[idx]
if (node.name in skip_nodes): # skip the node
if idx > start_idx:
culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
start_idx = idx + 1
elif idx == num_nodes - 1 and start_idx <= idx: # last node
culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1)
idx += 1
return culprits
def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
"""
Collect nodes in the model that between nodes with name of `start` and `end`.
These two nodes are also included.
"""
nodes: NodeList = []
add_node = start is None
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
if node.name == start:
add_node = True
if add_node:
nodes.append(node)
if node.name == end:
break
return nodes
def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None):
"""
Run part of the model from `start` node to `end` node. If `start` is None
then we start from the beginning of the model. If `end` is None then we
stop at the end of the model.
Args:
start: The name of the node which is the first node of the submodule
we want to run. If set to None, then we'll start with the first
node of the model.
end: The name of the node which is the last node of the submodule we
want to run. If set to None, we'll end with the last node of the
model.
"""
nodes = self._collect_nodes(start, end)
cur_nodes = set(nodes)
for node in nodes:
if node in self.fusions:
cur_nodes.update(self.fusions[node])
output_names = []
if self.settings.return_intermediate:
output_names = [node.name for node in nodes]
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, output_names)
except (
FxNetMinimizerRunFuncError,
FxNetMinimizerResultMismatchError,
) as e:
print(e)
def print_report(self, report: List[str]):
for i in range(len(report)):
if i > 0:
print(" . " + report[i])
else:
print(report[i])
def print_reports(self):
for report in self.reports:
self.print_report(report)
def minimize(
self,
start: Optional[str] = None,
end: Optional[str] = None,
skip_nodes: Optional[List] = None,
find_last_node: Optional[bool] = None,
) -> NodeSet:
"""
Minimizing the model from node with name `start` to node with name `end` base
on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
FxNetMinimizerResultMismatchError errors.
Args:
start: The name of the node where we want to start minimizing. If set
to None, then we'll start with the first node of the model.
end: The name of the node where we want to terminate minimizing. If
set to None, we'll end with the last node of the model.
skip_nodes: The names of nodes where we want to skip during minimizing.
It'll create subgraphs without these skip nodes under the hood.
Only applicable in mode "skip".
find_last_node: True if only last_node of a culprits is needed in mode "block".
False if only the first_node of a culprits is needed.
Only applicable in mode "block".
Returns:
nodes: A list of nodes that causes FxNetMinimizerRunFuncError or
FxNetMinimizerResultMismatchError errors during minimizing.
"""
print(self.settings)
print(self.module.graph)
nodes = self._collect_nodes(start, end)
if self.settings.traverse_method == "sequential":
return self._sequential_traverse(nodes)
if self.settings.traverse_method == "binary":
return self._binary_traverse(nodes)
if self.settings.traverse_method == "accumulate":
return self._accumulate_traverse(nodes)
if self.settings.traverse_method == "skip":
if (skip_nodes is None):
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
return self._skip_traverse(nodes, skip_nodes)
if self.settings.traverse_method == "defined":
return self._defined_traverse(nodes)
if self.settings.traverse_method == "block":
return self._block_traverse(nodes, find_last_node)
raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!")

View File

@ -0,0 +1,215 @@
# mypy: allow-untyped-defs
import abc
import typing as t
import torch
import torch.fx
from torch.fx._compatibility import compatibility
from .shape_prop import TensorMetadata
from .tools_common import get_node_target, CALLABLE_NODE_OPS
__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
# fx.Node.target typename, as returned by `get_node_target()`
TargetTypeName = str
# Arguments' dtypes for a given node, see `OperatorSupport`
SupportedArgumentDTypes = t.Optional[
t.Tuple[
t.Sequence[t.Sequence[torch.dtype]],
t.Dict[str, t.Sequence[torch.dtype]],
]
]
SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
@compatibility(is_backward_compatible=False)
class OperatorSupportBase(abc.ABC):
"""Interface for determining if a fx.Node is supported by a backend"""
@abc.abstractmethod
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
raise NotImplementedError
@compatibility(is_backward_compatible=False)
class OperatorSupport(OperatorSupportBase):
"""
`_support_dict` maps node.target typename to supported inputs dtypes.
node.target typename is retrieved using helper function `get_node_target()`
If supported inputs dtypes is None, it means any dtype is supported, else
we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
The first tuple ([dtypes], ...) indicates what dtypes are supported for
inputs in node.args and the second dict {"name": [dtypes], ...} indicates
what dtypes are supported for inputs in node.kwargs.
For inputs in args, if we don't want to check it, we can put None there,
e.g. (None, [torch.float]) indicates that we don't care about the type of
the first input in args. And for inputs in kwargs, if not listed, will not
be checked.
"""
_support_dict: SupportDict
def __init__(
self,
support_dict: t.Optional[SupportDict] = None
):
self._support_dict = support_dict or {}
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
"""
Args:
`submodules`: mapping from module name to the module. This can be
retrieved by calling model.named_modules().
`node`: a Fx node that we want to determine whether it's supported.
Returns:
`is_supported`: whether the arg `node` is supported.
"""
if node.op not in CALLABLE_NODE_OPS:
return True
target = get_node_target(submodules, node)
# Target not found in _support_dict meaning that we don't support this op at all
if target not in self._support_dict:
return False
# The rule for target is None meaning that we accept any dtype
if self._support_dict[target] is None:
return True
args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc]
# Check args dtypes
for i, dtypes in enumerate(args_dtypes):
if len(node.args) <= i:
break
# None indicates we don't care about the dtype of args[i]
if dtypes is None:
continue
# If arg is not a node then we don't check it
if not isinstance(node.args[i], torch.fx.Node):
continue
arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type]
if arg_dtype not in dtypes:
return False
# Check kwargs dtypes
for k, dtypes in kwargs_dtypes.items():
if k not in node.kwargs:
continue
# If arg is not a node then we don't check it
if not isinstance(node.kwargs[k], torch.fx.Node):
continue
kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type]
if kwarg_dtype not in dtypes:
return False
return True
# ======================================================================
# Functional interfaces and utils for defining basic operator support logic
# and composing them into more complex ones
# ======================================================================
IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
@compatibility(is_backward_compatible=False)
def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
"""Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
`IsNodeSupported` has the same call signature as
`OperatorSupportBase.is_node_supported`
"""
class FunctionalOperatorSupport(OperatorSupportBase):
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
return is_node_supported(submodules, node)
return FunctionalOperatorSupport()
@compatibility(is_backward_compatible=False)
def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
"""Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
instance by evaluating each input `OperatorSupportBase` instance, and returns False if
any of it reports False.
"""
def _chain(submods, node) -> bool:
return all(
x.is_node_supported(submods, node)
for x in op_support
)
return create_op_support(_chain)
@compatibility(is_backward_compatible=False)
def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
"""Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
instance by evaluating each input `OperatorSupportBase` instance, and returns True if
any of it reports True.
"""
def _any_chain(submods, node) -> bool:
return any(
x.is_node_supported(submods, node)
for x in op_support
)
return create_op_support(_any_chain)
@compatibility(is_backward_compatible=False)
class OpSupports:
"""A set of atomic `OperatorSupportBase` instances that can be combined together
to form more complex operator support logic.
"""
@classmethod
def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
"""Report a node as non-supported, if any of its arguments is of dtype"""
def _decline_if_input_dtype(
submodules: t.Mapping[str, torch.nn.Module],
node: torch.fx.Node,
) -> bool:
for arg in node.all_input_nodes:
arg_dtype = _get_arg_dtype(arg)
if arg_dtype == dtype:
return False
return True
return create_op_support(_decline_if_input_dtype)
@classmethod
def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
"""
If a node has a name that is in the disallow set, reported it as non-supported.
"""
def _decline_if_node_in_names(
submodules: t.Mapping[str, torch.nn.Module],
node: torch.fx.Node,
) -> bool:
return node.name not in disallow_set
return create_op_support(_decline_if_node_in_names)
def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
assert isinstance(arg, torch.fx.Node)
tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
return dtype

View File

@ -0,0 +1,66 @@
from torch.fx.graph_module import GraphModule
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
import torch.nn as nn
from torch.fx._compatibility import compatibility
__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
# Matching method matches the attribute name of current version to the attribute name of `target_version`
@compatibility(is_backward_compatible=False)
def default_matching(name: str, target_version: int) -> str:
"""Default matching method
"""
return name
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
torch.nn.modules.conv.Conv2d: (
1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching
),
torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching),
torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
torch.nn.modules.pooling.MaxPool2d: (
1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching
),
torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
}
@compatibility(is_backward_compatible=False)
def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
after checking module's version is compatible with the `module_fetch_book`.
"""
attrs_for_lowering: Dict[str, Any] = {}
attrs_for_lowering["name"] = torch.typename(mod)
if type(mod) in module_fetch_book:
version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
if version < mod._version:
raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
"please upgrade the module_fetch_book, open an issue and @842974287 "
"or report a bug to AIACC team directly.")
for attr in param_to_fetch:
attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
else:
raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, "
"please add it to the module_fetch_book, open an issue and @842974287 "
"or report a bug to AIACC team directly.")
return attrs_for_lowering
@compatibility(is_backward_compatible=False)
def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
"""Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
"""
submodules = dict(fx_module.named_modules())
for node in fx_module.graph.nodes:
if node.op == "call_module":
if isinstance(submodules[node.target], GraphModule):
lift_lowering_attrs_to_nodes(submodules[node.target])
else:
node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target])

View File

@ -0,0 +1,254 @@
# mypy: allow-untyped-defs
from functools import wraps
from inspect import unwrap
from typing import Callable, List, Optional
import logging
logger = logging.getLogger(__name__)
__all__ = [
"PassManager",
"inplace_wrapper",
"log_hook",
"loop_pass",
"this_before_that_pass_constraint",
"these_before_those_pass_constraint",
]
# for callables which modify object inplace and return something other than
# the object on which they act
def inplace_wrapper(fn: Callable) -> Callable:
"""
Convenience wrapper for passes which modify an object inplace. This
wrapper makes them return the modified object instead.
Args:
fn (Callable[Object, Any])
Returns:
wrapped_fn (Callable[Object, Object])
"""
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
return gm
return wrapped_fn
def log_hook(fn: Callable, level=logging.INFO) -> Callable:
"""
Logs callable output.
This is useful for logging output of passes. Note inplace_wrapper replaces
the pass output with the modified object. If we want to log the original
output, apply this wrapper before inplace_wrapper.
```
def my_pass(d: Dict) -> bool:
changed = False
if 'foo' in d:
d['foo'] = 'bar'
changed = True
return changed
pm = PassManager(
passes=[
inplace_wrapper(log_hook(my_pass))
]
)
```
Args:
fn (Callable[Type1, Type2])
level: logging level (e.g. logging.INFO)
Returns:
wrapped_fn (Callable[Type1, Type2])
"""
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
return val
return wrapped_fn
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
"""
Convenience wrapper for passes which need to be applied multiple times.
Exactly one of `n_iter`or `predicate` must be specified.
Args:
base_pass (Callable[Object, Object]): pass to be applied in loop
n_iter (int, optional): number of times to loop pass
predicate (Callable[Object, bool], optional):
"""
assert (n_iter is not None) ^ (
predicate is not None
), "Exactly one of `n_iter`or `predicate` must be specified."
@wraps(base_pass)
def new_pass(source):
output = source
if n_iter is not None and n_iter > 0:
for _ in range(n_iter):
output = base_pass(output)
elif predicate is not None:
while predicate(output):
output = base_pass(output)
else:
raise RuntimeError(
f"loop_pass must be given positive int n_iter (given "
f"{n_iter}) xor predicate (given {predicate})"
)
return output
return new_pass
# Pass Schedule Constraints:
#
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
# has a valid partial ordering according to this comparison operator.
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
):
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
if constraint(a, b):
continue
raise RuntimeError(
f"pass schedule constraint violated. Expected {a} before {b}"
f" but found {a} at index {i} and {b} at index{j} in pass"
f" list."
)
def this_before_that_pass_constraint(this: Callable, that: Callable):
"""
Defines a partial order ('depends on' function) where `this` must occur
before `that`.
"""
def depends_on(a: Callable, b: Callable):
return a != that or b != this
return depends_on
def these_before_those_pass_constraint(these: Callable, those: Callable):
"""
Defines a partial order ('depends on' function) where `these` must occur
before `those`. Where the inputs are 'unwrapped' before comparison.
For example, the following pass list and constraint list would be invalid.
```
passes = [
loop_pass(pass_b, 3),
loop_pass(pass_a, 5),
]
constraints = [
these_before_those_pass_constraint(pass_a, pass_b)
]
```
Args:
these (Callable): pass which should occur first
those (Callable): pass which should occur later
Returns:
depends_on (Callable[[Object, Object], bool]
"""
def depends_on(a: Callable, b: Callable):
return unwrap(a) != those or unwrap(b) != these
return depends_on
class PassManager:
"""
Construct a PassManager.
Collects passes and constraints. This defines the pass schedule, manages
pass constraints and pass execution.
Args:
passes (Optional[List[Callable]]): list of passes. A pass is a
callable which modifies an object and returns modified object
constraint (Optional[List[Callable]]): list of constraints. A
constraint is a callable which takes two passes (A, B) and returns
True if A depends on B and False otherwise. See implementation of
`this_before_that_pass_constraint` for example.
"""
passes: List[Callable]
constraints: List[Callable]
_validated: bool = False
def __init__(
self,
passes=None,
constraints=None,
):
self.passes = passes or []
self.constraints = constraints or []
@classmethod
def build_from_passlist(cls, passes):
pm = PassManager(passes)
# TODO(alexbeloi): add constraint management/validation
return pm
def add_pass(self, _pass: Callable):
self.passes.append(_pass)
self._validated = False
def add_constraint(self, constraint):
self.constraints.append(constraint)
self._validated = False
def remove_pass(self, _passes: List[str]):
if _passes is None:
return
passes_left = []
for ps in self.passes:
if ps.__name__ not in _passes:
passes_left.append(ps)
self.passes = passes_left
self._validated = False
def replace_pass(self, _target, _replacement):
passes_left = []
for ps in self.passes:
if ps.__name__ == _target.__name__:
passes_left.append(_replacement)
else:
passes_left.append(ps)
self.passes = passes_left
self._validated = False
def validate(self):
"""
Validates that current pass schedule defined by `self.passes` is valid
according to all constraints in `self.constraints`
"""
if self._validated:
return
for constraint in self.constraints:
_validate_pass_schedule_constraint(constraint, self.passes)
self._validated = True
def __call__(self, source):
self.validate()
out = source
for _pass in self.passes:
out = _pass(out)
return out

View File

@ -0,0 +1,675 @@
# mypy: allow-untyped-defs
import torch
from torch.fx import Node
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._pytree import tree_map_only
from torch.utils import _pytree as pytree
from torch.multiprocessing.reductions import StorageWeakRef
import _operator
from enum import Enum
import itertools
from typing import Set, Dict
from collections import defaultdict
__all__ = ['reinplace']
class _ViewType(Enum):
NonView = 0
SingleOutputView = 1
MultiOutputView = 2
def _is_view_op(tgt):
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
schema = tgt._schema
if len(schema.arguments) > 0:
first_arg = schema.arguments[0]
# check if op is a view
return first_arg.alias_info is not None and not first_arg.alias_info.is_write
def _get_view_type(tgt) -> _ViewType:
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
schema = tgt._schema
if len(schema.arguments) > 0:
first_arg = schema.arguments[0]
# check if op is a view
if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
# check if op is a multi-output view
if '*' in first_arg.alias_info.after_set:
return _ViewType.MultiOutputView
else:
return _ViewType.SingleOutputView
return _ViewType.NonView
# Stores a bunch of metadata related to functionalization each node.
# Relevant metadata:
# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
# The fake tensor output from running the current node
# n.meta['view_of']: Node
# If the current node n is a view of some base tensor, the 'view_of' field tells us which
# view node was used to generate the current node (a view tensor).
# This information actually makes `fake_result` redundant, but we can use `fake_result`
# to sanity check that our aliasing information is correct.
@compatibility(is_backward_compatible=False)
class _FunctionalizationMetadataProp(torch.fx.Interpreter):
def run_node(self, node: Node):
self.node_counter += 1
result = super().run_node(node)
node.meta['fake_result'] = result
node.meta['node_idx'] = self.node_counter
# (1) Update metadata with the list of nodes that are used by this node
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
# We don't want to treat it as "being used as an input".
node_args = node.args
if node.target is torch.ops.aten.copy_.default:
node_args = node_args[1:]
# (2) Update metadata to track aliasing information about view tensor nodes.
if node.op == 'call_function':
view_type = _get_view_type(node.target)
if view_type == _ViewType.SingleOutputView:
assert isinstance(node.args[0], Node)
node.meta['view_of'] = node.args[0]
elif view_type == _ViewType.MultiOutputView:
self.multi_output_view_nodes[node] = node.args[0]
# Check if we returned a multi-output view,
# and we're now grabbing the individual views from the output.
#
# For multi-output views, we want to map each output view to the base,
# but this mapping involves two separate nodes in FX IR.
# e.g. "a, b = x_1.split(...)" becomes:
# %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
# %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
# %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
# And we'd like to set:
# getitem1.meta['view_of'] = x_1
elif node.target is _operator.getitem:
list_arg = node.args[0]
maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
if maybe_base_of_view is not None:
# Note: we could also track indexing info here for multi-output views.
# I don't think this metadata is strictly needed for de-functionalization.
assert isinstance(maybe_base_of_view, Node)
node.meta['view_of'] = maybe_base_of_view
if 'view_of' in node.meta:
# We're linking the current node with its first argument as views.
# Assert here that this is actually the case, and their storages are the same.
assert isinstance(node.meta['fake_result'], FakeTensor)
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
assert view_storage == base_storage
return result
def propagate(self, *args):
self.multi_output_view_nodes = {}
self.node_counter = -1
with FakeTensorMode() as mode:
fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
return super().run(*fake_args)
def _schemas_match(functional_schema, inplace_schema):
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
# for the inplace op, its first argument should be mutable
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
# and its remaining arguments shouldn't be.
assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
return names_match and arg_types_match
# TODO: this should be beefed up to be able to properly re-inplace with:
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
# - out= ops (e.g. angle -> angle.out)
# TODO: we should also figure this info out using torchgen.
def _maybe_get_inplace_op(op):
# __module__ seems broken; it returns torch._ops.aten which doesn't exist
if not isinstance(op, torch._ops.OpOverload):
return None
# Some view ops have inplace variants (as_strided_, etc),
# but we do NOT want the reinplacing pass to directly add these into the program.
# (they'll require extra special handling, aren't aren't really useful for perf anyway)
if _is_view_op(op):
return None
op_namespace = op.__module__.split(".")[-1]
op_base_name = op.overloadpacket.__name__
maybe_namespace_module = getattr(torch.ops, op_namespace)
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
if maybe_inplace_op is None:
return None
inplace_overloads = [
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
]
inplace_overloads_with_matching_schemas = [
f
for f in inplace_overloads
if _schemas_match(op._schema, f._schema)
]
# Just because foo() and foo_() are both existing operators,
# They aren't guaranteed to have compatible schemas.
# For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant,
# Even though several overloads of pow_ exist.
if len(inplace_overloads_with_matching_schemas) == 0:
return None
assert len(inplace_overloads_with_matching_schemas) == 1
inplace_op = inplace_overloads_with_matching_schemas[0]
return inplace_op
_VIEW_INVERSE_MAP = {
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
}
# This function, given a set of set of (aliased) tensor nodes,
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
# in the node ordering.
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
def _add_if_tensor(x, set_):
if isinstance(x, FakeTensor):
set_.add(StorageWeakRef(x._typed_storage()))
nodes_used_after = set()
for t in tensor_aliases:
# get all nodes that use the current alias
usage_nodes = t.users
for n in usage_nodes:
# We only care about usages after the current node
if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
continue
# We also don't care about intermediate view ops.
# They only matter if their output is then used elsewhere
# (either in an out-of-place op, or as an output to the function).
if n in tensor_aliases:
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
continue
nodes_used_after.add(n)
return nodes_used_after
# Given an op that we're trying to re-inplace, "b = foo(a)",
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
# If there are any aliases in the alias_set(a) that satisfy:
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
# as "alias"
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
def matching_view_metadata(a, b):
return a.size() == b.size() and \
a.stride() == b.stride() and \
a.storage_offset() == b.storage_offset()
view_inverse_nodes = set()
# Go through them in node order, so we can see chains of view_scatter ops.
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
if n.target not in _VIEW_INVERSE_MAP:
continue
base = n.args[0]
mutated_view = n.args[1]
assert isinstance(base, Node)
assert isinstance(base.meta['fake_result'], FakeTensor)
assert isinstance(mutated_view, Node)
assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
# Check that this view_inverse op actually corresponds to taking doing the inverse
# of one of our existing self_alias nodes.
original_view = _VIEW_INVERSE_MAP[n.target]
for self_alias in self_aliases:
# We're looking for some alias of the self arg, "alias",
# that was created from some op `alias = foo(base, args...)`
# such that the current _scatter op "inverts" that foo call.
# We can check that by running the original op again, and checking that the strides match.
if 'view_of' not in self_alias.meta:
continue
self_alias_base = self_alias.meta['view_of']
try:
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
# of the current alias we're looking at.
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
expected_metadata = self_alias.meta['fake_result']
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
matching_view_metadata(view_replay_metadata, expected_metadata):
view_inverse_nodes.add(n)
except Exception:
continue
return view_inverse_nodes
@compatibility(is_backward_compatible=True)
def reinplace(gm, *sample_args):
"""
Given an fx.GraphModule, modifies it to perform "reinplacing",
mutating the nodes of the graph.
We look for out-of-place op call sites like `b = a.add(...)`,
and convert them to be inplace (`b = a.add_(...)`),
as long as the input to the current operator ("a") isn't re-used
anywhere later in the graph.
This pass currently expects to operate on a **functional, ATen** graph.
This can be obtained by running `make_fx(functionalize(f))`.
Sample inputs are needed to determine aliasing relationships of the inputs.
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
inputs to the program.
Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
(1) Perform some initial checks on the metadata of "a" and "args..."
that can disqualify them from being reinplaced.
(1a) Check that the self argument we're attempting to reinplace
has acceptable dtype/size metadata to reinplace with.
For example, if we have:
a = torch.ones(1)
b = torch.ones(10)
out = torch.add(a, b)
We can't turn that into
a.add_(b)
Because that would require resizing "a".
Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
because that would require changing a's dtype (from e.g. float32 to bool).
Note that in this specific example, we could technically do better..
If we see the pattern:
a_1 = a.ge(b)
a_2 = aten._to_copy(a_1, a.dtype)
Then we this should be valid to completely re-inplace
(this is exactly what functionalization will emit when it sees a.ge_(b)).
This optimization is only really important for user programs
that directly use inplace comparison ops though.
We also cannot re-inplace on tensors that have overlapping memory,
e.g. torch.ones(1).expand(4, 4).add_(1)
(1b) Check if "a" is an alias of any of the program inputs.
If it is, skip and move to the next node.
Inplace'ing an op that would cause it to mutate a program is not sound,
because that would be a side effect visible to the user.
NOTE: there's a future optimization that we should make:
if "a" is a (alias of a) program input, but later in the program
there is a node that looks like "a.copy_(...)",
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
which will later be overwritten by the copy_() call.
This will be an important optimization to have for programs that mutate
their inputs. It currently isn't implemented though.
(1c) Check if "a" and "args..." alias
For example, re-inplacing to create code like the below
isn't guaranteed to be sound:
aten.mul_(a, a)
(2) Check that "a" and all of its outstanding aliases are not used anywhere
later in the graph. If this is the case, then it's safe to re-inplace
to "b = foo_(a)".
There are a few caveats to this, explained in more detail below:
(a) If "a" is used later as an argument to a view op, that is okay.
It's only a problem if "a" (or that view) is later passed
into a normal operator, or if it is returned as the program output.
(b) If "a" is a repeat argument in `foo()`, then don't reinplace.
Most ATen kernels don't make any guarantees that this is sound,
e.g. if you do aten.mul_(a, a).
So we'll just ban re-inplacing in this case.
It's only a problem if "a" (or that view) is later passed
(c) If "a" is used as an input into a view "inverse" / "scatter"
operator, it is potentially fine to re-inplace
(and remove that scatter operator from the graph).
See below for a more detailed example.
NOTE: there is an optimization in this step that is crucial
to fully recovering performance from functionalization.
Given this program:
def f(x):
a = torch.ops.aten.add(x, x)
b = torch.ops.aten.diagonal(a)
torch.ops.aten.fill_(b, 0)
return d
Functionalization will emit the following:
def f(x):
a = torch.ops.aten.add(x, x)
b = torch.ops.aten.diagonal(a, 0, 1)
b_updated = torch.ops.aten.fill(b, 0)
a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
return a_updated
Ordinarily, we would not be able to reinplace the fill,
because "b" aliases with "a" which is used by the diagonal_scatter call.
"re-inplacing" is on the hook for figuring out that it is ok to
completely, the expensive diagonal_scatter call, if we re-inplace the add().
So, for every `alias in alias_set(a)`, instead of checking
that "alias" is not used anywhere later in the graph,
we check that
EITHER:
(a) alias is not used anywhere later in the graph
OR:
(b) alias is used exactly once later on in the graph,
in the following op:
out = foo_scatter(alias, x, args...)
where the following must hold:
(i) "foo_scatter" is the "inverse" operator for foo.
This only applies to "foo" ops that are view operators,
which view into a subset of the original tensor's memory.
In practice, there are ~4 operators where this applies:
diagonal -> diagonal_scatter
slice -> slice_scatter
select -> select_scatter
as_strided -> as_strided_scatter
(ii) "args..." are the same between the foo() and foo_scatter() calls.
(3) Perform the actual re-inplacing on foo!
(3b) is the common case, but special care is needed for {view}_scatter (3a)
(3a) {view}_scatter ops.
Consider this program:
a = torch.zeros(2, 2)
b = torch.ones(2)
a[0] = b
Post functionalization, that will look like:
a = torch.zeros(2)
b = torch.ones(1)
a_updated = torch.select_scatter(a, b, 0, 0)
In this case though, there is no "functional" op to re-inplace!
Instead, we'd like to directly remove toe select_scatter call.
We already know from (3) that this is valid,
because "a" has no later usages in the graph.
We perform the re-inplacing on the {view}_scatter op like so
Before:
a_updated = torch.select_scatter(a, b, args...)
After:
a_slice = a.select(a, args...)
a_slice.copy_(b)
(3b) Otherwise, replace the functional op with its inplace variant.
Before:
b = foo(a, args...)
After:
a.foo_(args...)
(4) Finally, after converting either:
Before:
b = foo(a)
After:
foo_(a)
or
Before:
b = {slice}_scatter(a, mutated_slice, args...)
After:
slice = {slice}(a, args...)
slice.copy_(mutated_slice)
We now need to find all later nodes that use "b" as an argument
and update them to take in "a" instead.
Note that for the majority of inplace ops, this isn't actually necessary
(because most inplace ops return "self" as their output).
This isn't generally true for all mutable ops though, which is why
we need to actually replace all of the arguments.
We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
That maps a given tensor storage to the set of all nodes that take in that storage
as an input.
Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
together.
(5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
during step (3) get manually deleted from the graph.
Their outputs are no longer used, so technically standard DCE would be able
to do this, but we can no longer run FX's DCE pass now that we have mutable
ops in the graph.
"""
_FunctionalizationMetadataProp(gm).propagate(*sample_args)
# Useful debug printing
# def _print(x):
# if isinstance(x, FakeTensor):
# print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
# for n in gm.graph.nodes:
# print(n.format_node())
# if hasattr(n, 'meta'):
# print(f'node_idx: {n.meta["node_idx"]}')
# if 'fake_result' in n.meta:
# tree_map(_print, n.meta['fake_result'])
# if 'view_of' in n.meta:
# print(f'view_of: {str(n.meta["view_of"])}')
# print()
# We need to know which nodes correspond to inputs (or their aliases)
# so we know not to re-inplace them.
# NOTE: later, we'll need to add an optimization for fully recovering performance
# on programs that mutate inputs.
input_storages = {
StorageWeakRef(
node.meta['fake_result']._typed_storage()
) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))}
# We also need to know for a given node, what are all of its aliasing nodes.
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
for n in gm.graph.nodes:
if 'fake_result' in n.meta:
# Tree-mapping because some ops can return lists of tensors.
def _add_to_map(x):
if isinstance(x, FakeTensor):
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
pytree.tree_map_(_add_to_map, n.meta['fake_result'])
# inplace-ify functional ops, subject to the constraints written below.
all_later_view_inverse_nodes_to_delete = set()
for idx, node in enumerate(gm.graph.nodes):
if node.op == 'call_function':
# Today, the re-inplace pass on directly acts on:
# - functional ops with an inplace variant
# - {view}_scatter ops that can be potentially removed from the graph.
# Both of these ops take in tensor first args, so filtering on this condition
# makes the later code simpler.
# We should revisit this at some point though, particularly when we also want
# the reinplacer to be able to handle out= and mutable operators
# and tensorlist first args (like `_foreach_` ops).
if not isinstance(node.target, torch._ops.OpOverload):
continue
if len(node.target._schema.arguments) < 1:
continue
if type(node.target._schema.arguments[0].type) != torch.TensorType:
continue
# Step 1a: Check that the self argument we're attempting to reinplace
# has the same size/stride as the output.
# For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
# As it would require resizing scalar_tensor.
# (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
# this is probably an optimization to revisit later).
self_arg = node.args[0]
self_flattened = pytree.tree_leaves(self_arg.meta['fake_result'])
node_flattened = pytree.tree_leaves(node.meta['fake_result'])
self_has_wrong_metadata = False
if len(self_flattened) == len(node_flattened):
for self_meta, node_meta in zip(self_flattened, node_flattened):
if self_meta.numel() != node_meta.numel():
self_has_wrong_metadata = True
if self_meta.dtype != node_meta.dtype:
self_has_wrong_metadata = True
# We also cannot re-inplace on tensors that have internal memory overlap.
# e.g. torch.ones(1).expand(4, 4).add_(1)
if torch._debug_has_internal_overlap(self_meta) == 1:
self_has_wrong_metadata = True
# Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
# Since users should never really be calling the functional "torch.ops.aten.resize"
# op directly in their programs.
if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
continue
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
self_arg_name = self_arg.name
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
if self_arg_storage in input_storages:
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
continue
if len([x for x in node.args if x is self_arg]) > 1:
# Step 1c:
# Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
# so we prevent re-inplacing in this case.
continue
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
self_aliases = storage_to_nodes[self_arg_storage]
# First, we find all later usages of any of the aliases of self_arg.
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
# Then, we check if any of those later usages are actually view_scatter ops
# that are safe to fully remove.
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
# Step 2: Check to see if the input to the op is re-used later in the graph.
# If not (same goes for its aliases), then this op is safe to re-in place.
# This is a slightly roundabout way to check that there are no later usages of the current self argument.
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
if not can_reinplace:
continue
# Step 3a: Special handling for when we see *_scatter operators.
# When we see an operator like `b = torch.slice_scatter(a, ...)`,
# instead of trying to "inplace" it into a.slice_scatter_(..._),
# we would prefer to remove it from the graph entirely,
# and instead copy_() the slice directly into the larger tensor.
# See the description of the algorithm for a full example.
if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
view_op = _VIEW_INVERSE_MAP[node.target]
# Before:
# base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
# After:
# slice = torch.ops.aten.slice.default(base, args...)
# slice.copy_(mutated_slice)
with gm.graph.inserting_before(node):
mutated_slice_node = node.args[1]
remaining_slice_args = node.args[2:]
slice_node = gm.graph.create_node(
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
copy_node = gm.graph.create_node(
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
# Add the slice_scatter node to our "nodes to delete" list.
all_later_view_inverse_nodes_to_delete.add(node)
else:
# Step 3b: Check to see if this operator has an inplace variant.
maybe_inplace_op = _maybe_get_inplace_op(node.target)
if maybe_inplace_op is None:
continue
# And if so, replace it with its inplace variant.
node.target = maybe_inplace_op
# At this point, 'storage_to_nodes' will be stale.
# Now that we're inplacing `b = foo(a)`, we need to effectively
# union together the dict values for b and a's storage.
# Hmm... morally I think we also want to keep the `fake_result` metadata
# up to date here, but I'm not sure how easy it is to do.
# Maybe it's fine to wait until the end of the pass to update it.
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
# Need to remember the view_scatter view nodes we found so we can remove them alter.
all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
# Step 4:
# Now that we've replaced b = a.foo() with a.foo_(),
# We need to replace any later usages of "b" with "a"
for old in itertools.chain([node], later_view_inverse_node_usages):
new = old.args[0]
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
for node_to_update in nodes_to_update:
new_args = []
args = node_to_update.args
def replace_arg(a):
if a == old:
return new
return a
# First, replace usages of "b" with "a"
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
# Second, update our storage_to_nodes data structure.
old_flattened_res = pytree.tree_leaves(old.meta['fake_result'])
node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result'])
old_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in old_flattened_res if isinstance(x, FakeTensor)}
node_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in node_flattened_res if isinstance(x, FakeTensor)}
# This will happen if we're updating a view op, e.g.
# e.g. replacing
# x = view(old)
# x = view(new)
# When that happens, we need to make sure to keep our
# storage mapping up to date.
#
# We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
# or multiple tensors that all share the same storage.
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
new_flattened_res = pytree.tree_leaves(new.meta['fake_result'])
new_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in new_flattened_res if isinstance(x, FakeTensor)}
assert len(new_res_storage) == 1
(old_ref,) = old_res_storage
(new_ref,) = new_res_storage
(node_ref,) = node_res_storage
# Technically, "old_ref" and all its aliases will remain
# in our mapping.
# That should be fine though, since we deleted "old"
# from the graph at this point.
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
# Step 4: delete any _scatter nodes that we de-functionalized
# Need to take care not to delete any of these nodes until after *all* modifications
# to the graph are finished.
for to_delete in all_later_view_inverse_nodes_to_delete:
gm.graph.erase_node(to_delete)
gm.recompile()
return gm

View File

@ -0,0 +1,605 @@
# mypy: allow-untyped-defs
import functools
import logging
import operator
import sys
from typing import Any, Dict, Optional, Set, TYPE_CHECKING
# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
if TYPE_CHECKING:
import sympy
from torch.fx.experimental.symbolic_shapes import ShapeEnv
else:
ShapeEnv = Any
import torch
import torch.utils._pytree as pytree
from torch import fx
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx._compatibility import compatibility
from torch.fx._utils import lazy_format_graph_code
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.experimental.sym_node import SymNode
from torch.fx.graph_module import GraphModule
__all__ = ["insert_deferred_runtime_asserts"]
log = logging.getLogger(__name__)
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
def _get_example_value(node: fx.Node) -> Optional[str]:
"""
Get the example value key for a node, since dynamo uses "example_value"
while non-strict export uses "val.
"""
if "example_value" in node.meta:
return node.meta["example_value"]
elif "val" in node.meta:
return node.meta["val"]
else:
return None
def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
val = _get_example_value(node)
if isinstance(val, py_sym_types):
return val.node.expr
return None
@compatibility(is_backward_compatible=True)
def insert_deferred_runtime_asserts(
gm: GraphModule,
shape_env: ShapeEnv,
name: str,
export: bool = False,
) -> None:
"""
During tracing, we may have discovered that some data-dependent values
had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime
that x.item() >= 0. This asserts can happen unpredictably during fake
tensor propagation, so we cannot conveniently insert them into the FX graph
when they occur. Instead, we accumulate them in the ShapeEnv, and in this
pass insert them into the graph as proper tests.
This pass also deduplicates size-related computation, CSE-ing ops that produce
symbolic values and/or are involved in runtime asserts. Additionally, shape calls
(size/stride/storage_offset) are turned into compute on input sizes if possible,
allowing intermediate tensors to be freed earlier. For example, here dynamo will
DCE the cat and repeat calls:
z = torch.cat([x, x], dim=0) # 2*s0
w = z.repeat(y.shape[0]) # 2*s0*s1
_w = w.shape[0]
# something with _w, but not w ...
# turns into ->
_w0 = 2 * s0
_w = _w0 * s1
# where s0, s1 are either SymInt graph inputs, or the result of added size calls
Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert
the same expression, and redundant constrain_range calls are also deduplicated.
Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate
information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol,
and we delete all previous calls, adding bound checks at the end of this pass.
"""
# Import sympy locally
import sympy
from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
from torch.fx.experimental.symbolic_shapes import (
_has_uninterpretable_sympy_function,
CallMethodKey,
cast_symbool_to_symint_guardless,
ConvertIntKey,
DivideByKey,
free_symbols,
InnerTensorKey,
resolve_unbacked_bindings,
)
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.reference import PythonReferenceAnalysis
from torch.utils._sympy.value_ranges import ValueRanges
# TODO: Request simplification on runtime asserts before emitting them
ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
graph = gm.graph
graph_code_log.debug(
"%s",
lazy_format_graph_code(
f"pre insert_deferred_runtime_asserts {name}", gm, colored=True
),
)
# We are going to mutate the dict
expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {}
placeholders = set()
first_non_placeholder = None
for node in graph.nodes:
if node.op != "placeholder":
first_non_placeholder = node
break
else:
placeholders.add(node)
def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool:
"""
If a size/stride/storage offset call on an intermediate tensor,
we can try to compute the value from input shapes instead.
"""
return (
(val := _get_sym_val(node)) is not None
and not isinstance(val, sympy.Number)
# this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
and not _has_uninterpretable_sympy_function(val)
and any(
isinstance(arg, fx.Node)
and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
and arg.op != "placeholder"
for arg in node.args
)
)
# Figure out what key to use, val or example_value
val_key = "val"
for node in graph.nodes:
if "example_value" in node.meta:
val_key = "example_value"
break
elif "val" in node.meta:
break
def _node_metadata_hook(
node: torch.fx.Node,
stack_trace: Optional[str] = None,
nn_module_stack: Optional[Dict[str, Any]] = None,
) -> None:
fake_args = [
_get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
for arg in node.args
]
try:
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
except NotImplementedError:
# This can happen when attempting to reify a symbol with an unsupported call_function node,
# e.g. with NestedTensors + sym_size.int via match_symbol().
# This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input.
pass
if stack_trace is not None:
node.meta["stack_trace"] = stack_trace
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack
# Track asserts/checks we've added
added_asserts: Set[sympy.Expr] = set()
constrained_unbacked_symbols: Set[sympy.Symbol] = set()
def _sympy_interp(expr_to_proxy, expr):
# sympy_interp() with hash consing
from sympy import Integer, Number, Symbol
from sympy.logic.boolalg import BooleanAtom
from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
# hash cons
if expr in expr_to_proxy:
return expr_to_proxy[expr]
# base cases, don't cache
if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
return sympy_interp(PythonReferenceAnalysis, expr_to_proxy, expr)
# hash cons on arguments, run expr handler
expr_to_proxy[expr] = _run_sympy_handler(
PythonReferenceAnalysis,
[_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
expr,
)
return expr_to_proxy[expr]
def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool:
# This is probably unnecessary, but since torch._check() calls for single-symbol bounds
# like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant
# and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial.
if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan):
return False
lhs, rhs = expr.args
return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or (
isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number)
)
def add_runtime_asserts(ras):
for ra in ras:
if (
# redundant
ra.expr in added_asserts
# if we've already added a constrain_range call for this symbol,
# then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant.
or (
len(ra.expr.free_symbols) == 1
and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols
and _is_bound_expr_for_symbol(ra.expr)
)
# don't try to reify sympy functions we can't turn into FX nodes
or _has_uninterpretable_sympy_function(ra.expr)
):
continue
log.debug("inserting runtime assert %s", ra.expr)
# Need to process ALL free symbols, not just unbacked ones
fvs = free_symbols(ra.expr)
missing = fvs - expr_to_proxy.keys()
if missing:
i1 = min(missing, key=str)
# TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689
# assert shape_env.is_unbacked_symint(i1), i1
ras_by_symbol.setdefault(i1, []).append(ra)
else:
# Convert the sympy expression into a sequence of FX
# nodes
with _set_node_metadata_hook(gm, _node_metadata_hook):
res = _sympy_interp(expr_to_proxy, ra.expr).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
# TODO: use ra.msg here, but it's pretty
# useless right now
(
res,
f"Runtime assertion failed for expression {ra.expr} on node '{res}'",
),
)
added_asserts.add(ra.expr)
nodes = list(graph.nodes)
for i, node in enumerate(nodes[:-1]):
# Placeholders can match symbols, but when we destructure them
# with size we have to make sure we insert the nodes after all
# the placeholders
with graph.inserting_before(
nodes[i + 1] if node not in placeholders else first_non_placeholder
):
# Unfortunately, this logic still must remain because manual
# make_fx calls may not explicitly bind all symbolic ints as
# arguments to the function, so we must infer it from the other
# arguments
if (
node in placeholders
and (example_value := _get_example_value(node)) is not None
):
def match_symbol(symint, cb):
if (
isinstance(symint, torch.SymInt)
and isinstance(symint.node, SymNode)
and isinstance(s := symint.node.expr, sympy.Symbol)
and s not in expr_to_proxy
):
with _set_node_metadata_hook(gm, _node_metadata_hook):
expr_to_proxy[s] = fx.Proxy(cb())
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
match_symbol(example_value, lambda: node)
if isinstance(t := example_value, torch.Tensor):
for i, s in enumerate(t.size()):
match_symbol(
s,
lambda: graph.call_function(
torch.ops.aten.sym_size.int, (node, i)
),
)
if not is_sparse_any(t):
for i, s in enumerate(t.stride()):
match_symbol(
s,
lambda: graph.call_function(
torch.ops.aten.sym_stride.int, (node, i)
),
)
match_symbol(
t.storage_offset(),
lambda: graph.call_function(
torch.ops.aten.sym_storage_offset.default, (node,)
),
)
# Handle asserts that aren't associated with any symbol. This
# doesn't really have to be in the loop as it will only run once,
# it just needs to happen right after the placeholders.
# insert this after placeholders & added sym nodes, and before non-placeholders.
if node == first_non_placeholder:
add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload]
# deduplicate asserts already present in graph
if node.target in (
torch._check,
torch.ops.aten._assert_scalar.default,
):
if (
node.args[0] == True # noqa: E712
or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy
or (
assert_expr is not None
and _is_bound_expr_for_symbol(assert_expr)
)
):
arg = node.args[0]
gm.graph.erase_node(node)
if isinstance(arg, fx.Node) and not arg.users:
gm.graph.erase_node(arg)
else:
added_asserts.add(assert_expr) # type: ignore[arg-type]
# hash cons, replace function calls that return torch.SymInts with direct references to
# FX nodes built up to reify the sympy expression.
if (
node.op != "placeholder"
and (sym_expr := _get_sym_val(node)) is not None
):
# this guards against deleting calls like item() that produce new untracked symbols
new_untracked_symbols = sym_expr.free_symbols - expr_to_proxy.keys()
# this guards against deleting calls that produce unbacked bindings we haven't yet seen.
# in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
# (is backed), but produces an unbacked symbol. In this case keep the node alive.
new_unbacked_bindings = (
resolve_unbacked_bindings(
shape_env, node.meta.get("unbacked_bindings", {})
).keys()
- expr_to_proxy.keys()
)
# maybe re-reify expression, replace current node
if (
sym_expr in expr_to_proxy
or ( # example value is redundant
_is_intermediate_tensor_sym_call(node)
# shape call on intermediate tensor, turn into computation on input shapes
and not new_untracked_symbols
)
) and not new_unbacked_bindings:
if _is_intermediate_tensor_sym_call(
node
): # reify from input shapes
with _set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
nn_module_stack=node.meta.get("nn_module_stack"),
),
):
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type]
# won't try DCE-ing tensor compute here
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
node.replace_all_uses_with(hash_node)
gm.graph.erase_node(node)
log.debug(
"CSE node %s -> %s for expr %s", node, hash_node, sym_expr
)
# store node in hash cons, don't delete/replace
elif sym_expr not in expr_to_proxy and not isinstance(
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
): # don't hash cons primitives
expr_to_proxy[sym_expr] = fx.Proxy(node) # type: ignore[arg-type]
# We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
# so calls before that are redundant.
if node.target in (
torch.ops.aten.sym_constrain_range.default,
torch.ops.aten.sym_constrain_range_for_size.default,
):
gm.graph.erase_node(node)
defs = []
# AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as
# equivalent, but the refinement calls we perform in this pass may struggle with associating the two.
# More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough
# information about the old symbol when we re-export, raising errors on data-dependent guards.
# Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is.
if unbacked_bindings := resolve_unbacked_bindings(
shape_env, node.meta.get("unbacked_bindings")
):
for s, keypath in unbacked_bindings.items():
defs.append(s)
# TODO: some CSE when generating these nodes can probably
# help reduce graph size and improve compile time
def go(node, keypath):
if keypath == ():
return node
if (
len(keypath) >= 2
and isinstance(keypath[0], CallMethodKey)
and isinstance(keypath[1], pytree.SequenceKey)
):
if keypath[0].name == "size":
return go(
graph.call_function(
torch.ops.aten.sym_size.int,
(node, keypath[1].idx),
),
keypath[2:],
)
if keypath[0].name == "stride":
return go(
graph.call_function(
torch.ops.aten.sym_stride.int,
(node, keypath[1].idx),
),
keypath[2:],
)
return go(
graph.call_method(
keypath[0].name, (node, keypath[1].idx)
),
keypath[2:],
)
elif isinstance(keypath[0], CallMethodKey):
return go(
graph.call_method(keypath[0].name, (node,)), keypath[1:]
)
elif isinstance(keypath[0], pytree.SequenceKey):
return go(
graph.call_function(
operator.getitem, (node, keypath[0].idx)
),
keypath[1:],
)
elif isinstance(keypath[0], ConvertIntKey):
return go(
graph.call_function(
cast_symbool_to_symint_guardless, (node,)
),
keypath[1:],
)
elif isinstance(keypath[0], DivideByKey):
# TODO: need to assert divisibility
return go(
graph.call_function(
operator.floordiv, (node, keypath[0].divisor)
),
keypath[1:],
)
elif isinstance(keypath[0], InnerTensorKey):
return go(
graph.call_function(
getattr, (node, keypath[0].inner_name)
),
keypath[1:],
)
else:
raise AssertionError(f"unrecognized keypath {keypath}")
if s not in expr_to_proxy:
with _set_node_metadata_hook(gm, _node_metadata_hook):
expr_to_proxy[s] = fx.Proxy(go(node, keypath))
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
for i0 in defs:
ras = ras_by_symbol.pop(i0, [])
# Before we perform any asserts, first apply range
# refinement. This is important, because if we are going
# to retrace the graph (and we typically are if we send
# the graph to AOTAutograd), we need to make sure we apply
# range refinement (ala _check_is_size) first, BEFORE we
# run any of the asserts. Otherwise, we may decide to
# perform substitutions based on the asserts which we then
# can't back out, because value ranges can only be applied
# to asserts.)
#
# A perhaps better long term plan is to avoid this order
# dependence by making it possible to refine ranges on
# arbitrary expressions, not just symbols. But it is not
# so easy to make use of this information, see
# https://twitter.com/ezyang/status/1745801370299482492
# We actually made an attempt at this in
# https://github.com/pytorch/pytorch/pull/119043
# which didn't work.
#
# Another ideas for how to do this:
# - Have bound_sympy be the source of truth of the ranges of any expression
# - Cache intermediate results for every subexpression of bound_sympy
# - This cache should be possible to edit to refine ranges
#
# One issue with this proposal is that if
# we have a bound on 2x, we are not going to be able to
# apply it for 4x. Similarly, we may have bounds for an
# equivalent expression that we are not applying because
# it's not a perfect match (e.g. x < y vs y > x)".
#
# The first issue we already have it and it's impossible
# to solve in general, so any implementation on a best
# effort basis should do.
#
# The second issue is a preexisting one. It can be mitigated
# with a normalisation algorithm. In general, it may also
# be on a best effort basis, but since our grammar is not
# terribly difficult, chances are we could even fully
# normalise SymPy expressions... who knows.
if i0 in constrained_unbacked_symbols:
continue # constrain symbol just once
if i0 in shape_env.size_like:
if export:
graph.call_function(
torch.ops.aten.sym_constrain_range_for_size.default,
(expr_to_proxy[i0].node,),
)
else:
graph.call_function(
torch._check_is_size, (expr_to_proxy[i0].node,)
)
vr = shape_env.var_to_range[i0]
if vr.is_int and vr.upper == sys.maxsize - 1:
# treat upper bound == sys.maxsize - 1 for int symbols as +oo
# to avoid redundant runtime assert
vr = ValueRanges(vr.lower, int_oo)
if not shape_env._default_unspecified_value_range().issubset(vr):
# The runtime range is constrained, so add a runtime
# assert and also explicitly refine the range
# (refinement should not be necessary once runtime
# asserts cause refinement, but that's NYI)
def convert(s):
if s in (int_oo, -int_oo):
return None
try:
return int(s)
except TypeError:
return None
if (
expr_to_proxy[i0].node.target
!= cast_symbool_to_symint_guardless
):
# TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
# raises AOTAutograd errors on cast_symbool_to_symint_guardless
with _set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
nn_module_stack=node.meta.get("nn_module_stack"),
),
):
if (min_val := convert(vr.lower)) is not None:
ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(
ge,
f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
),
)
added_asserts.add(i0 >= min_val)
if (max_val := convert(vr.upper)) is not None:
le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(
le,
f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
),
)
added_asserts.add(i0 <= max_val)
constrained_unbacked_symbols.add(i0)
add_runtime_asserts(ras)
# delete unused reified symbols
for expr, proxy in expr_to_proxy.items():
if (
isinstance(expr, sympy.Symbol)
and proxy.node.op != "placeholder" # keep placeholders intact
and not proxy.node.users
):
log.debug("deleting unused reified symbol for %s", expr)
gm.graph.erase_node(proxy.node)

View File

@ -0,0 +1,196 @@
# mypy: ignore-errors
import torch
import torch.fx
import traceback
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional, Dict
from torch.fx._compatibility import compatibility
from torch._guards import detect_fake_mode
from torch._subclasses.meta_utils import is_sparse_any
__all__ = ['TensorMetadata', 'ShapeProp']
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
# TensorMetadata is a structure containing pertinent information
# about a tensor within a PyTorch program.
# General Tensor metadata
shape : torch.Size
dtype : torch.dtype
requires_grad : bool
stride : Tuple[int, ...]
memory_format : Optional[torch.memory_format]
# Quantization metadata
is_quantized : bool
qparams: Dict[str, Any]
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
stride = result.stride() if not is_sparse_any(result) else None
memory_format = None
if include_contiguity and not is_sparse_any(result):
memory_formats = {
torch.contiguous_format,
torch.channels_last,
torch.channels_last_3d,
}
for query_format in memory_formats:
if result.is_contiguous(memory_format=query_format):
memory_format = query_format
break
is_quantized = result.is_quantized
qparams: Dict[str, Any] = {}
if is_quantized:
qscheme = result.qscheme()
qparams["qscheme"] = qscheme
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
qparams["scale"] = result.q_scale() # type: ignore[assignment]
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
# In this branch, scale and zero_point are expected to be tensors,
# we store the values as immutable_list in TensorMetadata for
# easier serialization downstream
qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
return TensorMetadata(
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
@compatibility(is_backward_compatible=True)
class ShapeProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
record the shape and type of the result
into the corresponding node.
Example:
In this example, we record the shape
and data type of a module given
an example input ``torch.randn(50, D_in)``.
We print the name, shape and dtype of each node.
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super().__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)
for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape)
The output of this code is:
x torch.float32 torch.Size([50, 1000])
linear1 torch.float32 torch.Size([50, 100])
clamp_1 torch.float32 torch.Size([50, 100])
linear2 torch.float32 torch.Size([50, 10])
output torch.float32 torch.Size([50, 10])
Args:
module (GraphModule): The module to be executed
fake_mode (FakeTensorMode): A fake mode for copying the gm
"""
def __init__(self, gm, fake_mode=None):
super().__init__(gm)
if fake_mode is None:
fake_mode = detect_fake_mode()
if fake_mode is not None:
from torch._dynamo.utils import deepcopy_to_fake_tensor
# Note:
# We need fake execution cause the inputs are fake, however, we cannot fakify the module
# - because we need to write to the tensor_meta of the real module. So we fakify to
# produce a result (L131 below), to extract tensor meta, and then keep going.
#
# If we were to fakify, we would write to the wrong node, and then downstream fusion
# would be missing the tensor_meta.
#
# See torch/_inductor/overrides.py for where this is called upstream of fusion.
self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
self.fake_mode = fake_mode
else:
self.fake_module = None
self.fake_mode = None
self.real_module = self.module
def run_node(self, n : Node) -> Any:
try:
if self.fake_module is not None:
# Hacky swap. Alternatively, we could do this with overriding
# call_module and get_attr.
self.module = self.fake_module
try:
if self.fake_mode is not None:
with self.fake_mode, enable_python_dispatcher():
result = super().run_node(n)
else:
result = super().run_node(n)
finally:
self.module = self.real_module
except Exception as e:
traceback.print_exc()
raise RuntimeError(
f"ShapeProp error for: node={n.format_node()} with "
f"meta={n.meta}"
) from e
found_tensor = False
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
nonlocal found_tensor
found_tensor = True
return _extract_tensor_metadata(obj)
else:
return obj
meta = map_aggregate(result, extract_tensor_meta)
if found_tensor:
n.meta['tensor_meta'] = meta
n.meta['type'] = type(result)
return result
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
if self.fake_mode is not None:
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
else:
fake_args = args
return super().run(*fake_args)

View File

@ -0,0 +1,575 @@
# mypy: allow-untyped-defs
import inspect
from typing import Any, Callable, Dict, List, Optional, Set
from collections import OrderedDict
import logging
import torch
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx._utils import lazy_format_graph_code
__all__ = ["Partition", "split_module"]
log = _LOGGER = logging.getLogger(__name__)
@compatibility(is_backward_compatible=True)
class Partition:
def __init__(self, name: str):
self.name: str = name
self.submod_name = f"submod_{name}"
self.node_names: List[str] = []
self.inputs: Dict[str, None] = {}
self.outputs: Dict[str, None] = {}
self.dependencies: Dict[str, None] = {}
self.dependents: Dict[str, None] = {}
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment: Dict[Node, Node] = {}
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
return (
f"name: {self.name},\n"
f" nodes: {self.node_names},\n"
f" inputs: {self.inputs},\n"
f" outputs: {self.outputs},\n"
f" partitions depended on: {self.dependencies},\n"
f" partition dependents: {self.dependents}"
)
# Creates subgraphs out of main graph
@compatibility(is_backward_compatible=True)
def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[Node], int],
qualname_map: Optional[Dict[str, str]] = None,
keep_original_order: Optional[bool] = False,
keep_original_node_name: Optional[bool] = False,
):
"""
Creates subgraphs out of main graph
Args:
m (GraphModule): Graph module to split
root_m (torch.nn.Module): root nn module. Not currently used. Included
because the root nn module is usually transformed via
torch.fx._symbolic_trace.symbolic_trace (see example below)
split_callback (Callable[[Node], int]): Callable function
that maps a given Node instance to a numeric partition identifier.
split_module will use this function as the policy for which operations
appear in which partitions in the output Module.
qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
mapping from new target names in the module after split to old target
names in the original module.
keep_original_order: Optional[bool]: keep the original order of the GraphModule
or use the Topological order of the new constructed GraphModule
Returns:
GraphModule: the module after split.
Example:
This is a sample setup:
import torch
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx.passes.split_module import split_module
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w
# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)
# random mod partitioning
partition_counter = 0
NPARTITIONS = 3
def mod_partition(node: Node):
global partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition
# split module in module with submodules
module_with_submodules = split_module(
my_module_traced, my_module, mod_partition
)
Output looks like this. Original graph is broken into partitions
> print(module_with_submodules)
GraphModule(
(submod_0): GraphModule(
(linear): Linear(in_features=4, out_features=5, bias=True)
)
(submod_1): GraphModule(
(linear): Linear(in_features=4, out_features=5, bias=True)
)
(submod_2): GraphModule()
)
def forward(self, x, y):
param = self.param
submod_0 = self.submod_0(x, param, y); x = param = y = None
getitem = submod_0[0]
getitem_1 = submod_0[1]; submod_0 = None
submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
getitem_2 = submod_1[0]
getitem_3 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
return submod_2
Output of split module is the same as output of input traced module.
This is an example within a test setting:
> orig_out = my_module_traced(x, y)
> submodules_out = module_with_submodules(x, y)
> self.assertEqual(orig_out, submodules_out)
True
"""
log.debug(
"%s",
lazy_format_graph_code(
"pre split_module", m, colored=True
),
)
def construct_graph(
node: Node,
base_mod_env: Dict[str, Node],
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],
):
if node.op == "placeholder":
default_value = (
node.args[0] if len(node.args) > 0 else inspect.Signature.empty
)
if keep_original_node_name:
args = () if default_value is inspect.Signature.empty else (default_value,)
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type]
else:
base_mod_env[node.name] = base_mod_graph.placeholder(
node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type]
)
base_mod_env[node.name].meta = node.meta.copy()
elif node.op == "get_attr":
base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
base_mod_env[node.name].meta = node.meta.copy()
attr_val = m
for atom in node.target.split("."): # type: ignore[union-attr]
if not hasattr(attr_val, atom):
raise AttributeError(f"Node target {node.target} not found!")
attr_val = getattr(attr_val, atom)
base_mod_attrs[node.target] = attr_val # type: ignore[index]
return base_mod_env, base_mod_attrs
import sympy
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, Node] = {}
symbol_to_node: Dict[sympy.Symbol, Node] = {}
def record_cross_partition_use(
def_node: Node, use_node: Optional[Node]
): # noqa: B950
from torch.fx.experimental.symbolic_shapes import free_symbols
defined = getattr(def_node, "_fx_partition", None)
used = getattr(use_node, "_fx_partition", None)
log.debug(
"record_cross_partition_use %s (%s) %s (%s)",
def_node.name, defined, use_node.name if use_node is not None else "-", used
)
if defined != used:
if defined is not None:
def_partition = partitions[defined]
def_partition.outputs.setdefault(def_node.name)
if used is not None:
def_partition.dependents.setdefault(used)
if used is not None:
use_partition = partitions[used]
use_partition.inputs.setdefault(def_node.name)
# We have made def_node an input to the use_partition. If
# this input has symbolic symbols in its size, those also must
# be made as inputs to the partition
if (def_val := def_node.meta.get("example_value")) is not None:
for s in sorted(free_symbols(def_val), key=str):
s_node = symbol_to_node[s]
use_partition.inputs.setdefault(s_node.name)
if symbol_to_node[s].op != "placeholder":
# If the node that defines the symbol is not a
# placeholder, we must make it an output of the
# partition. Note that this may be in a different
# partition than defined! Although, this doesn't
# really make a difference for correctness, since
# defined is guaranteed to have the symbol in
# scope and can return it; you just get less
# optimal codegen in this case.
s_defined = getattr(s_node, "_fx_partition", None)
if s_defined is not None:
s_def_partition = partitions[s_defined]
s_def_partition.outputs.setdefault(s_node.name)
s_def_partition.dependents.setdefault(used)
if defined is not None:
use_partition.dependencies.setdefault(defined)
def instantiate_node_partition_mapping(node):
partition_name = str(split_callback(node))
log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name)
# add node to partitions
partition = partitions.get(partition_name)
if partition is None:
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
node._fx_partition = partition_name
# Global State Nodes are nodes which by their global state effects,
# "taint" all downstream nodes while they are active.
GLOBAL_STATE_NODES = [
torch.amp._enter_autocast,
torch.amp._exit_autocast,
torch._C._set_grad_enabled
]
# For grad regions:
# ------------------------
# 1. first region: we do nothing
# 2. subsequent regions: we insert the set_grad at the beginning
grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
# For autocast regions:
# ------------------------
# 1. first region: we will only insert the _exit at the end
# 2. intermediate regions: we will insert both the
# _enter at the beginning and _exit at the end
# 3. last region: we will only insert _enter at the beginning
# We will do so in the order in which the autocasts were instantiated.
autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
autocast_exits: Dict[Node, Optional[Node]] = {}
active_grad = None
active_autocasts = set()
for node in m.graph.nodes:
# This will prefer placeholder bindings, because those come first.
# This is a little dangerous though: it is possible that an unbacked
# symbol is used without any binding site for it, in which case we
# will get a KeyError not able to find it. I'd like to fix this by
# having passes.runtime_assert establish some invariants that I can
# rely on later, but this needs some extra work. Quick fix first.
# See https://github.com/pytorch/pytorch/issues/130534
if (
(val := node.meta.get("example_value")) is not None and
isinstance(val, torch.SymInt) and
isinstance(s0 := val.node.expr, sympy.Symbol) and
s0 not in symbol_to_node
):
symbol_to_node[val.node.expr] = node
if node.op in ["placeholder", "get_attr", "output"]:
continue
instantiate_node_partition_mapping(node)
if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
if node.target == torch._C._set_grad_enabled:
assert len(node.args) == 1
assert isinstance(node.args[0], bool)
active_grad = node
grad_regions[active_grad] = set({split_callback(node)})
elif node.target == torch.amp._enter_autocast:
# Should all be python constants
assert all(not isinstance(arg, Node) for arg in node.args)
active_autocasts.add(node)
autocast_regions[node] = set({split_callback(node)})
autocast_exits[node] = None
elif node.target == torch.amp._exit_autocast:
assert len(node.args) == 1
autocast_regions[node.args[0]].add(split_callback(node))
active_autocasts.remove(node.args[0])
autocast_exits[node.args[0]] = node
if active_grad is not None:
grad_regions[active_grad].add(split_callback(node))
for a in active_autocasts:
autocast_regions[a].add(split_callback(node))
assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("autocast_regions: %s", autocast_regions)
_LOGGER.debug("grad_regions: %s", grad_regions)
assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
# split nodes into partitions
highest_partition = -1
for node in m.graph.nodes:
orig_nodes[node.name] = node
# TODO currently placeholders/parameters aren't put into random partitions,
# rather they're added to the graphs where they are used down below
if node.op in ["placeholder", "get_attr"]:
continue
if node.op == "output":
torch.fx.graph.map_arg(
node.args[0], lambda n: record_cross_partition_use(n, None)
)
continue
if assert_monotonically_increasing:
pid = split_callback(node)
assert highest_partition <= pid, \
("autocast or set_grad_enabled require monotonically increasing partitions:"
f"highest: {highest_partition}, this node's: {pid}")
highest_partition = pid
# do not capture cross-partition dependencies for global state nodes as they will be
# self-contained - their setup and unwind will be isolated to each partition submodule.
if node.target not in GLOBAL_STATE_NODES:
torch.fx.graph.map_arg(
node.args, lambda def_node: record_cross_partition_use(def_node, node)
)
torch.fx.graph.map_arg(
node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
) # noqa: B950
original_partition_order = list(partitions.keys())
# find partitions with no dependencies
root_partitions: List[str] = []
for partition_name, partition in partitions.items():
if not len(partition.dependencies):
root_partitions.append(partition_name)
# check partitions for circular dependencies and create topological partition ordering
sorted_partitions: List[str] = []
while root_partitions:
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
for dependent in partitions[root_partition].dependents:
partitions[dependent].dependencies.pop(root_partition)
if not partitions[dependent].dependencies:
root_partitions.append(dependent)
if len(sorted_partitions) != len(partitions):
raise RuntimeError("cycle exists between partitions!")
# Enter prelude
for regions_mapping in [autocast_regions, grad_regions]:
for node, regions in regions_mapping.items():
assert len(regions) > 0
partitions[str(regions[0])].environment[node] = node
for r in regions[1:]:
partition = partitions[str(r)]
new_node = partition.graph.create_node(
op=node.op,
target=node.target,
args=tuple(arg for arg in node.args),
kwargs={},
type_expr=node.type,
)
new_node.meta = node.meta.copy() # is it really a good idea to copy this?
partition.environment[node] = new_node
# add placeholders to partition inputs
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for inp in partition.inputs:
placeholder = partition.graph.placeholder(
inp,
type_expr=orig_nodes[inp].type,
)
placeholder.meta = orig_nodes[inp].meta.copy()
partition.environment[orig_nodes[inp]] = placeholder
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(
node.kwargs, lambda n: environment[n]
)
if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise AttributeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = "_".join(target_atoms)
partition.targets[target] = target_attr
# Fill in the passed-in mapping from new qualname to old qualname
if qualname_map is not None:
# When creating the split module later, the submodules will have
# path prefix matching the corresponding partition's submod_name
qualname = f"{partition.submod_name}.{target}"
qualname_map[qualname] = node.target
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
name = node.name if keep_original_node_name else None
new_node = partition.graph.create_node(
op=node.op,
target=target,
args=gathered_args,
kwargs=gathered_kwargs,
type_expr=node.type,
name=name,
)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
# Exit epilogue
for regions_mapping in [autocast_regions]:
for node in reversed(regions_mapping):
regions = regions_mapping[node]
assert len(regions) > 0
for r in regions[:-1]:
partition = partitions[str(r)]
exit_node = autocast_exits[node]
assert exit_node is not None, "Missing exit node"
new_node = partition.graph.create_node(
op=exit_node.op,
target=exit_node.target,
args=(partition.environment[node],),
kwargs={},
type_expr=exit_node.type,
)
new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this?
# original module environment dict mapping node names to nodes
orig_mod_env: Dict[str, Node] = {}
# Set up values to construct base module
base_mod_env: Dict[str, Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
if not keep_original_order:
for node in m.graph.nodes:
base_mod_env, base_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
else:
# Go through the graph to construct the mapping dict
for node in m.graph.nodes:
orig_mod_env[node.name] = node
# Do some things iterating over the partitions in topological order again:
# 1) Finish off submodule Graphs by setting corresponding outputs
# 2) Construct GraphModules for each submodule
# 3) Construct the base graph by emitting calls to those submodules in
# topological order or original order specified by keep_original_order
construct_order_partitions = (
sorted_partitions if not keep_original_order else original_partition_order
)
already_constructed_attr_nodes = set()
# We actually need to insert the placeholder nodes in the original order
# otherwise graph signature will be wrong.
original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
for partition_name in construct_order_partitions:
partition = partitions[partition_name]
# Set correct output values
output_vals = tuple(
partition.environment[orig_nodes[name]] for name in partition.outputs
)
# skip output node generation if there are no output values
num_output_vals = len(output_vals)
if num_output_vals == 1:
partition.graph.output(output_vals[0])
elif num_output_vals > 1:
partition.graph.output(output_vals)
if keep_original_order:
# first get the attr nodes required by this partition
orig_mod_attr_nodes: List[Node] = [
orig_mod_env[key] for key in partition.inputs if key not in original_order
]
for node in original_order:
if node in already_constructed_attr_nodes:
continue # already added this attr to the base graph
base_mod_env, based_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
already_constructed_attr_nodes.add(node)
# Construct GraphModule for this partition
for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
if node in already_constructed_attr_nodes:
continue
base_mod_env, base_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
already_constructed_attr_nodes.add(node)
base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
partition.targets, partition.graph
) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(
partition.submod_name,
tuple(base_mod_env[name] for name in partition.inputs),
)
num_outputs = len(partition.outputs)
if num_outputs > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
elif num_outputs == 1:
base_mod_env[next(iter(partition.outputs))] = output_val
for node in m.graph.nodes:
if node.op == "output":
base_mod_graph.output(
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
) # noqa: B950
ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
log.debug(
"%s",
lazy_format_graph_code(
"post split_module", ret, colored=True
),
)
return ret

View File

@ -0,0 +1,303 @@
# mypy: allow-untyped-defs
import copy
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
import torch.fx
from torch.fx._compatibility import compatibility
from torch.fx.graph import map_arg
from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
from .tools_common import NodeList
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
@compatibility(is_backward_compatible=False)
def getattr_recursive(obj, name):
for layer in name.split("."):
if hasattr(obj, layer):
obj = getattr(obj, layer)
else:
return None
return obj
@compatibility(is_backward_compatible=False)
def setattr_recursive(obj, attr, value):
if "." not in attr:
setattr(obj, attr, value)
else:
layer = attr.split(".")
setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
@compatibility(is_backward_compatible=False)
@dataclass
class Component:
"""
A component serves as a container for a subgraph we want to create afterwards.
"""
graph: torch.fx.Graph
order: int
name: str
# Stores the placeholder nodes in `graph`.
input_placeholders: List = field(default_factory=list)
# Store the nodes in original graph that are placeholder in `graph`.
orig_inputs: List = field(default_factory=list)
# Store the nodes in original graph that are outputs in `graph`.
orig_outputs: List = field(default_factory=list)
# Mapping from get_attr node in original graph to get_attr node in `graph`.
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
constructor_args: List[str] = field(default_factory=list)
gm: Optional[torch.fx.GraphModule] = None
@compatibility(is_backward_compatible=False)
def split_by_tags(
gm: torch.fx.GraphModule,
tags: List[str],
return_fqn_mapping: bool = False,
return_tuple: bool = False,
GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
"""
Splits a GraphModule using tags on its graph nodes. We honor the order of
tags. For example, we have tags = ["a", "b", "c"], the function will create
the initial submodules in the order of "a", "b", "c".
To set a tag:
gm.graph.nodes[idx].tag = "mytag"
This will result in all nodes with the same tag being extracted and placed in their
own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
and output nodes are created when needed while get_attr nodes get copied to submodules
where they are used.
Given the following module def:
class SimpleModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(...)
self.linear2 = torch.nn.Linear(...)
self.linear3 = torch.nn.Linear(...)
def forward(self, in1, in2):
r1 = self.linear1(in1)
r2 = self.linear2(in2)
r3 = torch.cat([r1, r2])
return self.linear3(r3)
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
ro:
def forward(self, in1):
self = self.root
linear1 = self.linear1(in1)
return linear1
main:
def forward(self, in2, linear1):
self = self.root
linear2 = self.linear2(in2)
cat_1 = torch.cat([linear1, linear2])
linear3 = self.linear3(cat_1)
return linear3
main:
def forward(self, in1, in2):
self = self.root
ro_0 = self.ro_0(in1)
main_1 = self.main_1(in2, ro_0)
return main_1
Returns:
split_gm: torch fx graph after split
orig_to_split_fqn_mapping: a map between the original fqn and the fqn
after split for call_module and get_attr.
"""
def flatten(x: torch.fx.node.Argument) -> NodeList:
"""
Stores nodes in x to a list and returns the list.
"""
r: NodeList = []
map_arg(x, r.append)
return r
# Mapping from node in original module to node in created submodule.
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
# Mapping from node in original module or created submodules to
# corresponding component.
node_to_component: Dict[torch.fx.Node, Component] = {}
# Mapping from tag to the corresponding component.
tag_to_component: Dict[str, Component] = {}
# Stores all components.
all_components: List[Component] = []
# Stores nodes that will be used in main graph.
used_in_main: Dict[torch.fx.Node, None] = {}
# Main graph after split.
main_g = torch.fx.Graph()
# Mapping from node in original module to node in main graph after split.
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
# Output node of original module.
output_node: Optional[torch.fx.Node] = None
# Create a component for each tag, we don't expect to create other components afterwards.
for tag in tags:
comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
all_components.append(comp)
tag_to_component[tag] = comp
# Traverse the nodes in original graph and take care of them.
for node in gm.graph.nodes:
if node.op == "output":
if output_node is not None:
raise RuntimeError("Multiple output nodes in graph!")
output_node = node
continue
# Placeholders in the original graph get copied to main graph.
if node.op == "placeholder":
main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
main_remapping[node].meta = copy.copy(node.meta)
continue
# Get_attr nodes are ignored because we are not tagging them.
# Instead, we copy them directly to the submodules use them afterwards.
if node.op == "get_attr":
continue
# Now we process callable nodes which are nodes with op of call_module,
# call_function or call_method. Every callable nodes should be tagged.
assert hasattr(node, "tag")
upstream_components = [
node_to_component[x]
for x in flatten(node.args) + flatten(node.kwargs)
if x.op not in {"placeholder", "get_attr"}
]
comp = tag_to_component[node.tag]
node_to_component[node] = comp
# Max order of upperstream components.
mx = max((c.order for c in upstream_components), default=0)
# Expect the component for `node` has higher order then its upstream components.
assert comp.order >= mx
# Map a input of `node` to nodes in the component's graph.
def remap_func(x):
# If input is a get_attr node, copy it to current component's graph.
# Returns the get_attr node in current component's graph.
if x.op == "get_attr":
if x not in comp.getattr_maps:
comp.getattr_maps[x] = comp.graph.get_attr(
x.target, type_expr=x.type
)
return comp.getattr_maps[x]
# If input is not a placeholder, it should have been put into a component
# already. If it's the current component then we return the corresponding
# node in the component.
if x.op != "placeholder" and node_to_component[x] == comp:
return node_remapping[x]
# If input is a placeholder or it's in other components, we want to make it
# as a placeholder in current component's graph.
if x not in comp.orig_inputs:
comp.orig_inputs.append(x)
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
placeholder.meta = copy.copy(x.meta)
comp.input_placeholders.append(placeholder)
used_in_main[x] = None
return comp.input_placeholders[comp.orig_inputs.index(x)]
n = comp.graph.node_copy(node, remap_func)
n.tag = node.tag # type: ignore[attr-defined]
node_remapping[node] = n
node_to_component[n] = comp
if output_node is None:
raise RuntimeError("Graph had no output node!")
for x in flatten(output_node.args[0]):
if x.op == "get_attr":
# We don't need components mapping for nodes of type "get_attr"
# that are consumed by the output. Only need to make sure we create
# corresponding counterparts in the resulting graph.
main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
else:
# All component results consumed by the output node should be
# marked as "used in main".
used_in_main[x] = None
# If a node is used in main graph then we mark it as an output in the component
# it belongs to.
for n in used_in_main:
if n.op != "placeholder":
node_to_component[n].orig_outputs.append(n)
# Now we create a graphmodule for each component.
orig_to_split_fqn_mapping: Dict[str, str] = {}
for comp in all_components:
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
if return_tuple:
comp.graph.output(outs)
else:
# Take care of the args of FX output node. If there's a single
# output then the output node args is like (output_single), else
# if there're multiple outputs then the output node args is like
# ((output_0, output_1, ...)).
comp.graph.output(outs[0] if len(outs) == 1 else outs)
comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
gm, subgraph=comp.graph, comp_name=comp.name
)
orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
# Create a call_module node in main graph.
main_node = main_g.call_module(
comp.name,
args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
kwargs=None,
)
if len(outs) == 1 and not return_tuple:
main_remapping[comp.orig_outputs[0]] = main_node
else:
for i, o in enumerate(comp.orig_outputs):
# Use Proxy to record getitem access.
main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
main_root = HolderModule({comp.name: comp.gm for comp in all_components})
main_g._codegen = gm.graph._codegen
# If the output nodes consumes get_attr directly in the original graph,
# then we need to make sure get_attr is copied to the new graph.
for x in flatten(output_node.args[0]):
if x.op == "get_attr":
setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
result_gm = GraphModuleCls(main_root, main_g)
if return_fqn_mapping:
return result_gm, orig_to_split_fqn_mapping
return result_gm

View File

@ -0,0 +1,898 @@
# mypy: allow-untyped-defs
import argparse
import copy
from collections import defaultdict
from dataclasses import dataclass
from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
import logging
import torch
from torch.fx.passes.graph_manipulation import get_size_of_node
from torch.fx.node import map_arg
from torch.fx._compatibility import compatibility
from .operator_support import (
get_node_target,
OperatorSupportBase,
)
from .graph_drawer import FxGraphDrawer
from .shape_prop import ShapeProp
from .split_utils import split_by_tags
from .tools_common import (
FxNetAccFusionsFinder,
CALLABLE_NODE_OPS,
Tensors,
NodeList,
NodeSet,
is_node_output_tensor,
)
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
_LOGGER = logging.getLogger(__name__)
DEFAULT_MIN_ACC_MODULE_SIZE = 1
DEFAULT_SKIP_FUSION = False
DEFAULT_ALLOW_NON_TENSOR = False
class _SplitterSettingBase:
def __init__(
self,
min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
skip_fusion=DEFAULT_SKIP_FUSION,
allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
max_acc_splits: int = -1,
):
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-acc-module-size",
"--min_acc_module_size",
required=False,
type=int,
help="Minimum size limit of an accelerator subgraph.",
)
parser.add_argument(
"--max-acc-splits",
"--max_acc_splits",
required=False,
type=int,
help="Enforce a maximum number of split subgraphs.",
)
parser.add_argument(
"--skip-fusion",
"--skip_fusion",
default=False,
action="store_true",
help="If true then no fusion groups. Fusion group is used to "
"enforce no non-tensor data flow between submodules. If we don't "
"have this constrain, setting this to false is recommended as it "
"can reduce overhead.",
)
parser.add_argument(
"--allow-non-tensor",
"--allow_non_tensor",
default=False,
action="store_true",
help="For some backends non-tensor data flow between cpu and them "
"are not allowed. Therefore, if a node supported by accelerator but "
"it has non-tensor inputs or outputs to a cpu node we would want to "
"consider it as a cpu node during splitting. However, for some backends "
"we might not care about non-tensor data flow and we can set this option "
"to true to disable the functionality that prevent non-tensor data flow.",
)
args, unknown = parser.parse_known_args()
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
self.max_acc_splits: int = max_acc_splits
@compatibility(is_backward_compatible=False)
class FxNetAccNodesFinder:
"""
Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
I.e. if we have a chain:
ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
This behavior can be turned off by passing allow_non_tensor=True.
"""
def __init__(
self,
module: torch.fx.GraphModule,
operator_support: OperatorSupportBase,
allow_non_tensor: bool,
):
self.module = module
self.operator_support = operator_support
self.allow_non_tensor = allow_non_tensor
self.acc_nodes: NodeSet = set()
def reduce_acc_nodes_non_tensor_input_helper(
self, cpu_worklist: NodeList
):
"""
Transitively excludes nodes from ACC supported set.
For every node in the worklist:
- removes its downstream ACC nodes from ACC supported set,
- if any downstream ACC node produces non-tensor output,
then it gets added into the worklist.
"""
while cpu_worklist:
node = cpu_worklist.pop(0)
for user in node.users:
if user in self.acc_nodes:
self.acc_nodes.remove(user)
if not is_node_output_tensor(user):
cpu_worklist.append(user)
def reduce_acc_nodes_non_tensor_input(self):
"""
Excludes nodes from ACC supported set that have direct
upstream CPU nodes that produce non-tensor outputs.
"""
non_tensor_cpu_nodes: NodeList = []
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
if node in self.acc_nodes:
continue
if is_node_output_tensor(node):
continue
non_tensor_cpu_nodes.append(node)
self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
def reduce_acc_nodes_non_tensor_output(self):
"""
Excludes nodes from ACC supported set that produce non-tensor
outputs and have downstream CPU nodes.
"""
while True:
new_cpu_nodes: NodeList = []
for acc_node in self.acc_nodes:
if is_node_output_tensor(acc_node):
continue
for user in acc_node.users:
if user not in self.acc_nodes:
new_cpu_nodes.append(acc_node)
break
if not new_cpu_nodes:
break
for new_cpu_node in new_cpu_nodes:
self.acc_nodes.remove(new_cpu_node)
self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
def __call__(self) -> NodeSet:
submodules = dict(self.module.named_modules())
self.acc_nodes = {
n
for n in self.module.graph.nodes
if n.op in CALLABLE_NODE_OPS
and self.operator_support.is_node_supported(submodules, n)
}
if not self.allow_non_tensor:
self.reduce_acc_nodes_non_tensor_input()
self.reduce_acc_nodes_non_tensor_output()
return self.acc_nodes
@compatibility(is_backward_compatible=False)
class FxNetSplitterInternalError(Exception):
pass
@compatibility(is_backward_compatible=False)
@dataclass
class Subgraph:
is_acc: bool
nodes: NodeList
device_ordinal: Optional[int] = None
@compatibility(is_backward_compatible=False)
class SplitResult(NamedTuple):
"""
Stores the results of the splitter.
Attributes:
split_module: root module after splitting.
submodule_inputs: a dict that maps submodule name to its inputs.
non_acc_submodule_prefix: the prefix for non acc submodules. For
acc submodule the prefix is alwasy "_run_on_acc_".
"""
split_module: torch.fx.GraphModule
submodule_inputs: Dict[str, Any]
non_acc_submodule_prefix: str
@compatibility(is_backward_compatible=False)
def generate_inputs_for_submodules(
model: torch.nn.Module,
inputs: Sequence[Any],
target_submodules: Iterable[str],
deepcopy: bool = False,
) -> Dict[str, Any]:
"""
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
function doesn't work.
Args:
model: root model.
inputs: inputs to the root model.
target_submodules: submodules that we want to generate inputs for.
Returns:
A dict that maps from submodule name to its inputs.
"""
handles = []
results = {}
submodule_to_names = {mod: name for name, mod in model.named_modules()}
def pre_forward(module, module_inputs):
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
for name, mod in model.named_modules():
if name in target_submodules:
handles.append(mod.register_forward_pre_hook(pre_forward))
def clean_up_handles():
for h in handles:
h.remove()
try:
with torch.no_grad():
model(*inputs)
except Exception as e:
clean_up_handles()
raise e
clean_up_handles()
return results
class _SplitterBase:
"""
Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
Given the following graph:
==> b ==>
// \\
a d
\\ //
==> c ==>
class SimpleModule(torch.nn.Module):
def forward(self, a):
b = torch.sin(a)
c = torch.cos(a)
d = b + c
return d
and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
we will get the following split result:
main:
def forward(self, a):
run_on_acc_0_0 = self._run_on_acc_0_0(a)
getitem = run_on_acc_0_0[0]
getitem_1 = run_on_acc_0_0[1]
run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
return run_on_cpu_1_1
_run_on_acc_0_0:
def forward(self, a):
sin_1 = torch.sin(a)
cos_1 = torch.cos(a)
return (sin_1, cos_1)
_run_on_cpu_1_1:
def forward(self, sin_1, cos_1):
add_1 = sin_1 + cos_1
return add_1
"""
# PCIe bandwidth for the backend, default to 100 GB/s
PCIe_BW = 100 * 2 ** 30
def __init__(
self,
module: torch.fx.GraphModule,
sample_input: Sequence[Any],
operator_support: OperatorSupportBase,
settings: _SplitterSettingBase,
non_acc_submodule_name: str = "_run_on_cpu_",
return_tuple: bool = False,
):
"""
Preprocesses graph before splitting:
- finds nodes supported by ACC,
- finds fusion groups for ACC nodes having non-tensor IO,
- builds a graph of direct dependencies,
- builds a map of fused nodes to their fusions.
As a result we get self.acc_nodes, self.deps and self.fusions.
"""
assert isinstance(module, torch.fx.GraphModule)
self.module = module
ShapeProp(self.module).propagate(*sample_input)
self.settings = settings
self.operator_support = operator_support
self.sample_input = sample_input
self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
if self.settings.skip_fusion:
self.fusions = {}
else:
self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
# Modify deps to add more deps for fused nodes
self.deps = self.find_deps()
self.update_deps_for_fusions()
self.non_acc_submodule_name = non_acc_submodule_name
self._node_submodule_map: Dict[str, str] = {}
self._return_tuple = return_tuple
self.tags: List[str] = []
# ===============================================================
# Helpers for ctor and initial state
# ===============================================================
def get_node_submodule_map(self) -> Dict[str, str]:
""" Returns a map from node name to submodule name, e.g.
node: main_module_impl_impl_over_arch_unary_multiple_embedding
_pooling_embedding_pooling_sparse_entity_equivalence_key
_proxy_embedding_bag
maps to submodule name of: _run_on_acc_1
"""
return self._node_submodule_map
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
"""
Builds a graph of node dependencies. Leaf nodes don't have any
dependencies and the "output" node doesn't have nodes depending on it.
Resulting graph has only direct dependencies, i.e. there are no
transitive dependencies.
"""
deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
for user in node.users:
if user.op != "output":
deps[user].add(node)
return deps
def update_deps_for_fusions(self):
"""
Updates graph of dependencies so that:
- nodes from the same fusion depend on the same set of outer nodes,
- outer nodes depending on a fusion depend on all nodes in that fusion.
"""
for node in self.fusions:
fusion = self.fusions[node]
for fused_neighbor in fusion:
self.deps[node].update(self.deps[fused_neighbor] - fusion)
for user in fused_neighbor.users:
if user not in fusion:
self.deps[user].add(node)
# ===============================================================
# Helpers for preview
# ===============================================================
def _lower_model_to_backend(
self, mod: torch.fx.GraphModule, inputs: Tensors
) -> torch.nn.Module:
"""
Lower the model to a backend.
"""
return mod
def _find_culprit(
self, mod: torch.fx.GraphModule, inputs: Tensors
) -> str:
"""
When an error occurs during lowering or running the lowered mod, we use this
function to find culprits in the `mod` that causes the error.
"""
return "Unable to find a culprit because _find_culprit() function is not implemented."
def _draw_graph_based_on_node_support(
self, mod: torch.fx.GraphModule, supported_nodes: NodeList
):
color_map = {
"default": "AliceBlue",
"supported": "chartreuse1",
"unsupported": "crimson",
}
class CustomDrawer(FxGraphDrawer):
def _get_node_style(self, node):
template = super()._get_node_style(node)
if node in supported_nodes:
template["fillcolor"] = color_map["supported"]
elif node.op in CALLABLE_NODE_OPS:
template["fillcolor"] = color_map["unsupported"]
else:
template["fillcolor"] = color_map["default"]
return template
drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
dot_graph = drawer.get_main_dot_graph()
# pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
dot_graph.write_raw("node_support.dot")
def node_support_preview(self, dump_graph: bool = False):
submodules = dict(self.module.named_modules())
supported_nodes: NodeList = []
supported_node_types = defaultdict(set)
unsupported_node_types = defaultdict(set)
def get_dtype(arg):
tensor_meta = arg.meta.get("tensor_meta")
return getattr(tensor_meta, "dtype", None)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
target = get_node_target(submodules, node)
# Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
arg_dtypes = [
get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
for arg in node.args
]
# Find last non-None element. If all elements are None, return max_len.
last_index = len(arg_dtypes) - next(
(
i
for i, dtype in enumerate(reversed(arg_dtypes))
if dtype is not None
),
len(arg_dtypes),
)
# Strip None elements at the end.
arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
kwarg_dtypes_tuple = tuple(
(k, get_dtype(arg))
for k, arg in node.kwargs.items()
if isinstance(arg, torch.fx.Node)
)
if self.operator_support.is_node_supported(submodules, node):
supported_nodes.append(node)
supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
else:
unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
if dump_graph:
self._draw_graph_based_on_node_support(self.module, supported_nodes)
reports = "\nSupported node types in the model:\n"
for t, dtypes in supported_node_types.items():
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
reports += "\nUnsupported node types in the model:\n"
for t, dtypes in unsupported_node_types.items():
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
print(reports)
# Return reports for testing purpose
return reports
def split_preview(self, dump_graph: bool = False):
reports = ""
subgraphs = self.put_nodes_into_subgraphs()
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
for i, subgraph in enumerate(subgraphs):
reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
reports += f"{len(subgraph.nodes)} node(s)\n"
self.tag(subgraphs)
split_mod = self.split(remove_tag=True)
split_mod.eval()
if dump_graph:
drawer = FxGraphDrawer(
split_mod, "preview", ignore_getattr=True
)
dot_graphs = drawer.get_all_dot_graphs()
for name, dot_graph in dot_graphs.items():
# pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
dot_graph.write_raw(f"{name}.dot")
max_qps: float = self.PCIe_BW
bottleneck_module = ""
for node in split_mod.graph.nodes:
if node.op == "call_module" and "acc" in node.target:
reports += f"\nProcessing acc submodule {node.target}\n"
submod = getattr(split_mod, node.target)
def get_submod_inputs(main_mod, submod, example_inputs):
sub_inputs = None
def get_inputs(self, inputs):
nonlocal sub_inputs
sub_inputs = inputs
handle = submod.register_forward_pre_hook(get_inputs)
main_mod(*example_inputs)
handle.remove()
return sub_inputs
submod_inputs = get_submod_inputs(
split_mod, submod, self.sample_input
)
ShapeProp(submod).propagate(*submod_inputs)
total_input_bytes = 0
total_output_bytes = 0
reports += "Checking inputs...\n"
for n in submod.graph.nodes:
if n.op == "placeholder":
if not is_node_output_tensor(n):
reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
else:
total_input_bytes += get_size_of_node(submod, n)[0]
if n.op == "output":
output_node = n
reports += "Checking outputs...\n"
def get_bytes(node: torch.fx.Node):
nonlocal total_output_bytes
nonlocal reports
if not is_node_output_tensor(node):
reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
else:
total_output_bytes += get_size_of_node(submod, node)[0]
map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
if qps < max_qps:
max_qps = qps
bottleneck_module = node.target
try:
lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
except RuntimeError:
reports += "Run into an error during lowering!\n"
reports += self._find_culprit(submod, submod_inputs)
continue
try:
lowered_submod(*submod_inputs)
except RuntimeError:
reports += "Run into an error during inference!\n"
reports += self._find_culprit(submod, submod_inputs)
else:
reports += "Lowering and running succeed!\n"
reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
reports += f" bottleneck is submodule {bottleneck_module}."
print(reports)
# return the reports for testing purposes
return reports
# ===============================================================
# Helpers for extend_acc_subgraph() method
# ===============================================================
def find_reverse_deps(
self, tag_id: Optional[int] = None
) -> Dict[torch.fx.Node, NodeSet]:
"""
Builds reversed topological node dependencies, if tag_id is specified,
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
"""
result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
for user in node.users:
if user.op not in CALLABLE_NODE_OPS:
continue
if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
result[node].add(user)
return result
def update_reverse_deps_for_fusions(
self, deps: Dict[torch.fx.Node, NodeSet]
):
processed_node = set()
for node, fusion in self.fusions.items():
if node in processed_node:
continue
new_dep = set()
# Create a new dependency set which include all the
# dependencies of the nodes in the fusion group
for n in fusion:
new_dep.update(deps[n])
# Exclude nodes in the fusion
new_dep.difference_update(fusion)
# Update dependency
for n in fusion:
deps[n] = new_dep
for arg in n.all_input_nodes:
if arg not in fusion:
deps[arg].update(fusion)
processed_node.add(n)
def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
"""
Finds parent nodes of the `tag` subgraph.
Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
and is not a placeholder, we consider it as the parent node of the subgraph.
"""
parent_nodes = set()
for node in self.module.graph.nodes:
if node.op in CALLABLE_NODE_OPS and node.tag == tag:
for arg in node.all_input_nodes:
if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
parent_nodes.add(arg)
return parent_nodes
def extend_acc_subgraph(self, tag: str):
"""
Extend the acc subgraph with `tag` going the reversed topological direction.
"""
# Dict that maps node to its users and ignore users that
# are in the subgraph that has greater tag
deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
self.update_reverse_deps_for_fusions(deps)
# Parent nodes of the subgraph
parent_nodes = self.find_parent_nodes_of_subgraph(tag)
visited_nodes: NodeSet = set()
while parent_nodes:
node = None
# Find a acc node that depends on visited nodes only
for n in parent_nodes:
if deps[n] <= visited_nodes and n in self.acc_nodes:
node = n
break
if node is None:
break
# Put the node into `tag` subgraph
node.tag = tag # type: ignore[attr-defined]
parent_nodes.remove(node)
visited_nodes.add(node)
# If node is in a fusion group, add all fusion buddies to parent nodes
if node in self.fusions:
for fusion_node in self.fusions[node]:
if fusion_node not in visited_nodes:
parent_nodes.add(fusion_node)
# Add inputs of the node to parent nodes
for arg in node.all_input_nodes:
if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
parent_nodes.add(arg)
# ===============================================================
# Helpers for split() method
# ===============================================================
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
"""
Finds nodes that consume module inputs or get_attr nodes.
"""
starter_cpu_nodes: NodeSet = set()
starter_acc_nodes: NodeSet = set()
for node in self.module.graph.nodes:
if node.op not in {"placeholder", "get_attr"}:
continue
for user in node.users:
if user in self.acc_nodes:
starter_acc_nodes.add(user)
else:
starter_cpu_nodes.add(user)
return starter_cpu_nodes, starter_acc_nodes
def put_nodes_into_subgraphs(self) -> List[Subgraph]:
# We start graph traversal from leaf nodes
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set()
# Determine which subgraph to start from based on which subgraph has
# 0-dep node
acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
current_subgraph_nodes: NodeList = []
# Result accumulator
subgraphs: List[Subgraph] = []
while current_cpu_nodes or current_acc_nodes:
# Find the first node that should belong to the current subgraph and has all dependencies resolved
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
node = next(
(n for n in current_nodes if self.deps[n] <= visited_nodes),
None,
)
# If nothing was found, then it's time to flip the mode and start a new subgraph
if node is None:
if not current_subgraph_nodes:
raise FxNetSplitterInternalError("Subgraph can't be empty")
subgraphs.append(
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
)
acc_subgraph = not acc_subgraph
current_subgraph_nodes = []
continue
current_nodes.remove(node)
visited_nodes.add(node)
current_subgraph_nodes.append(node)
# Add fusion buddies
if node in self.fusions:
if node in self.acc_nodes:
current_acc_nodes.update(self.fusions[node] - visited_nodes)
else:
current_cpu_nodes.update(self.fusions[node] - visited_nodes)
# Put depending nodes into the queue
for user in node.users:
if user.op not in CALLABLE_NODE_OPS:
continue
# Add downstream nodes
if user in self.acc_nodes:
current_acc_nodes.add(user)
else:
current_cpu_nodes.add(user)
# Check if the last subgraph was not created
if current_subgraph_nodes:
subgraphs.append(
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
)
if not subgraphs:
raise FxNetSplitterInternalError("Couldn't create subgraphs")
return subgraphs
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
"""
This pass finds ACC submodules with less than specified size and merges
them with adjacent CPU submodules.
"""
result: List[Subgraph] = []
for subgraph in subgraphs:
if subgraph.is_acc:
if len(subgraph.nodes) >= self.settings.min_acc_module_size:
result.append(subgraph)
else:
print(
"Eliminating acc subgraph because it's smaller than the threshold: "
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
)
if result:
result[-1].nodes.extend(subgraph.nodes)
else:
subgraph.is_acc = False
result.append(subgraph)
else:
if result and not result[-1].is_acc:
result[-1].nodes.extend(subgraph.nodes)
else:
result.append(subgraph)
return result
def tag(self, subgraphs: List[Subgraph]):
self.tags = []
for subgraph in subgraphs:
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
self.tags.append(tag)
for node in subgraph.nodes:
if hasattr(node, "tag"):
raise FxNetSplitterInternalError(f"Node {node} was already tagged")
node.tag = tag # type: ignore[attr-defined]
self._node_submodule_map[node.name] = tag
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple)
if remove_tag:
for node in self.module.graph.nodes:
if hasattr(node, "tag"):
del node.tag
return split_module # type: ignore[return-value]
def __call__(self) -> torch.fx.GraphModule:
subgraphs = self.put_nodes_into_subgraphs()
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
self.tag(subgraphs)
return self.split()
def generate_split_results(self) -> SplitResult:
split_module = self()
submodule_names = []
for name, mod in split_module.named_children():
submodule_names.append(name)
if (
self.settings.max_acc_splits > 0
and len(submodule_names) > self.settings.max_acc_splits
):
raise ValueError(
"Cannot fulfill max_acc_splits limit. "
"This may cause split fragmentation and "
"result in performance issues."
)
submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)

View File

@ -0,0 +1,58 @@
import unittest
from ..pass_manager import (
inplace_wrapper,
PassManager,
these_before_those_pass_constraint,
this_before_that_pass_constraint,
)
class TestPassManager(unittest.TestCase):
def test_pass_manager_builder(self) -> None:
passes = [lambda x: 2 * x for _ in range(10)]
pm = PassManager(passes)
pm.validate()
def test_this_before_that_pass_constraint(self) -> None:
passes = [lambda x: 2 * x for _ in range(10)]
pm = PassManager(passes)
# add unfulfillable constraint
pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
self.assertRaises(RuntimeError, pm.validate)
def test_these_before_those_pass_constraint(self) -> None:
passes = [lambda x: 2 * x for _ in range(10)]
constraint = these_before_those_pass_constraint(passes[-1], passes[0])
pm = PassManager(
[inplace_wrapper(p) for p in passes]
)
# add unfulfillable constraint
pm.add_constraint(constraint)
self.assertRaises(RuntimeError, pm.validate)
def test_two_pass_managers(self) -> None:
"""Make sure we can construct the PassManager twice and not share any
state between them"""
passes = [lambda x: 2 * x for _ in range(3)]
constraint = these_before_those_pass_constraint(passes[0], passes[1])
pm1 = PassManager()
for p in passes:
pm1.add_pass(p)
pm1.add_constraint(constraint)
output1 = pm1(1)
self.assertEqual(output1, 2 ** 3)
passes = [lambda x: 3 * x for _ in range(3)]
constraint = these_before_those_pass_constraint(passes[0], passes[1])
pm2 = PassManager()
for p in passes:
pm2.add_pass(p)
pm2.add_constraint(constraint)
output2 = pm2(1)
self.assertEqual(output2, 3 ** 3)

View File

@ -0,0 +1,303 @@
# mypy: allow-untyped-defs
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
import collections
from dataclasses import dataclass
import operator
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
from torch.fx._compatibility import compatibility
__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
TensorOrTensors = Union[torch.Tensor, Tensors]
NodeList = List[torch.fx.Node]
NodeSet = Set[torch.fx.Node]
Names = List[str]
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
@compatibility(is_backward_compatible=False)
def get_acc_ops_name(k):
if isinstance(k, str):
return k
elif k.__module__ and "acc_ops" in k.__module__:
return f"acc_ops.{k.__name__}"
else:
module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
return f"{module if module else ''}.{k.__name__}"
@compatibility(is_backward_compatible=False)
def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
"""
Given a `node` returns its target typename.
For "call_method" node, return node.target which is the name of that method being called.
This could potential lead to conflict but should be okay because normally it's on a tensor.
For "call_function" node, return typename of node.target.
For "call_module" node, return typename of the module that node.target point to.
If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
"torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
"""
assert node.op in CALLABLE_NODE_OPS, (
"Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
)
if node.op == "call_module":
assert isinstance(node.target, str)
submod = submodules[node.target]
submod_type = getattr(submod, "_base_class_origin", type(submod))
return get_acc_ops_name(submod_type)
elif node.op == "call_function":
target: Any = node.target
return (
f"acc_ops.{target.__name__}"
if target.__module__ is not None and "acc_ops" in target.__module__
else _get_qualified_name(target)
)
else:
assert isinstance(node.target, str)
return node.target
@compatibility(is_backward_compatible=False)
def is_node_output_tensor(node: torch.fx.Node) -> bool:
"""Checks if the node output produces a Tensor or not.
NOTE: This requires to run `ShapeProp` on the containing fx graph before
calling this function. This is because it works by checking the `type`
metadata on the node. This metadata is produced by the `ShapeProp`.
"""
type_ = node.meta.get("type", None)
return type_ is not None and issubclass(type_, torch.Tensor)
@compatibility(is_backward_compatible=False)
class FxNetAccFusionsFinder:
"""
Finds groups of connected ACC nodes that pass non-tensor data between each other.
Such groups are called fusion groups.
"""
def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
self.module = module
self.nodes = list(module.graph.nodes)
self.acc_nodes = acc_nodes
@dataclass
class FusionGroup:
# The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
top_node_idx: int
# Nodes in this fusion group.
nodes: NodeSet
# Inputs to this fusion group.
inputs: NodeSet
# Nodes that in the fusion group that haven't been processed yet.
nodes_need_process: NodeSet
def add_node(self, node):
"""
Add a node to fusion group.
"""
if node in self.nodes:
return
self.nodes_need_process.add(node)
self.nodes.add(node)
self.inputs.discard(node)
self.inputs.update(
{
n
for n in node.all_input_nodes
if n.op in CALLABLE_NODE_OPS and n not in self.nodes
}
)
def recursive_add_node(
self,
fusion_group: "FxNetAccFusionsFinder.FusionGroup",
inputs: Union[NodeSet, NodeList],
visited: Optional[NodeSet] = None,
):
"""
Start from inputs and going reverse topological order. If any upstream node
is in the fusion group, add all the nodes in this path to fusion group.
"""
for arg in inputs:
# skip the node if already seen
if visited is not None:
if arg in visited:
continue
visited.add(arg)
# Skip placeholder and get_attr because they won't be in the fusion group.
if arg.op not in CALLABLE_NODE_OPS:
continue
# If the node has smaller idx, it's already an upstream node of the fusion
# group. We don't need to check it anymore.
if self.nodes.index(arg) < fusion_group.top_node_idx:
continue
# If the node is in the fusion group, return True.
if arg in fusion_group.nodes:
return True
# Check the upstream nodes of the node, if any of them is in the fusion group
# we'll add this node to fusion group and return True.
if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
fusion_group.add_node(arg)
return True
return False
def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
result: Dict[torch.fx.Node, NodeSet] = {}
acc_nodes = list(self.acc_nodes)
for node in acc_nodes:
if node in result:
continue
if node.op not in CALLABLE_NODE_OPS:
continue
if "tensor_meta" in node.meta:
continue
if node not in self.acc_nodes:
continue
fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
top_node_idx=self.nodes.index(node),
nodes={node},
inputs=set(node.all_input_nodes),
nodes_need_process={node},
)
while fusion_group.nodes_need_process:
node = fusion_group.nodes_need_process.pop()
self.recursive_add_node(
fusion_group,
fusion_group.inputs,
visited=set(),
)
# Optionally add downstream nodes
if "tensor_meta" not in node.meta:
for user in node.users:
if user.op not in CALLABLE_NODE_OPS:
continue
if user in fusion_group.nodes:
continue
fusion_group.add_node(user)
self.recursive_add_node(
fusion_group,
fusion_group.inputs,
visited=set(),
)
# Add some upstream nodes
for arg in node.all_input_nodes:
if arg.op not in CALLABLE_NODE_OPS:
continue
if "tensor_meta" in arg.meta:
continue
if arg in fusion_group.nodes:
continue
fusion_group.add_node(arg)
fusion_group.top_node_idx = min(
fusion_group.top_node_idx, self.nodes.index(arg)
)
self.recursive_add_node(
fusion_group,
fusion_group.inputs,
visited=set(),
)
if not (set(fusion_group.nodes) <= self.acc_nodes):
self.acc_nodes -= fusion_group.nodes
else:
for n in fusion_group.nodes:
result[n] = fusion_group.nodes
return result
@compatibility(is_backward_compatible=False)
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Replace the graph of the given GraphModule with one that contains the same nodes as the
original, but in topologically sorted order.
This is used by the merge_matmul transformation below, which disturbs the topologically sorted
order of its input GraphModule, so that this order is restored before further transformation.
Arguments:
gm: The graph module to topologically sort. It is modified in-place.
Returns:
The graph module in-place sorted
"""
# These operators are used for making runtime assertions before any
# data-dependent operators occur. We want to prioritize sorting these to
# ensure that these assertions appear before any data-dependent operations
# in the graph.
PRIORITIZED_OPS = [
operator.add,
operator.mul,
operator.sub,
operator.floordiv,
operator.truediv,
operator.mod,
operator.le,
operator.lt,
operator.ge,
operator.gt,
operator.eq,
operator.ne,
torch.ops.aten.sym_constrain_range.default,
torch.ops.aten.sym_constrain_range_for_size.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten.scalar_tensor.default,
torch.ops.aten._assert_scalar.default,
]
indeg = dict.fromkeys(gm.graph.nodes, 0)
new_graph = torch.fx.Graph()
# Track how many unfulfilled dependencies each node has
for node in gm.graph.nodes:
for user in node.users:
indeg[user] += 1
queue: collections.deque = collections.deque()
# Add all nodes with no dependencies to the queue
for node in gm.graph.nodes:
if indeg[node] == 0:
queue.append(node)
env: Dict[torch.fx.Node, torch.fx.Node] = {}
# Pop nodes from the queue, and add nodes that have had all their
# dependencies fulfilled
while len(queue) > 0:
cur = queue.popleft()
env[cur] = new_graph.node_copy(cur, lambda x: env[x])
for user in cur.users:
indeg[user] -= 1
if indeg[user] == 0:
if user.op == "call_function" and user.target in PRIORITIZED_OPS:
queue.appendleft(user)
else:
queue.append(user)
# If the new graph's size is not as large as the old one, then there must be
# a cycle (i.e. some node's dependencies were not satisfied.)
if len(new_graph.nodes) < len(gm.graph.nodes):
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
new_graph._codegen = gm.graph._codegen
gm.graph = new_graph
return gm

View File

@ -0,0 +1 @@
from .common import lift_subgraph_as_module, HolderModule, compare_graphs

View File

@ -0,0 +1,96 @@
# mypy: allow-untyped-defs
from typing import Dict, Tuple
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.nn import Module
__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
@compatibility(is_backward_compatible=False)
class HolderModule(Module):
"""
HolderModule is used to copy all the attributes from original module to submodules
that uses the attributes
"""
def __init__(self, d):
super().__init__()
for k, v in d.items():
self.add_module(k, v)
@compatibility(is_backward_compatible=False)
def lift_subgraph_as_module(
gm: GraphModule,
subgraph: Graph,
comp_name: str = "",
class_name: str = "GraphModule",
) -> Tuple[GraphModule, Dict[str, str]]:
"""
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
Args:
gm (GraphModule): parent graph module
subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
comp_name (str): name for the new component
class_name (str): name for the submodule
"""
# Loop through all module calls (call_module) and param fetches (get_attr)
# in this component, creating HolderModules as necessary to match the path.
# e.g. if in the original module there's a get_attr node fetches "conv.weight".
# We create a HolderModule as root -> add a HolderModule named "conv" ->
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in
# the original module.
submodule = HolderModule({})
orig_to_split_fqn_mapping: Dict[str, str] = {}
for n in subgraph.nodes:
if n.op not in ("call_module", "get_attr"):
continue
target = n.target
assert isinstance(target, str)
target_name_parts = target.split(".")
curr = submodule
orig_gm = gm
for name in target_name_parts[:-1]:
if not hasattr(curr, name):
curr.add_module(name, HolderModule({}))
curr = getattr(curr, name)
orig_gm = getattr(orig_gm, name)
leaf_node_name = target_name_parts[-1]
leaf_node = getattr(orig_gm, leaf_node_name)
orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
# Relies on custom __setattr__ magic.
setattr(curr, leaf_node_name, leaf_node)
return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
@compatibility(is_backward_compatible=False)
def compare_graphs(left: Graph, right: Graph) -> bool:
"""
Return True if two graphs are identical, i.e they
- have the same number of outputs in the same order
- have the same number of inputs in the same order
- have the same set of nodes, and identical connectivity
"""
matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
matches = matcher.match(right)
return len(matches) > 0

View File

@ -0,0 +1,236 @@
# mypy: allow-untyped-defs
import copy
from queue import SimpleQueue
from typing import List, Dict, Tuple
import torch.fx
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
from torch.fx.passes.utils import lift_subgraph_as_module
from torch.fx._compatibility import compatibility
@compatibility(is_backward_compatible=False)
def topo_sort(nodes: NodeList) -> NodeList:
# sort nodes according to the topological order
indegree_map = dict.fromkeys(nodes, 0)
candidates: SimpleQueue = SimpleQueue()
for node in nodes:
for n in node.all_input_nodes:
if n in indegree_map:
indegree_map[node] += 1
if indegree_map[node] == 0:
candidates.put(node)
sorted_nodes: NodeList = []
while not candidates.empty():
node = candidates.get()
sorted_nodes.append(node)
for n in node.users:
if n in indegree_map:
indegree_map[n] -= 1
if indegree_map[n] == 0:
candidates.put(n)
assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
return sorted_nodes
@compatibility(is_backward_compatible=False)
def validate_partition(partition: NodeList) -> bool:
# verify the partition does't form a dependency cycle in the original graph
# returns True for valid partition, False for invalid
partition_set = set(partition)
outputs: NodeList = []
for node in partition_set:
for user_node in node.users:
if user_node not in partition_set:
# external user node, need to expose as an output
outputs.append(user_node)
# Perform BFS on the partition outputs.
# If it reaches a node within the partition, then it found a cycle.
# This function takes the ownership of `root_nodes` and may modify it.
def bfs_find_cycle(root_nodes: NodeList) -> bool:
# Set used to exclude nodes that have already been visited.
# If a node has been visited, that node and all its children have
# been checked for cycles.
visited: NodeSet = set()
# Start with `root_nodes` and traverse through (toward child nodes)
# their connected sub-graph. Nodes in `visited` won't be added
# to `queue` again.
queue: NodeList = root_nodes
while queue:
current = queue.pop()
visited.add(current)
if current in partition_set:
# Started from partition's `output` nodes, and reached
# another node in partition. Cycle!
return True
for user_node in current.users:
if user_node in visited:
continue
queue.append(user_node)
# `root_nodes` don't cause cycle.
return False
# Use all output nodes as roots to traverse
# the graph to check cycles.
if bfs_find_cycle(outputs):
return False
return True
@compatibility(is_backward_compatible=False)
def fuse_as_graphmodule(gm: GraphModule,
nodes: NodeList,
module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
"""
Fuse nodes in graph_module into a GraphModule.
Args:
gm (GraphModule): target graph_module
nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted
module_name: class name for the fused GraphModule
Returns:
fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm`
original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm`
original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm`
"""
# assumption: nodes are already sorted in topo order
for node in nodes:
assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
assert not node._erased, f"{node} has been removed from owning graph"
assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}"
# validates partition doesn't introduce dependency circles in the graph
assert validate_partition(nodes), "Invalid partition, found dependency cycles"
subgraph = Graph()
node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
# handles inputs through graph.node_copy's arg_transform functions
def remap_inputs(x):
if x.op == "get_attr":
# TODO: do we really need copy the get_attr node into the graph?
# do something here
pass
if x in nodes:
# x is inside subgraph, return the copied node
# the node should have been copied aleady, as we are copying graph in the topological order
return node_map[x]
if x not in node_to_placeholder:
# x is not in subgraph, create a new placeholder for subgraph
placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
# copy all meta fields, even if some fields might be irrelvant for the placeholder node
placeholder_node.meta = copy.copy(x.meta)
node_to_placeholder[x] = placeholder_node
return node_to_placeholder[x]
# copy nodes in topological order
for node in nodes:
new_node = subgraph.node_copy(node, remap_inputs)
node_map[node] = new_node
# handles outputs
output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs
for node in nodes:
for user_node in node.users:
if user_node not in nodes:
# external user node, need to expose as an output
output_mapping[node] = node_map[node]
# outs contain nodes in the new subgraph
outs = tuple(output_mapping.values())
# Take care of the args of FX output node. If there's a single
# output then the output node args is like (output_single), else
# if there're multiple outputs then the output node args is like
# ((output_0, output_1, ...)).
subgraph.output(outs[0] if len(outs) == 1 else outs)
# lint to ensure correctness
subgraph.lint()
fused_gm: GraphModule
fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name)
# sub_gm's input nodes in the original module
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
# sub_gm's outputs node in the original module
original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
return fused_gm, original_inputs, original_outputs
@compatibility(is_backward_compatible=False)
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
# add sub_gm into gm
submodule_name = sub_gm.__class__.__name__
gm.add_submodule(submodule_name, sub_gm)
# Create a call_module node in main graph.
module_node = gm.graph.call_module(
submodule_name,
args=orig_inputs,
kwargs=None)
if len(orig_outputs) == 1:
# main_remapping[comp.orig_outputs[0]] = module_node
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
else:
for i, orig_output in enumerate(orig_outputs):
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs)
return gm
@compatibility(is_backward_compatible=False)
def erase_nodes(gm: GraphModule, nodes: NodeList):
# erase original nodes in inversed topological order
for node in reversed(nodes):
gm.graph.erase_node(node)
@compatibility(is_backward_compatible=False)
def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule:
for partition_id, nodes in enumerate(partitions):
sorted_nodes = topo_sort(nodes)
submodule_name = prefix + str(partition_id)
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)
insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
erase_nodes(gm, sorted_nodes)
# topological sort original gm with newly created sub_gm
legalize_graph(gm)
return gm

View File

@ -0,0 +1,401 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass, field
from collections import defaultdict
import copy
import torch
from torch.fx import (
Node,
Graph,
)
from torch.fx._compatibility import compatibility
from typing import Dict, List, Set, Any, Union, Tuple
import logging
import os
__all__ = ['SubgraphMatcher', 'InternalMatch']
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
def _init_logger():
logger = logging.getLogger(__name__)
level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter("%(filename)s > %(message)s")
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger
logger = _init_logger()
@compatibility(is_backward_compatible=False)
@dataclass
class InternalMatch:
# Nodes from which the match was found
anchors: List[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict)
# nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list)
# nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list)
# map from a string name to a node in the target graph
# only available if the matcher is `SubgraphMatcherWithNameNodesMap`
name_node_map: Dict[str, Node] = field(default_factory=dict)
def __copy__(self):
return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(),
placeholder_nodes=self.placeholder_nodes.copy(),
returning_nodes=self.returning_nodes.copy())
@compatibility(is_backward_compatible=False)
class SubgraphMatcher:
def __init__(self, pattern: Graph,
match_output: bool = False,
match_placeholder: bool = False,
remove_overlapping_matches: bool = True,
ignore_literals: bool = False) -> None:
"""
Args:
pattern: the targeted matching pattern, represented in fx.Graph.
match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern.
If False, output node is ignored during match.
match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of
the targeted pattern. If False, placeholder nodes will be used a wildcard.
remove_overlapping_matches: If True, in the case of overlapping matches, only the first match
will be returned.
ignore_literals: If True, will not check if literals are equal and
will instead treat them as wildcards.
"""
self.pattern = pattern
self.match_output = match_output
self.match_placeholder = match_placeholder
self.remove_overlapping_matches = remove_overlapping_matches
self.ignore_literals = ignore_literals
if len(pattern.nodes) == 0:
raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern")
for node in pattern.nodes:
if node.op != "output":
assert len(node.users) > 0, \
"SubgraphMatcher cannot be initialized with an pattern with dead code"
# TODO: assert pattern is a connected graph
self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"]
output_node = next(iter(reversed(pattern.nodes)))
# nodes returned by outputs
self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
self.pattern_anchors: List[Node] = []
if match_output:
self.pattern_anchors = [output_node]
else:
# If a node has output_node as the ONLY user, then this node is a graph sink,
# and should be matched against as an anchor
self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1]
def _match_attributes(self, pn: Node, gn: Node) -> bool:
# Attributes matching is complicated. Right now we only support matching constant tensor
assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string."
assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string."
# TODO(tmanlaibaatar) should probably make this actual API
def _getattr(model: torch.fx.GraphModule, attr_name: str):
*prefix, field = attr_name.split(".")
t = model
for item in prefix:
t = getattr(t, item, None) # type: ignore[assignment]
assert t is not None
return getattr(t, field)
pn_value = _getattr(pn.graph.owning_module, pn.target)
gn_value = _getattr(gn.graph.owning_module, gn.target)
if type(pn_value) != type(gn_value):
return False
# Don't require exact match on tensor values.
if isinstance(pn_value, torch.Tensor):
return isinstance(gn_value, torch.Tensor)
else:
raise RuntimeError(f"Unsupported type {pn_value} when matching attributes")
return False
def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
# if exact match for placeholder is not required, then use placeholder as a wildcard
if not self.match_placeholder and pn.op == "placeholder":
return True
if pn.op == gn.op:
if pn.op == "placeholder" or pn.op == "output":
return True
elif pn.op == "get_attr":
return self._match_attributes(pn, gn)
return pn.target == gn.target
return False
def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph`
# that are part of `pattern`
# Placeholders can be used by other nodes in the graphs
lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}
for gn, pn in lookup.items():
# nodes returned by output are allowed to be used in other areas of the graph
if pn in self.pattern_returning_nodes:
continue
for user in gn.users:
# If this node has users that were not in `lookup`, then it must leak out of the
# pattern subgraph
if user not in lookup:
return False
return True
def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]:
non_overlapping_matches: List[InternalMatch] = []
nodes_matched: Set[Node] = set()
for match in matches:
found_overlap = False
for pn, gn in match.nodes_map.items():
if pn.op not in {"placeholder", "output"} and gn in nodes_matched:
found_overlap = True
break
if not found_overlap:
non_overlapping_matches.append(match)
for pn, gn in match.nodes_map.items():
if pn.op not in {"placeholder", "output"}:
nodes_matched.add(gn)
return non_overlapping_matches
def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
if isinstance(pn, Node) and not isinstance(gn, Node):
if pn.op == "placeholder":
# Check if we've already matched these nodes in the current
# traversal
if pn in match.nodes_map:
return match.nodes_map[pn] == gn
match.nodes_map[pn] = gn
return True
else:
return False
elif not isinstance(pn, Node) and isinstance(gn, Node):
return False
else:
return type(gn) == type(pn) and gn == pn
def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
logger.info(" matching %s to %s", pn, gn)
assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}")
# Check if we've already matched these nodes in the current
# traversal
if pn in match.nodes_map:
return match.nodes_map[pn] == gn
# TODO: use a more efficient way to check if gn is matched before: two-way dict
if gn in match.nodes_map.values():
return False
if not self._nodes_are_equal(pn, gn):
return False
# Optimistically mark `pn` as a match for `gn`, and save a local copy of match
saved_match = copy.copy(match)
match.nodes_map[pn] = gn
# Placeholder is a wildcard and can be matched with any python object
# (including list/tuple)
if pn.op == "placeholder":
return True
# Recursively traverse upwards to check if `pn` is a true
# match for `gn`
match_found = True
def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
if len(args1) != len(args2):
return False
for a1, a2 in zip(args1, args2):
if isinstance(a1, Node) and isinstance(a2, Node):
matched = self._match_nodes(a1, a2, match)
elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)):
matched = _match_args(a1, a2)
else:
matched = self._match_literals(a1, a2, match) or self.ignore_literals
if not matched:
return False
return True
# Flatten all args/kwargs into 1 list of args
pn_args, gn_args = None, None
if (
(len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and
pn.op == "call_function" and
isinstance(pn.target, torch._ops.OpOverload)
):
args_schema = pn.target._schema.arguments
def get_all_arguments(orig_args, orig_kwargs):
all_args = []
for i, schema in enumerate(args_schema):
if schema.name in orig_kwargs:
all_args.append(orig_kwargs[schema.name])
elif not schema.kwarg_only and i < len(orig_args):
all_args.append(orig_args[i])
else:
all_args.append(schema.default_value)
return all_args
pn_args = get_all_arguments(pn.args, pn.kwargs)
gn_args = get_all_arguments(gn.args, gn.kwargs)
elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()):
pn_args = list(pn.args)
gn_args = list(gn.args)
pn_args.extend(list(pn.kwargs.values()))
gn_args.extend(list(gn.kwargs.values()))
else:
match_found = False
match_found = (
match_found and
pn_args is not None and
gn_args is not None and
_match_args(pn_args, gn_args)
)
if not match_found:
# revert to saved_match before matching with current node
match = copy.copy(saved_match)
return False
return True
def match(self, graph: Graph) -> List[InternalMatch]:
"""
Returns:
The matched subgraphs.
Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder
and nodes returned by output) can only be consumed by nodes within the matched subgraph.
Subgraph pattern matcher is implemented with the backtracking style in the following steps:
1. We first identify all the anchor nodes in the pattern graph. The anchor nodes
are the "sinks" (nodes with no user other than the output node) of the pattern graph.
One pattern graph could have multiple anchors if it has multiple return values.
2. In the target graph, we identify the potential candidate nodes that can be matched
with each anchor. These anchor-candidate pairs are the starting points for
pairwise per-node matching.
3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both
pattern and target graphs. For every pattern nodes along traversal path, we compare it
against the target nodes. In case any comparison failed, the match for this anchor-candidate
pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes`
for more details.
4. In the case of multiple anchors, every anchor will need to find a match using step 3.
In addition, the matches found between anchors need to have a common intersection node
in order for the match to be valid. This is implemented with backtracking. See `backtracking`
for more details.
Notice: graph traversal must be done in the reverser order because a tensor can have multiple
consumers, but can only have a single producer. Only with reverser order, we can we jointly
traverse the pattern and target graph in a deterministic path.
Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However,
in practice, it's unlikely to blow up.
"""
from torch.fx.passes.utils.fuser_utils import validate_partition
# find candidate nodes to match with pattern anchors
match_candidates: Dict[Node, List[Node]] = defaultdict(list)
for pattern_anchor in self.pattern_anchors:
for node in graph.nodes:
if self._nodes_are_equal(pattern_anchor, node):
match_candidates[pattern_anchor].append(node)
match_candidates_list = list(match_candidates.items())
logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
matches: List[InternalMatch] = []
def backtracking(anchor_index, match):
if anchor_index == len(match_candidates_list):
match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes]
match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes]
matches.append(match)
logger.info("Found a match: %s\n", match)
return
pattern_anchor, candidate_nodes = match_candidates_list[anchor_index]
saved_match = copy.copy(match)
for node in candidate_nodes:
logger.info("Trying to match anchor %s to %s", pattern_anchor, node)
match_found = self._match_nodes(pattern_anchor, node, match)
if match_found:
# match next anchor
backtracking(anchor_index + 1, match)
else:
logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node)
# revert to saved_match before matching with current anchor
match = copy.copy(saved_match)
match = InternalMatch(anchors=self.pattern_anchors)
if match_candidates_list:
backtracking(0, match)
# filter out the matches where the subgraph is not fully_contained
before = len(matches)
matches = [match for match in matches if self._is_contained(match.nodes_map)]
after = len(matches)
if before != after:
logger.info("Filtered out %s matches because they are not fully contained", before - after)
# filter out the matches that form a cycle if the subgraph is fused
valid_matches = []
for match in matches:
matched_compute_nodes = \
[gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}]
if validate_partition(matched_compute_nodes):
valid_matches.append(match)
if len(valid_matches) != len(matches):
logger.info("Filtered out %s matches because \
matched subgraph would form a cycle if fused", len(matches) - len(valid_matches))
if self.remove_overlapping_matches:
before = len(valid_matches)
matches = self._remove_overlapping_matches(valid_matches)
after = len(matches)
if before != after:
logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after)
logger.info("Matches returned: %s", matches)
return matches

View File

@ -0,0 +1,114 @@
from typing import Dict, List, Tuple
from torch.fx import Graph, GraphModule, Node
from torch.fx._compatibility import compatibility
from .matcher_utils import InternalMatch, SubgraphMatcher
__all__ = ["SubgraphMatcherWithNameNodeMap"]
def _split_to_graph_and_name_node_map(
gm: GraphModule,
) -> Tuple[GraphModule, Dict[str, Node]]:
from torch.fx.graph import _PyTreeInfo
from torch.utils._pytree import tree_flatten, tree_unflatten
name_node_map = {}
for n in gm.graph.nodes:
if n.op == "output":
assert gm._out_spec is not None
output = tree_unflatten(n.args[0], gm._out_spec)
assert isinstance(
output, tuple
), "Expecting the pattern graph to return a tuple"
assert (
len(output) >= 2
), "Expecting the pattern graph to have at least two outputs"
*out, name_node_map = output
flattened, out_spec = tree_flatten(out)
assert isinstance(
name_node_map, Dict
), "Expecting the input graph to have a dict output as the last element"
n.args = (flattened,)
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
)
gm.recompile()
return gm, name_node_map
@compatibility(is_backward_compatible=False)
class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
"""Extends SubgraphMatcher to support querying the matched subgraph nodes through node name,
this requires pattern to have specific format (returning and additional dictionary at the output,
that has node name as key, and the node in the pattern graph as value, see Example for more details)
Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during
initialization since we need to modify the graph (which requires `recompile` the GraphModule)
Example::
def pattern(x, weight):
conv = F.conv2d(x, weight)
relu = F.relu(conv)
return relu, {"conv": conv, "relu": relu}
def target_graph(x, weight):
conv = F.conv2d(x, weight)
relu = F.relu(conv)
relu *= 2
return relu
pattern_gm = capture_pre_autograd_graph(pattern, example_inputs)
target_gm = capture_pre_autograd_graph(target_graph, example_inputs)
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
matches = matcher.match(target_gm)
for match in matches:
match.name_node_map["conv"].meta["annotation"] = ...
"""
def __init__(
self,
pattern_gm: GraphModule,
match_output: bool = False,
match_placeholder: bool = False,
remove_overlapping_matches: bool = True,
ignore_literals: bool = False,
) -> None:
pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
self.name_node_map = name_node_map
super().__init__(
pattern_gm.graph,
match_output,
match_placeholder,
remove_overlapping_matches,
ignore_literals,
)
def match(self, graph: Graph) -> List[InternalMatch]:
"""The returned InternalMatch will have name_node_map populated with a map
from node name (str) to the target node, e.g.
{"conv": target_conv_ndoe, "relu": target_relu_node}
this requires the pattern graph returns an additional
output of node name to node, e.g. instead of:
```
def pattern(...):
...
return relu
```
we should do:
```
def pattern(...):
...
return relu, {"conv": conv, "relu": relu}
``` instead
"""
internal_matches = super().match(graph)
for internal_match in internal_matches:
for k, n in self.name_node_map.items():
internal_match.name_node_map[k] = internal_match.nodes_map[n]
return internal_matches

View File

@ -0,0 +1,154 @@
from dataclasses import dataclass, field
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
from typing import Dict, List, Any, Type, Optional, Callable
import logging
import os
__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition']
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
def _init_logger() -> logging.Logger:
logger = logging.getLogger(__name__)
level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter("%(filename)s > %(message)s")
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger
logger = _init_logger()
@compatibility(is_backward_compatible=False)
@dataclass
class SourcePartition:
# Nodes in a particular partition
nodes: List[Node]
# The source these nodes decomposed from
source: Any
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the
# partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[Node] = field(default_factory=list)
@compatibility(is_backward_compatible=False) # type: ignore[misc]
def get_source_partitions(
graph: Graph,
wanted_sources: List[Any],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Dict[Any, List[SourcePartition]]:
"""
Args:
graph: The graph we want to partition
wanted_sources: List of sources of nodes that were decomposed from this
source. This can be a function (ex. torch.nn.functional.linear) or a
leaf module type (ex. torch.nn.Linear).
Returns:
Dictionary mapping sources that were given to a list of SourcePartitions
that correspond to the list of nodes that were decomposed from the given
source.
"""
modules: Dict[Type, Dict[str, List[Node]]] = {}
for node in graph.nodes:
# The metadata source_fn should contain a tuple of a unique name for the
# source, and the source function if the node is decomposed from a
# function, or the type of module if the node is decomposed from a leaf
# module
# TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can
# be different from "source_fn_stack", for example for the add_ node
# decomposed from batch norm. We should remove the check on "source_fn_stack"
# after we fix "torch_fn". T199561090
if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and
(torch_fn := node.meta.get("torch_fn", None)) is not None):
node_fqn, source_fn = torch_fn
source_fn_name = source_fn.split(".")[1]
if source_fn_name in wanted_sources:
diff_modules = modules.setdefault(source_fn_name, {})
partition = diff_modules.setdefault(node_fqn, [])
partition.append(node)
if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None:
source_fn = source_fn_st[-1]
if source_fn[1] in wanted_sources:
diff_modules = modules.setdefault(source_fn[1], {})
partition = diff_modules.setdefault(source_fn[0], [])
partition.append(node)
def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
input_nodes = set()
output_nodes = set()
params = set()
for node in nodes:
for arg in node.args:
if isinstance(arg, Node) and arg not in nodes:
input_nodes.add(arg)
if node.op == "get_attr":
params.add(node)
for user in node.users.keys():
if user not in nodes:
output_nodes.add(node)
return SourcePartition(
nodes,
module_type,
list(input_nodes),
list(output_nodes),
list(params), # type: ignore[arg-type]
)
ret: Dict[Type[Any], List[SourcePartition]] = {}
if filter_fn:
# for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
# filter condition
filtered_modules = {}
for tp, name_to_partition in modules.items():
filtered_name_to_partition = {
name: partition
for name, partition in name_to_partition.items()
if all(map(filter_fn, partition))
}
filtered_modules[tp] = filtered_name_to_partition
modules = filtered_modules
for k, v in modules.items():
ret[k] = [make_partition(partition, k) for partition in v.values()]
return ret
@compatibility(is_backward_compatible=False) # type: ignore[misc]
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
"""
Given two subgraphs A and B (in the form of a list of nodes), checks if
A has nodes connecting to at least one node in B -- aka there exists a node
in B that uses a node in A (not the other way around).
"""
for node in reversed(subgraph1.nodes):
for user in node.users.keys():
if user in subgraph2.nodes:
return True
return False