I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from .fusion import Fusion # noqa: F401
from .fusion_gelu import FusionGelu # noqa: F401
from .fusion_layernorm import FusionLayerNormalization # noqa: F401

View File

@ -0,0 +1,311 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
from collections import deque
import onnx
from ..onnx_model import ONNXModel
class Fusion:
"""
Base class for fusions.
"""
def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
self.search_op_type: str = search_op_type
self.fused_op_type: str = fused_op_type
self.model: ONNXModel = model
self.nodes_to_remove: list = []
self.nodes_to_add: list = []
self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.
def fuse(
self,
node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
):
"""
Interface function for derived fusion classes. Tries to fuse a node sequence containing
the specified node.
"""
raise NotImplementedError
def apply(self) -> bool:
"""
Apply graph fusion on the entire model graph.
"""
input_name_to_nodes = self.model.input_name_to_nodes()
output_name_to_node = self.model.output_name_to_node()
for node in self.model.nodes():
if node.op_type == self.search_op_type:
self.fuse(node, input_name_to_nodes, output_name_to_node)
self.model.remove_nodes(self.nodes_to_remove)
self.model.add_nodes(self.nodes_to_add)
graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
if graph_updated:
self.model.remove_unused_constant()
return graph_updated
def create_unique_node_name(self):
prefix = self._new_node_name_prefix
if self._new_node_name_suffix is None:
largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
self._new_node_name_suffix = largest_suffix + 1
new_name = f"{prefix}{self._new_node_name_suffix!s}"
self._new_node_name_suffix += 1
return new_name
@staticmethod
def is_safe_to_fuse_nodes(
nodes_to_remove: list[onnx.NodeProto],
keep_outputs: list[str],
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
for node_to_remove in nodes_to_remove:
for output_to_remove in node_to_remove.output:
if output_to_remove in keep_outputs:
continue
if output_to_remove in input_name_to_nodes:
for impacted_node in input_name_to_nodes[output_to_remove]:
if impacted_node not in nodes_to_remove:
# Not safe to remove nodes since output is used by impacted_node
return False
return True
@staticmethod
def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
for attr in node.attribute:
if attr.name == attribute_name:
value = onnx.helper.get_attribute_value(attr)
return value
return None
@staticmethod
def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
for index, input_name in enumerate(child_node.input):
if input_name == node_output:
return index
return -1
@staticmethod
def tensor_shape_to_list(tensor_type) -> list[int]:
shape_list = []
for d in tensor_type.shape.dim:
if d.HasField("dim_value"):
shape_list.append(d.dim_value) # known dimension
elif d.HasField("dim_param"):
shape_list.append(d.dim_param) # unknown dimension with symbolic name
else:
shape_list.append("?") # shall not happen
return shape_list
def get_constant_input(self, node: onnx.NodeProto):
for i, inp in enumerate(node.input):
value = self.model.get_constant_value(inp)
if value is not None:
return i, value
return None, None
def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
i, value = self.get_constant_input(node)
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
return i
return -1
def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
return self.find_constant_input(node, expected_value, delta) >= 0
def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
value = self.model.get_constant_value(output_name)
if value is None:
return False # Not an initializer
if len(value.shape) != rank:
return False # Wrong dimensions
return True
def match_first_parent(
self,
node: onnx.NodeProto,
parent_op_type: str,
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
exclude: list[onnx.NodeProto] = [], # noqa: B006
) -> tuple[onnx.NodeProto | None, int | None]:
"""
Find parent node based on constraints on op_type.
Args:
node: current node.
parent_op_type (str): constraint of parent node op_type.
output_name_to_node (dict): dictionary with output name as key, and node as value.
exclude (list): list of nodes that are excluded (not allowed to match as parent).
Returns:
parent: The matched parent node. None if not found.
index: The input index of matched parent node. None if not found.
"""
if output_name_to_node is None:
output_name_to_node = self.model.output_name_to_node()
for i, inp in enumerate(node.input):
if inp in output_name_to_node:
parent = output_name_to_node[inp]
if parent.op_type == parent_op_type and parent not in exclude:
return parent, i
return None, None
def match_parent(
self,
node: onnx.NodeProto,
parent_op_type: str,
input_index: int | None = None,
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
exclude: list[onnx.NodeProto] = [], # noqa: B006
return_indice: list[int] | None = None,
) -> onnx.NodeProto | None:
"""
Find parent node based on constraints on op_type and index.
When input_index is None, we will find the first parent node based on constraints,
and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
parent_op_type (str): constraint of parent node op_type.
input_index (int or None): only check the parent given input index of current node.
output_name_to_node (dict): dictionary with output name as key, and node as value.
exclude (list): list of nodes that are excluded (not allowed to match as parent).
return_indice (list): a list to append the input index when input_index is None.
Returns:
parent: The matched parent node.
"""
assert node is not None
assert input_index is None or input_index >= 0
if output_name_to_node is None:
output_name_to_node = self.model.output_name_to_node()
if input_index is None:
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
if return_indice is not None:
return_indice.append(index)
return parent
if input_index >= len(node.input):
# Input index out of bounds.
return None
parent = self.model.get_parent(node, input_index, output_name_to_node)
if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
return parent
return None
def match_parent_path(
self,
node: onnx.NodeProto,
parent_op_types: list[str],
parent_input_index: list[int] | None = None,
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
return_indice: list[int] | None = None,
) -> list[onnx.NodeProto] | None:
"""
Find a sequence of input edges based on constraints on parent op_type and index.
When input_index is None, we will find the first parent node based on constraints,
and return_indice will be appended the corresponding input index.
Args:
node (str): current node name.
parent_op_types (str): constraint of parent node op_type of each input edge.
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
output_name_to_node (dict): dictionary with output name as key, and node as value.
return_indice (list): a list to append the input index
When there is no constraint on input index of an edge.
Returns:
parents: a list of matched parent node.
"""
if parent_input_index is not None:
assert len(parent_input_index) == len(parent_op_types)
if output_name_to_node is None:
output_name_to_node = self.model.output_name_to_node()
current_node = node
matched_parents = []
for i, op_type in enumerate(parent_op_types):
matched_parent = self.match_parent(
current_node,
op_type,
parent_input_index[i] if parent_input_index is not None else None,
output_name_to_node,
exclude=[],
return_indice=return_indice,
)
if matched_parent is None:
return None
matched_parents.append(matched_parent)
current_node = matched_parent
return matched_parents
def match_parent_paths(
self,
node: onnx.NodeProto,
paths: list[tuple[list[str], list[int]]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
"""
Find a matching parent path to the given node.
"""
for i, path in enumerate(paths):
return_indice = []
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
if matched:
return i, matched, return_indice
return -1, None, None
def find_first_child_by_type(
self,
node: onnx.NodeProto,
child_type: str,
input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
recursive: bool = True,
) -> onnx.NodeProto | None:
children = self.model.get_children(node, input_name_to_nodes)
dq = deque(children)
while len(dq) > 0:
current_node = dq.pop()
if current_node.op_type == child_type:
return current_node
if recursive:
children = self.model.get_children(current_node, input_name_to_nodes)
for child in children:
dq.appendleft(child)
return None

View File

@ -0,0 +1,272 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import onnx
from ..onnx_model import ONNXModel
from .fusion import Fusion
class FusionGelu(Fusion):
def __init__(self, model: ONNXModel):
super().__init__(model, "Gelu", "Erf")
def fuse(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
):
"""
Interface function that tries to fuse a node sequence containing an Erf node into a single
Gelu node.
"""
if (
self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
):
self.model.set_opset_import("com.microsoft", 1)
def fuse_1(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
"""
This pattern is from PyTorch model
Fuse Gelu with Erf into one node:
Pattern 1:
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->
(B=1.4142...) (1)
Pattern 2:
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
(B=1.4142...) (1) (0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if erf_node.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[erf_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return False
add_after_erf = children[0]
if not self.has_constant_input(add_after_erf, 1):
return False
if add_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[add_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_after_erf = children[0]
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
if div is None:
return False
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
return False
subgraph_input = div.input[0]
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
if subgraph_input == mul_after_erf.input[another]: # pattern 2
children = input_name_to_nodes[mul_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_half = children[0]
if not self.has_constant_input(mul_half, 0.5):
return False
subgraph_output = mul_half.output[0]
else: # pattern 1
mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
if mul_half is None:
return False
if not self.has_constant_input(mul_half, 0.5):
return False
if subgraph_input not in mul_half.input:
return False
subgraph_output = mul_after_erf.output[0]
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
return False
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
def fuse_2(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
"""
This pattern is from Keras model
Fuse Gelu with Erf into one node:
+------------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul
(B=1.4142...) (A=1) (A=0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if erf_node.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[erf_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return False
add_after_erf = children[0]
if not self.has_constant_input(add_after_erf, 1):
return False
if add_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[add_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_after_erf = children[0]
if not self.has_constant_input(mul_after_erf, 0.5):
return False
if mul_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[mul_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul = children[0]
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
if div is None:
return False
sqrt_node = None
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
if sqrt_node is None:
return False
if not self.has_constant_input(sqrt_node, 2.0):
return False
subgraph_input = div.input[0]
if subgraph_input not in mul.input:
return False
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
if sqrt_node:
subgraph_nodes.append(sqrt_node)
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
return False
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
def fuse_3(
self,
erf_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
) -> bool:
"""
This pattern is from TensorFlow model
Fuse Gelu with Erf into one node:
+----------------------------------------------+
| |
| v
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
(A=0.7071067690849304) (B=1) (B=0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if erf_node.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[erf_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return False
add_after_erf = children[0]
if not self.has_constant_input(add_after_erf, 1):
return False
if add_after_erf.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[add_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
mul_half = children[0]
if not self.has_constant_input(mul_half, 0.5):
return False
first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
if first_mul is None:
return False
i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
if i < 0:
return False
root_input_index = 1 - i
subgraph_input = first_mul.input[root_input_index]
if mul_half.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[mul_half.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return False
last_mul = children[0]
if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
return False
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
if not self.is_safe_to_fuse_nodes(
subgraph_nodes,
[last_mul.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return False
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True

View File

@ -0,0 +1,135 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import onnx
from ..onnx_model import ONNXModel
from .fusion import Fusion
class FusionLayerNormalization(Fusion):
def __init__(self, model: ONNXModel):
super().__init__(model, "LayerNormalization", "ReduceMean")
def fuse(
self,
reduce_mean_node: onnx.NodeProto,
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
output_name_to_node: dict[str, onnx.NodeProto],
):
"""
Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
LayerNormalization node.
+----------------------+
| |
| v
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
| |
+-------------------------------------------------+
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
+----------------------+
| v
| +-------> Sub-----------------------------------------------+
| | |
| | v
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
| ^
| |
+----------------------+
"""
children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
if len(children) == 0 or len(children) > 2:
return
root_input = reduce_mean_node.input[0]
if children[0].op_type != "Sub" or children[0].input[0] != root_input:
return
if len(children) == 2:
if children[1].op_type != "Sub" or children[1].input[0] != root_input:
return
div_node = None
for child in children:
div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
if div_node is not None:
break
if div_node is None:
return
path_id, parent_nodes, _ = self.match_parent_paths(
div_node,
[
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
(
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
[1, 0, 0, 0, 0, 0],
),
],
output_name_to_node,
)
if path_id < 0:
return
sub_node = parent_nodes[-1]
if sub_node not in children:
return
second_add_node = parent_nodes[1]
i, add_weight = self.get_constant_input(second_add_node)
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
# Skip fusion since epsilon value is not expected.
return
pow_node = parent_nodes[3]
if self.find_constant_input(pow_node, 2.0) != 1:
return
mul_node = input_name_to_nodes[div_node.output[0]][0]
if mul_node.op_type != "Mul":
return
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
return
subgraph_nodes = [reduce_mean_node]
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])
subgraph_nodes.extend([last_add_node, mul_node, div_node])
if not self.is_safe_to_fuse_nodes(
subgraph_nodes,
last_add_node.output,
input_name_to_nodes,
output_name_to_node,
):
return
weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
if not self.is_constant_with_specified_rank(weight_input, 1):
return
bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
if not self.is_constant_with_specified_rank(bias_input, 1):
return
self.nodes_to_remove.extend(subgraph_nodes)
normalize_node = onnx.helper.make_node(
"LayerNormalization",
name=self.create_unique_node_name(),
inputs=[reduce_mean_node.input[0], weight_input, bias_input],
outputs=[last_add_node.output[0]],
)
normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
self.nodes_to_add.append(normalize_node)