169 lines
6.7 KiB
Python
169 lines
6.7 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
from logging import getLogger
|
|
from typing import Dict, List
|
|
|
|
from fusion_base import Fusion
|
|
from fusion_utils import FusionUtils
|
|
from onnx import NodeProto, TensorProto, helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class FusionTranspose(Fusion):
|
|
def __init__(self, model: OnnxModel):
|
|
super().__init__(model, "Transpose", "Transpose")
|
|
|
|
def fuse(
|
|
self,
|
|
transpose_node: NodeProto,
|
|
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
output_name_to_node: Dict[str, NodeProto],
|
|
):
|
|
"""
|
|
Note that onnxruntime will do comprehensive transpose optimization after loading model.
|
|
The purpose of this fusion is to make graph clean before running onnxruntime.
|
|
|
|
Case 1:
|
|
(input)-->Transpose(perm=a)-->Transpose(perm=b)-->
|
|
After:
|
|
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
|
|
|
|
+----->Transpose(perm=a*b)-->
|
|
|
|
Case 2 (Cast has only one child):
|
|
(input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
|
|
After:
|
|
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
|
|
|
|
+----->Cast --> Transpose(perm=a*b)-->
|
|
"""
|
|
transpose_b = transpose_node
|
|
if transpose_b.input[0] not in output_name_to_node:
|
|
return
|
|
|
|
transpose_a = output_name_to_node[transpose_b.input[0]]
|
|
if transpose_a.op_type != "Cast":
|
|
cast_node = None
|
|
else:
|
|
cast_node = transpose_a
|
|
|
|
cast_children = self.model.get_children(cast_node, input_name_to_nodes)
|
|
if cast_children and len(cast_children) > 1:
|
|
return
|
|
|
|
if cast_node.input[0] not in output_name_to_node:
|
|
return
|
|
|
|
transpose_a = output_name_to_node[cast_node.input[0]]
|
|
|
|
if transpose_a.op_type != "Transpose":
|
|
return
|
|
|
|
permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
|
|
assert isinstance(permutation, list)
|
|
|
|
parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
|
|
assert isinstance(parent_permutation, list)
|
|
|
|
assert len(parent_permutation) == len(permutation)
|
|
|
|
output_permutation = []
|
|
for _j, index in enumerate(permutation):
|
|
output_permutation.append(parent_permutation[index])
|
|
|
|
if cast_node is None:
|
|
if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
|
|
self.nodes_to_remove.append(transpose_a)
|
|
else:
|
|
if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
|
|
self.nodes_to_remove.append(transpose_a)
|
|
transpose_b.ClearField("attribute")
|
|
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
|
|
|
|
|
|
class FusionInsertTranspose(Fusion):
|
|
def __init__(self, model: OnnxModel):
|
|
super().__init__(model, "", "GroupNorm")
|
|
|
|
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
|
"""Append a Transpose node after an input"""
|
|
node_name = self.model.create_node_name("Transpose")
|
|
if output_name is None:
|
|
output_name = node_name + "_out" + "-" + input_name
|
|
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
|
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
|
return transpose_node
|
|
|
|
def fuse(
|
|
self,
|
|
group_norm_node: NodeProto,
|
|
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
output_name_to_node: Dict[str, NodeProto],
|
|
):
|
|
"""
|
|
This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
|
|
another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
|
|
Before:
|
|
--> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
|
After:
|
|
--> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
|
"""
|
|
gemm_path = self.model.match_parent_path(
|
|
group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
|
|
)
|
|
if gemm_path is None:
|
|
return
|
|
transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
|
|
if self.model.find_graph_output(unsqueeze_3.output[0]):
|
|
return
|
|
|
|
permutation = OnnxModel.get_node_attribute(transpose, "perm")
|
|
assert isinstance(permutation, list)
|
|
if permutation != [0, 2, 3, 1]:
|
|
return
|
|
|
|
if not (
|
|
len(unsqueeze_3.input) == 2
|
|
and self.model.get_constant_value(unsqueeze_3.input[1]) == 3
|
|
and len(unsqueeze_2.input) == 2
|
|
and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
|
|
and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
|
|
and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
|
|
and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
|
|
):
|
|
return
|
|
|
|
# Here we use hard-coded name so that it could be shared for the whole model.
|
|
axes_1 = "ort_const_unsqueeze_axes_1"
|
|
if self.model.get_initializer(axes_1) is None:
|
|
self.add_initializer(
|
|
name=axes_1,
|
|
data_type=TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[1],
|
|
raw=False,
|
|
)
|
|
|
|
axes_2 = "ort_const_unsqueeze_axes_2"
|
|
if self.model.get_initializer(axes_2) is None:
|
|
self.add_initializer(
|
|
name=axes_2,
|
|
data_type=TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[2],
|
|
raw=False,
|
|
)
|
|
|
|
unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
|
|
unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
|
|
transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
|
|
self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
|
|
new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
|
|
self.model.add_node(new_transpose, self.this_graph_name)
|
|
self.increase_counter("Insert Transpose")
|