I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,2 @@
from . import pass_manager

View File

@ -0,0 +1,335 @@
# mypy: allow-untyped-defs
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
import collections
import itertools
import logging
from copy import copy
from typing import Dict, Iterable, List, Optional, Sequence, Set
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupportBase
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class Partition:
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
self.id = id
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
def __repr__(self) -> str:
return str(self.nodes)
def add_node(self, node: Node):
self.nodes.update({node: None})
def remove_node(self, node: Node):
del self.nodes[node]
def size(self):
return len(self.nodes)
class _DependencyViewer:
def __init__(self, graph_module: GraphModule):
self.upstreams = collections.defaultdict(set)
self.downstreams = collections.defaultdict(set)
for node in graph_module.graph.nodes:
for input_node in node.all_input_nodes:
# add input_node and input_node's upstream dependency
self.upstreams[node].add(input_node)
self.upstreams[node].update(self.upstreams[input_node])
for node in reversed(graph_module.graph.nodes):
for output_node in node.users:
# add output_node and output_node's downstream dependency
self.downstreams[node].add(output_node)
self.downstreams[node].update(self.downstreams[output_node])
def downstreams_of(self, node: Node) -> Set[Node]:
return self.downstreams[node]
def upstreams_of(self, node: Node) -> Set[Node]:
return self.upstreams[node]
class CapabilityBasedPartitioner:
def __init__(self,
graph_module: GraphModule,
operator_support: OperatorSupportBase,
allows_single_node_partition: bool = False,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
) -> None:
self.graph_module = graph_module
self.operator_support = operator_support
self.allows_single_node_partition = allows_single_node_partition
self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
self.allowed_single_node_partition_ops = (
allowed_single_node_partition_ops
if allowed_single_node_partition_ops is not None
else []
)
self.dependency_viewer = _DependencyViewer(graph_module)
def __is_node_supported(self, node: Node) -> bool:
return (
self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
)
def propose_partitions(self) -> List[Partition]:
# partition_map is a mapping from partition id to a set of partition id's.
# The value set contains all the partition ids that can be reached by doing a
# DFS starting from the partition id in the key.
partition_map : Dict[int, Set] = collections.defaultdict(set)
# assumptions: nodes in candidate list is sorted in topological order
assignment: Dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
new_partition_id = itertools.count()
# try to merge partition other_id into partition self_id
# merge only happens if the end graph doesn't contain cyclic dependency
# returns `True` when merge happens, `False` otherwise.
def maybe_merge_partition(self_id: int, other_id: int):
# merged_nodes is the union of nodes in two partition to-be-merged
merged_nodes = copy(partitions_by_id[self_id].nodes)
merged_nodes.update(partitions_by_id[other_id].nodes)
def dfs_iter_find_cycle(all_user_nodes: Set[Node]):
for user_node in all_user_nodes:
visited_partition_ids = set()
for path_node in self.dependency_viewer.downstreams_of(user_node):
# If any of the nodes in the dfs path of this node are in the merged_nodes
# list then there is a cycle in the graph.
if path_node in merged_nodes:
return True
# If any of the nodes in the dfs path of this node are in the assignment
# map then we have to make sure that the partitions that these nodes belong
# to do not form a cycle with the current partitions being merged. This means
# iterating through all the nodes in all the parititons that are traversed in
# the dfs path and checking if they are in the merged_nodes list.
if path_node in assignment:
partition_id = assignment[path_node]
# If the partition id has already been visited then we know that it doesn't
# form a cycle with the current partitions being merged.
if partition_id in visited_partition_ids:
continue
p_map = partition_map[partition_id]
if self_id in p_map or other_id in p_map:
return True
visited_partition_ids.add(partition_id)
return False
# check if merge would create cyclic dependency.
all_user_nodes = set()
for node in merged_nodes:
for user_node in node.users:
if user_node not in merged_nodes:
all_user_nodes.add(user_node)
if dfs_iter_find_cycle(all_user_nodes):
# return false indicating cyclic dependency found and
# merge is aborted
return False
# no cyclic dependency found, move forward with the merge
# updating partition nodes
partitions_by_id[self_id].nodes = merged_nodes
# updating assignment map
for node in partitions_by_id[other_id].nodes:
assignment[node] = self_id
# delete other partition
del partitions_by_id[other_id]
partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
del partition_map[other_id]
return True
def merge_single_node(node: Node, id: Optional[int]):
def _update_partition_map(node: Node, id: int):
# Iterate through all the downstream nodes of this node and update the partition map
# to indicate that there is a path from the partition id of this node to the target
# partition id.
downstream_nodes = self.dependency_viewer.downstreams_of(node)
for curr_node in downstream_nodes:
target_id = assignment.get(curr_node, None)
if target_id is not None:
partition_map[id].add(target_id)
# Iterate through all the upstream nodes of this node and update the partition map
# to indicate that there is a path from the partition id of the upstream node to the
# current node's partition id.
upstream_nodes = self.dependency_viewer.upstreams_of(node)
for curr_node in upstream_nodes:
source_id = assignment.get(curr_node, None)
if source_id is not None:
partition_map[source_id].add(id)
if node in assignment:
partitions_by_id[assignment[node]].remove_node(node)
if id is None:
assignment.pop(node)
elif id not in partitions_by_id:
assignment[node] = id
partitions_by_id[id] = Partition(id=id, nodes=[node])
_update_partition_map(node, id)
else:
assignment[node] = id
partitions_by_id[id].add_node(node)
_update_partition_map(node, id)
logger.debug("Proposing partitions...")
for node in reversed(self.graph_module.graph.nodes):
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
merge_candidates: Dict[int, None] = {}
# Note a limited horizontal fusion is enabled:
# when `node` is not supported, the code below attempts to fuse consumer of `node`.
#
# I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
# the fusion by adding an `else` block here to skip horizontal fusion.
if self.__is_node_supported(node) and node not in assignment:
partition_id = next(new_partition_id)
merge_single_node(node, partition_id)
merge_candidates[partition_id] = None
# merge all possible partitions
for node in assignment:
merge_candidates[assignment[node]] = None
merge_candidates_list = list(merge_candidates.keys())
if len(merge_candidates_list) > 1:
self_id = merge_candidates_list[0]
for other_id in merge_candidates_list[1:]:
# note: merge partition `other_id` into partition `self_id` if
# it doesn't create cyclic dependency in the graph, otherwise,
# this is a no-op
maybe_merge_partition(self_id, other_id)
# post processing to re-assign "getitem" nodes into upstream partition
logger.debug("Reassigning getitem nodes to its producer node's partition...")
nodes_reassignment: Dict[Node, int] = {}
for node in self.graph_module.graph.nodes:
is_tuple_output = True
for user in node.users:
if user.op != "call_function" or \
_get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
is_tuple_output = False
break
# node has tuple outputs, re-assign all following getitem node into node's partition
if is_tuple_output:
id = assignment.get(node, None) # type: ignore[arg-type]
for user in node.users:
if assignment.get(user, None) != id: # type: ignore[arg-type]
nodes_reassignment[user] = id # type: ignore[assignment]
for node, id in nodes_reassignment.items():
merge_single_node(node, id)
# filter out single node partitions
if not self.allows_single_node_partition:
logger.debug("Filtering out single node partitions...")
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
partitions_to_remove: List[int] = []
for id, partition in partitions_by_id.items():
compute_node_count = 0
for node in partition.nodes:
if node.op == "call_function":
assert callable(node.target)
if _get_qualified_name(node.target) not in non_compute_ops:
compute_node_count += 1
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
compute_node_count += 1
if compute_node_count <= 1:
partitions_to_remove.append(id)
for id in partitions_to_remove:
del partitions_by_id[id]
logger.debug("Partitions proposed:")
for id, partition in partitions_by_id.items():
logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
return [partition for partition in partitions_by_id.values() if partition.size() > 0]
def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule:
logger.debug("Fusing partitions...")
# fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
return fuse_by_partitions(
self.graph_module,
[list(partition.nodes) for partition in partitions],
prefix=prefix,
)
# remove non-compute-ops that sits at the boundary of a partition.
def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
non_compute_ops = set(self.non_compute_ops)
def is_non_compute_node(node: Node):
return node.op == "call_function" and \
_get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
# cache transparent nodes
transparent_input_nodes: Dict[Node, bool] = {}
transparent_output_nodes: Dict[Node, bool] = {}
def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
return True
if node in transparent_input_nodes:
return transparent_input_nodes[node]
if is_non_compute_node(node):
for input_n in node.all_input_nodes:
if not is_transparent_input_node(input_n, partition, removed_nodes):
transparent_input_nodes[node] = False
return False
transparent_input_nodes[node] = True
return True
transparent_input_nodes[node] = False
return False
def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
return True
if node in transparent_output_nodes:
return transparent_output_nodes[node]
if is_non_compute_node(node):
for output_n in node.users:
if not is_transparent_output_node(output_n, partition, removed_nodes):
transparent_output_nodes[node] = False
return False
transparent_output_nodes[node] = True
return True
transparent_output_nodes[node] = False
return False
for partition in partitions:
# Note it's ok to use `set` here, since we are only query if a node
# has been removed. We are NEVER going to iterate on nodes inside
# the set.
remove_node: Set[Node] = set()
for node in partition.nodes:
if is_non_compute_node(node) and \
(is_transparent_input_node(node, set(partition.nodes), remove_node) or
is_transparent_output_node(node, set(partition.nodes), remove_node)):
remove_node.add(node)
if len(remove_node) != 0:
for node in remove_node:
partition.nodes.pop(node, None)
def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule:
partitions = self.propose_partitions()
fused_gm = self.fuse_partitions(partitions, prefix=prefix)
return fused_gm

