312 lines
12 KiB
Python
312 lines
12 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
|
# license information.
|
|
# --------------------------------------------------------------------------
|
|
from __future__ import annotations
|
|
|
|
from collections import deque
|
|
|
|
import onnx
|
|
|
|
from ..onnx_model import ONNXModel
|
|
|
|
|
|
class Fusion:
|
|
"""
|
|
Base class for fusions.
|
|
"""
|
|
|
|
def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
|
|
self.search_op_type: str = search_op_type
|
|
self.fused_op_type: str = fused_op_type
|
|
self.model: ONNXModel = model
|
|
self.nodes_to_remove: list = []
|
|
self.nodes_to_add: list = []
|
|
|
|
self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
|
|
self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.
|
|
|
|
def fuse(
|
|
self,
|
|
node: onnx.NodeProto,
|
|
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
output_name_to_node: dict[str, onnx.NodeProto],
|
|
):
|
|
"""
|
|
Interface function for derived fusion classes. Tries to fuse a node sequence containing
|
|
the specified node.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def apply(self) -> bool:
|
|
"""
|
|
Apply graph fusion on the entire model graph.
|
|
"""
|
|
input_name_to_nodes = self.model.input_name_to_nodes()
|
|
output_name_to_node = self.model.output_name_to_node()
|
|
|
|
for node in self.model.nodes():
|
|
if node.op_type == self.search_op_type:
|
|
self.fuse(node, input_name_to_nodes, output_name_to_node)
|
|
|
|
self.model.remove_nodes(self.nodes_to_remove)
|
|
self.model.add_nodes(self.nodes_to_add)
|
|
|
|
graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
|
|
|
|
if graph_updated:
|
|
self.model.remove_unused_constant()
|
|
|
|
return graph_updated
|
|
|
|
def create_unique_node_name(self):
|
|
prefix = self._new_node_name_prefix
|
|
|
|
if self._new_node_name_suffix is None:
|
|
largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
|
|
self._new_node_name_suffix = largest_suffix + 1
|
|
|
|
new_name = f"{prefix}{self._new_node_name_suffix!s}"
|
|
self._new_node_name_suffix += 1
|
|
|
|
return new_name
|
|
|
|
@staticmethod
|
|
def is_safe_to_fuse_nodes(
|
|
nodes_to_remove: list[onnx.NodeProto],
|
|
keep_outputs: list[str],
|
|
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
output_name_to_node: dict[str, onnx.NodeProto],
|
|
) -> bool:
|
|
for node_to_remove in nodes_to_remove:
|
|
for output_to_remove in node_to_remove.output:
|
|
if output_to_remove in keep_outputs:
|
|
continue
|
|
|
|
if output_to_remove in input_name_to_nodes:
|
|
for impacted_node in input_name_to_nodes[output_to_remove]:
|
|
if impacted_node not in nodes_to_remove:
|
|
# Not safe to remove nodes since output is used by impacted_node
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
|
|
for attr in node.attribute:
|
|
if attr.name == attribute_name:
|
|
value = onnx.helper.get_attribute_value(attr)
|
|
return value
|
|
return None
|
|
|
|
@staticmethod
|
|
def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
|
|
for index, input_name in enumerate(child_node.input):
|
|
if input_name == node_output:
|
|
return index
|
|
return -1
|
|
|
|
@staticmethod
|
|
def tensor_shape_to_list(tensor_type) -> list[int]:
|
|
shape_list = []
|
|
for d in tensor_type.shape.dim:
|
|
if d.HasField("dim_value"):
|
|
shape_list.append(d.dim_value) # known dimension
|
|
elif d.HasField("dim_param"):
|
|
shape_list.append(d.dim_param) # unknown dimension with symbolic name
|
|
else:
|
|
shape_list.append("?") # shall not happen
|
|
return shape_list
|
|
|
|
def get_constant_input(self, node: onnx.NodeProto):
|
|
for i, inp in enumerate(node.input):
|
|
value = self.model.get_constant_value(inp)
|
|
if value is not None:
|
|
return i, value
|
|
|
|
return None, None
|
|
|
|
def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
|
|
i, value = self.get_constant_input(node)
|
|
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
|
|
return i
|
|
|
|
return -1
|
|
|
|
def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
|
|
return self.find_constant_input(node, expected_value, delta) >= 0
|
|
|
|
def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
|
|
value = self.model.get_constant_value(output_name)
|
|
if value is None:
|
|
return False # Not an initializer
|
|
|
|
if len(value.shape) != rank:
|
|
return False # Wrong dimensions
|
|
|
|
return True
|
|
|
|
def match_first_parent(
|
|
self,
|
|
node: onnx.NodeProto,
|
|
parent_op_type: str,
|
|
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
|
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
|
) -> tuple[onnx.NodeProto | None, int | None]:
|
|
"""
|
|
Find parent node based on constraints on op_type.
|
|
|
|
Args:
|
|
node: current node.
|
|
parent_op_type (str): constraint of parent node op_type.
|
|
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
|
|
|
Returns:
|
|
parent: The matched parent node. None if not found.
|
|
index: The input index of matched parent node. None if not found.
|
|
"""
|
|
if output_name_to_node is None:
|
|
output_name_to_node = self.model.output_name_to_node()
|
|
|
|
for i, inp in enumerate(node.input):
|
|
if inp in output_name_to_node:
|
|
parent = output_name_to_node[inp]
|
|
if parent.op_type == parent_op_type and parent not in exclude:
|
|
return parent, i
|
|
|
|
return None, None
|
|
|
|
def match_parent(
|
|
self,
|
|
node: onnx.NodeProto,
|
|
parent_op_type: str,
|
|
input_index: int | None = None,
|
|
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
|
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
|
return_indice: list[int] | None = None,
|
|
) -> onnx.NodeProto | None:
|
|
"""
|
|
Find parent node based on constraints on op_type and index.
|
|
When input_index is None, we will find the first parent node based on constraints,
|
|
and return_indice will be appended the corresponding input index.
|
|
|
|
Args:
|
|
node (str): current node name.
|
|
parent_op_type (str): constraint of parent node op_type.
|
|
input_index (int or None): only check the parent given input index of current node.
|
|
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
|
return_indice (list): a list to append the input index when input_index is None.
|
|
|
|
Returns:
|
|
parent: The matched parent node.
|
|
"""
|
|
assert node is not None
|
|
assert input_index is None or input_index >= 0
|
|
|
|
if output_name_to_node is None:
|
|
output_name_to_node = self.model.output_name_to_node()
|
|
|
|
if input_index is None:
|
|
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
|
|
if return_indice is not None:
|
|
return_indice.append(index)
|
|
return parent
|
|
|
|
if input_index >= len(node.input):
|
|
# Input index out of bounds.
|
|
return None
|
|
|
|
parent = self.model.get_parent(node, input_index, output_name_to_node)
|
|
if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
|
|
return parent
|
|
|
|
return None
|
|
|
|
def match_parent_path(
|
|
self,
|
|
node: onnx.NodeProto,
|
|
parent_op_types: list[str],
|
|
parent_input_index: list[int] | None = None,
|
|
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
|
return_indice: list[int] | None = None,
|
|
) -> list[onnx.NodeProto] | None:
|
|
"""
|
|
Find a sequence of input edges based on constraints on parent op_type and index.
|
|
When input_index is None, we will find the first parent node based on constraints,
|
|
and return_indice will be appended the corresponding input index.
|
|
|
|
Args:
|
|
node (str): current node name.
|
|
parent_op_types (str): constraint of parent node op_type of each input edge.
|
|
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
|
|
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
return_indice (list): a list to append the input index
|
|
When there is no constraint on input index of an edge.
|
|
|
|
Returns:
|
|
parents: a list of matched parent node.
|
|
"""
|
|
if parent_input_index is not None:
|
|
assert len(parent_input_index) == len(parent_op_types)
|
|
|
|
if output_name_to_node is None:
|
|
output_name_to_node = self.model.output_name_to_node()
|
|
|
|
current_node = node
|
|
matched_parents = []
|
|
for i, op_type in enumerate(parent_op_types):
|
|
matched_parent = self.match_parent(
|
|
current_node,
|
|
op_type,
|
|
parent_input_index[i] if parent_input_index is not None else None,
|
|
output_name_to_node,
|
|
exclude=[],
|
|
return_indice=return_indice,
|
|
)
|
|
if matched_parent is None:
|
|
return None
|
|
|
|
matched_parents.append(matched_parent)
|
|
current_node = matched_parent
|
|
|
|
return matched_parents
|
|
|
|
def match_parent_paths(
|
|
self,
|
|
node: onnx.NodeProto,
|
|
paths: list[tuple[list[str], list[int]]],
|
|
output_name_to_node: dict[str, onnx.NodeProto],
|
|
) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
|
|
"""
|
|
Find a matching parent path to the given node.
|
|
"""
|
|
for i, path in enumerate(paths):
|
|
return_indice = []
|
|
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
|
|
if matched:
|
|
return i, matched, return_indice
|
|
return -1, None, None
|
|
|
|
def find_first_child_by_type(
|
|
self,
|
|
node: onnx.NodeProto,
|
|
child_type: str,
|
|
input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
|
|
recursive: bool = True,
|
|
) -> onnx.NodeProto | None:
|
|
children = self.model.get_children(node, input_name_to_nodes)
|
|
dq = deque(children)
|
|
while len(dq) > 0:
|
|
current_node = dq.pop()
|
|
if current_node.op_type == child_type:
|
|
return current_node
|
|
|
|
if recursive:
|
|
children = self.model.get_children(current_node, input_name_to_nodes)
|
|
for child in children:
|
|
dq.appendleft(child)
|
|
|
|
return None
|