Files
Reinforced-Learning-Godot/rl/Lib/site-packages/onnxruntime/transformers/fusion_quickgelu.py
2024-10-30 22:14:35 +01:00

75 lines
2.7 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class FusionQuickGelu(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "QuickGelu", ["Mul"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Fuse the following subgraph to `QuickGelu`
#
# root_input
# / \
# | Mul ----+
# | (B = ~1.702) |
# \ | |
# \ Sigmoid |---- `QuickGelu`
# \ / |
# \ / |
# Mul ----+
# |
# root_output
if node.op_type != "Mul":
logger.debug("fuse_quickgelu: failed to match second Mul node")
return
second_mul_node = node
root_input = second_mul_node.input[0]
sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
if sigmoid_node is None:
logger.debug("fuse_quickgelu: failed to match Sigmoid node")
return
sigmoid_node = sigmoid_node[0]
first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
if first_mul_node is None:
logger.debug("fuse_quickgelu: failed to match first Mul node")
return
first_mul_node = first_mul_node[0]
approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
if abs(approximation_value - 1.7021484375) >= 1e-3:
logger.debug("fuse_quickgelu: failed to match approximation value")
return
if first_mul_node.input[0] != root_input:
logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
return
new_node = helper.make_node(
"QuickGelu",
inputs=[root_input],
outputs=[second_mul_node.output[0]],
name=self.model.create_node_name("QuickGelu"),
)
new_node.domain = "com.microsoft"
new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])
self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.increase_counter("QuickGelu")