View File

@ -0,0 +1,73 @@
# mypy: allow-untyped-defs
import abc
from collections import namedtuple
from typing import Optional
from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
__all__ = ['PassResult', 'PassBase']
@compatibility(is_backward_compatible=False)
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
"""
Result of a pass:
graph_module: The modified graph module
modified: A flag for if the pass has modified the graph module
"""
def __new__(cls, graph_module, modified):
return super().__new__(cls, graph_module, modified)
@compatibility(is_backward_compatible=False)
class PassBase(abc.ABC):
"""
Base interface for implementing passes.
It is required to implement the `call` function so that we can directly
pass instances of the Pass directly to the PassManager and call them as a
function.
We can directly pass an instance of a class implementing this interface into
the PassManager's `passes` attribute.
"""
def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
"""
Runs the precondition check, the pass itself, and the postcondition check.
"""
self.requires(graph_module)
res = self.call(graph_module)
self.ensures(graph_module)
return res
@abc.abstractmethod
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
"""
The pass that is run through the given graph module. To implement a
pass, it is required to implement this function.
Args:
graph_module: The graph module we will run a pass on
"""
def requires(self, graph_module: GraphModule) -> None: # noqa: B027
"""
This function will be called before the pass is run and will check that
the given graph module contains the preconditions needed to run the
pass. It is not required to implement this function.
Args:
graph_module: The graph module we will run checks on
"""
def ensures(self, graph_module: GraphModule) -> None: # noqa: B027
"""
This function will be called after the pass is run and will check that
the given graph module contains the postconditions needed to run the
pass. It is not required to implement this function.
Args:
graph_module: The graph module we will run checks on
"""

View File

@ -0,0 +1,302 @@
# mypy: allow-untyped-defs
import inspect
import logging
from queue import Queue
from functools import wraps
from typing import Callable, Dict, List
import torch.nn as nn
from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
@compatibility(is_backward_compatible=False)
def pass_result_wrapper(fn: Callable) -> Callable:
"""
Wrapper for passes which currently do not return a PassResult.
This wrapper makes them return a PassResult containing the modified object
and True for the "modified" flag.
Args:
fn (Callable[Module, Any])
Returns:
wrapped_fn (Callable[Module, PassResult])
"""
if fn is None:
return None
@wraps(fn)
def wrapped_fn(gm):
res = fn(gm)
if res is None:
return PassResult(gm, True)
if isinstance(res, PassResult):
return res
elif isinstance(res, nn.Module):
return PassResult(res, True)
if not inspect.isfunction(fn):
wrapped_fn.__name__ = type(fn).__name__
return wrapped_fn
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
) -> None:
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
if constraint(a, b):
continue
raise RuntimeError(
f"pass schedule constraint violated. Expected {a} before {b}"
f" but found {a} at index {i} and {b} at index{j} in pass"
f" list."
)
def _topological_sort_passes(
passes: List[Callable], constraints: List[Callable]
) -> List[Callable]:
"""
Args
passes: Passes that we are ordering
constraints: Constraints applied on these passes
Returns
A sorted list of callables and a boolean of if a circular dependency
existed
"""
if len(constraints) == 0:
return passes
# Contruct a graph mapping nodes to a list of their users
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
candidates: Queue = Queue()
for a in passes:
for b in passes:
if a == b:
continue
for constraint in constraints:
if not constraint(a, b):
graph[b].append(a)
indegree_map[a] += 1
if indegree_map[a] == 0:
candidates.put(a)
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
sorted_passes: List[Callable] = []
while not candidates.empty():
p = candidates.get()
sorted_passes.append(p)
visited[p] = True
for n in graph[p]:
if not visited[n]:
indegree_map[n] -= 1
if indegree_map[n] == 0:
candidates.put(n)
# Check if there are unvisited nodes (aka cycles in the graph)
cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
if len(cycle_passes) != 0:
error = f"Circular dependency detected within the following passes: {cycle_passes}"
raise RuntimeError(error)
return sorted_passes
@compatibility(is_backward_compatible=False)
def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
"""
Defines a partial order ('depends on' function) where `this` must occur
before `that`.
For example, the following pass list and constraint list would be invalid.
```
passes = [pass_b, pass_a]
constraints = [
this_before_that_pass_constraint(pass_a, pass_b)
]
```
Args:
this (Callable): pass which should occur first
that (Callable): pass which should occur later
Returns:
depends_on (Callable[[Object, Object], bool]
"""
def depends_on(a: Callable, b: Callable):
return a != that or b != this
return depends_on
@compatibility(is_backward_compatible=False)
class PassManager:
"""
Construct a PassManager.
Collects passes and constraints. This defines the pass schedule, manages
pass constraints and pass execution.
Args:
passes (Optional[List[Callable]]): List of passes. A pass is a
callable which modifies an object and returns a PassResult
constraint (Optional[List[Callable]]): List of constraints. A
constraint is a callable which takes two passes (A, B) and returns
True if A depends on B and False otherwise. See implementation of
`this_before_that_pass_constraint` for example.
steps (int): Max number of times we run the passes (default = 1).
run_checks_after_each_pass (bool): Whether to run checks and linting
after each pass
suppress_check_failures (bool): Whether to raise errors when running
checks
"""
passes: List[Callable[[nn.Module], PassResult]]
constraints: List[Callable[[Callable, Callable], bool]]
_validated: bool = False
steps: int = 1
def __init__(
self,
passes=None,
constraints=None,
steps=None,
run_checks_after_each_pass: bool = False,
suppress_check_failures: bool = False,
):
self.passes = passes or []
self.constraints = constraints or []
if steps:
self.steps = steps
self.run_checks_after_each_pass = run_checks_after_each_pass
self.suppress_check_failures = suppress_check_failures
def add_pass(self, _pass: Callable):
"""
Adds a pass into the current list of passes.
"""
self.passes.append(_pass)
self._validated = False
def add_constraint(self, constraint: Callable):
"""
Adds a constraint into the current list of constraints.
"""
self.constraints.append(constraint)
self._validated = False
def validate_constraints(self):
"""
Validates that current pass schedule defined by `self.passes` is valid
according to all constraints in `self.constraints`
"""
if self._validated:
return
for constraint in self.constraints:
_validate_pass_schedule_constraint(constraint, self.passes)
self._validated = True
def solve_constraints(self):
"""
Finds a valid traversal order based on the given constraints and orders
the passes based on this order.
If a circular dependency exists between the constraints and steps = 1,
then we will raise an error because if steps != 1 this means that we
will re-run the passes, allowing for circular dependencies.
"""
self.passes = _topological_sort_passes(self.passes, self.constraints)
self._validated = True
def add_checks(self, check: Callable) -> None:
"""
Adds a function which takes runs various checks on a given graph module.
This function is run before and after each pass if the
`run_checks_after_each_pass` flag is enabled.
"""
sig = inspect.signature(check)
if len(list(sig.parameters.values())) != 1:
raise TypeError("PassManager check function should only take in one variable, a module")
setattr(self, "check", check) # noqa: B010
def check(self, module: nn.Module) -> None:
pass
def __call__(self, module: nn.Module) -> PassResult:
"""
Runs a list of passes in the order based on `self.passes` on the given
graph module. Each time a pass is run, checks and linting will be run on
the graph module if `run_checks_after_each_pass` is set.
If the module is a graph module, we will run the list of passes until
the graph stops changing, or until `steps` number of times.
"""
# Order the passes based on the constraints
if not self._validated:
self.solve_constraints()
# Check graph invariants
self.check(module)
# Run the set of passes `steps` number of times or until the graph stops
# changing
overall_modified = False
for _ in range(self.steps):
modified = False
# Run the set of passes on the graph module
for i, fn in enumerate(self.passes):
fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
logger.debug("Running pass '%s'", fn_name)
try:
res = fn(module)
if not isinstance(res, PassResult) and not hasattr(
res, "graph_module"
):
raise TypeError(
f"The result of the pass {fn_name} should be type PassResult."
+ "Please wrap it with pass_result_wrapper()"
)
module = res.graph_module
modified = modified or res.modified
if isinstance(module, GraphModule):
logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
module.recompile()
# Check graph invariants
if self.run_checks_after_each_pass:
self.check(module)
except Exception as e:
prev_pass_names = [
p.__name__ if inspect.isfunction(p) else type(p).__name__
for p in self.passes[:i]
]
msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
raise Exception(msg) from e # noqa: TRY002
# If the graph no longer changes, then we can stop running these passes
overall_modified = overall_modified or modified
if not modified:
break
return PassResult(module, overall_modified)