1236 lines
50 KiB
Python
1236 lines
50 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
from logging import getLogger
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
from fusion_base import Fusion
|
|
from fusion_options import AttentionMaskFormat
|
|
from fusion_utils import FusionUtils, NumpyHelper
|
|
from onnx import NodeProto, TensorProto, helper, numpy_helper
|
|
from onnx_model import OnnxModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class AttentionMask:
|
|
"""
|
|
Fuse Attention subgraph into one Attention node.
|
|
"""
|
|
|
|
def __init__(self, model: OnnxModel):
|
|
self.model = model
|
|
# A lookup table with mask input as key, and mask index output as value
|
|
self.mask_indice = {}
|
|
# A lookup table with mask input as key, and cast (to int32) output as value
|
|
self.mask_casted = {}
|
|
self.utils = FusionUtils(model)
|
|
self.mask_format = AttentionMaskFormat.MaskIndexEnd
|
|
self.opset_version = model.get_opset_version()
|
|
|
|
def set_mask_format(self, mask_format: AttentionMaskFormat):
|
|
self.mask_format = mask_format
|
|
|
|
def set_mask_indice(self, mask, mask_index):
|
|
if mask in self.mask_indice:
|
|
assert mask_index == self.mask_indice[mask]
|
|
self.mask_indice[mask] = mask_index
|
|
|
|
def get_first_mask(self):
|
|
assert len(self.mask_indice) > 0
|
|
return next(iter(self.mask_indice))
|
|
|
|
def process_mask(self, input: str) -> str:
|
|
if self.mask_format == AttentionMaskFormat.NoMask:
|
|
return None
|
|
|
|
if input in self.mask_indice:
|
|
return self.mask_indice[input]
|
|
|
|
# Add cast to convert int64 to int32
|
|
if self.model.find_graph_input(input):
|
|
casted, input_name = self.utils.cast_graph_input_to_int32(input)
|
|
else:
|
|
input_name, cast_node = self.utils.cast_input_to_int32(input)
|
|
casted = True
|
|
|
|
if casted:
|
|
self.mask_casted[input] = input_name
|
|
|
|
# Attention supports int32 attention mask (2D) since 1.4.0
|
|
if self.mask_format == AttentionMaskFormat.AttentionMask:
|
|
self.mask_indice[input] = input_name
|
|
return input_name
|
|
|
|
# Add a mask processing node to convert attention mask to mask index (1D)
|
|
output_name = self.model.create_node_name("mask_index")
|
|
if self.opset_version < 13:
|
|
mask_index_node = helper.make_node(
|
|
"ReduceSum",
|
|
inputs=[input_name],
|
|
outputs=[output_name],
|
|
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
|
|
)
|
|
mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
|
|
else:
|
|
# ReduceSum-13: axes is moved from attribute to input
|
|
axes_name = "ort_const_1_reduce_sum_axes"
|
|
if self.model.get_initializer(axes_name) is None:
|
|
self.model.add_initializer(
|
|
helper.make_tensor(
|
|
name=axes_name,
|
|
data_type=TensorProto.INT64,
|
|
dims=[1],
|
|
vals=[1],
|
|
raw=False,
|
|
)
|
|
)
|
|
mask_index_node = helper.make_node(
|
|
"ReduceSum",
|
|
inputs=[input_name, axes_name],
|
|
outputs=[output_name],
|
|
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
|
|
)
|
|
mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
|
|
|
|
self.model.add_node(mask_index_node)
|
|
|
|
self.mask_indice[input] = output_name
|
|
return output_name
|
|
|
|
|
|
class FusionAttention(Fusion):
|
|
"""
|
|
Fuse Attention subgraph into one Attention node.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: OnnxModel,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
attention_mask: Optional[AttentionMask] = None,
|
|
use_multi_head_attention: bool = False,
|
|
disable_multi_head_attention_bias: bool = False,
|
|
search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006
|
|
):
|
|
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
|
|
super().__init__(model, attention_op_name, search_op_types)
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
|
|
self.use_multi_head_attention = use_multi_head_attention
|
|
self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
|
|
self.mask_filter_value = None
|
|
|
|
# Flags to show warning only once
|
|
self.num_heads_warning = True
|
|
self.hidden_size_warning = True
|
|
|
|
self.shape_infer = None
|
|
self.shape_infer_done = True
|
|
|
|
def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]:
|
|
"""
|
|
Detect num_heads and hidden_size from Concat node in the following subgraph:
|
|
|
|
SkipLayerNormalization or EmbedLayerNormalization
|
|
/ |
|
|
MatMul Shape
|
|
| |
|
|
Add Gather(indices=0)
|
|
| |
|
|
| Unsqueeze
|
|
| |
|
|
| Concat (*, -1, 12, 64)
|
|
| /
|
|
Reshape
|
|
|
|
|
Transpose
|
|
"""
|
|
if len(concat.input) == 4:
|
|
num_heads = self.model.get_constant_value(concat.input[2])
|
|
head_size = self.model.get_constant_value(concat.input[3])
|
|
if (
|
|
isinstance(num_heads, np.ndarray)
|
|
and num_heads.size == 1
|
|
and isinstance(head_size, np.ndarray)
|
|
and head_size.size == 1
|
|
):
|
|
return num_heads[0], num_heads[0] * head_size[0]
|
|
|
|
return self.num_heads, self.hidden_size
|
|
|
|
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
|
|
"""Detect num_heads and hidden_size from a reshape node.
|
|
|
|
Args:
|
|
reshape_q (NodeProto): reshape node for Q
|
|
|
|
Returns:
|
|
Tuple[int, int]: num_heads and hidden_size
|
|
"""
|
|
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
q_shape = self.model.get_initializer(reshape_q.input[1])
|
|
if q_shape is None:
|
|
concat = self.model.get_parent(reshape_q, 1)
|
|
if concat is not None and concat.op_type == "Concat":
|
|
return self.get_num_heads_and_hidden_size_from_concat(concat)
|
|
logger.debug(f"{reshape_q.input[1]} is not initializer.")
|
|
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
|
|
q_shape_value = NumpyHelper.to_array(q_shape)
|
|
if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
|
|
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
|
|
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
|
|
num_heads = q_shape_value[2]
|
|
head_size = q_shape_value[3]
|
|
hidden_size = num_heads * head_size
|
|
|
|
if self.num_heads > 0 and num_heads != self.num_heads:
|
|
if self.num_heads_warning:
|
|
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
|
self.num_heads_warning = False # Do not show the warning more than once
|
|
|
|
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
|
if self.hidden_size_warning:
|
|
logger.warning(
|
|
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
|
)
|
|
self.hidden_size_warning = False # Do not show the warning more than once
|
|
|
|
return num_heads, hidden_size
|
|
|
|
def get_add_qk_str(self, add_qk: NodeProto):
|
|
if not self.shape_infer_done:
|
|
self.shape_infer = self.model.infer_runtime_shape(update=True)
|
|
self.shape_infer_done = True
|
|
|
|
if self.shape_infer is None:
|
|
return None
|
|
|
|
input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
|
|
input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])
|
|
|
|
if input_0_shape is None or input_1_shape is None:
|
|
logger.debug(f"one of the inputs of {add_qk} is None")
|
|
return None
|
|
|
|
if input_0_shape != input_1_shape:
|
|
logger.debug(f"the shape of two inputs of {add_qk} is not same")
|
|
return None
|
|
|
|
return add_qk.input[1]
|
|
|
|
def reshape_add_qk(self, add_qk: str):
|
|
# Convert 4D mask from (B,1,S,T) to (B,N,S,T)
|
|
# B = batch size, N = num heads, S = source sequence length, T = target sequence length
|
|
mask_output_name = add_qk + "_mask"
|
|
|
|
# Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
|
|
concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
|
|
if len(concat_node) == 1:
|
|
return mask_output_name
|
|
|
|
assert len(concat_node) == 0
|
|
concat_node_name = self.model.create_node_name("Concat")
|
|
concat_add_qk_fp32 = helper.make_node(
|
|
"Concat",
|
|
inputs=[add_qk for _ in range(self.num_heads)],
|
|
outputs=[mask_output_name],
|
|
name=concat_node_name,
|
|
axis=1,
|
|
)
|
|
# Add new node to graph
|
|
self.nodes_to_add.append(concat_add_qk_fp32)
|
|
self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
|
|
|
|
return mask_output_name
|
|
|
|
def concat_kv(self, past_k: str, past_v: str) -> str:
|
|
"""Concatenate past_k and past_v inputs to create past_kv input.
|
|
|
|
Args:
|
|
past_k (str): name of past K value
|
|
past_v (str): name of past V value
|
|
|
|
Returns:
|
|
kv_output_name (str): name of past KV value
|
|
"""
|
|
# Unsqueeze K and V nodes from (B,N,P,H) to (1,B,N,P,H)
|
|
# B = batch size, N = num heads, P = past sequence length, H = head size
|
|
unsqueeze_k_name = self.model.create_node_name("Unsqueeze")
|
|
unsqueeze_v_name = self.model.create_node_name("Unsqueeze")
|
|
k_5d_name = (past_k + "_5d").replace(".", "_")
|
|
v_5d_name = (past_v + "_5d").replace(".", "_")
|
|
|
|
k_5d = helper.make_node(
|
|
"Unsqueeze",
|
|
inputs=[past_k],
|
|
outputs=[k_5d_name],
|
|
name=unsqueeze_k_name,
|
|
axes=[0],
|
|
)
|
|
v_5d = helper.make_node(
|
|
"Unsqueeze",
|
|
inputs=[past_v],
|
|
outputs=[v_5d_name],
|
|
name=unsqueeze_v_name,
|
|
axes=[0],
|
|
)
|
|
|
|
# Add unsqueeze nodes to graph
|
|
self.nodes_to_add.append(k_5d)
|
|
self.nodes_to_add.append(v_5d)
|
|
self.node_name_to_graph_name[unsqueeze_k_name] = self.this_graph_name
|
|
self.node_name_to_graph_name[unsqueeze_v_name] = self.this_graph_name
|
|
|
|
# Concat K and V to get one node of size (2,B,N,P,H)
|
|
concat_node_name = self.model.create_node_name("Concat")
|
|
kv_output_name = past_v.replace(".value", ".kv").replace(".", "_").replace("_value", "_kv")
|
|
concat_kv = helper.make_node(
|
|
"Concat",
|
|
inputs=[k_5d_name, v_5d_name],
|
|
outputs=[kv_output_name],
|
|
name=concat_node_name,
|
|
axis=0,
|
|
)
|
|
|
|
# Add concat node to graph
|
|
self.nodes_to_add.append(concat_kv)
|
|
self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
|
|
|
|
return kv_output_name
|
|
|
|
def reshape_kv(self, past_k: str, past_v: str) -> (str, str):
|
|
"""Reshape past_k and past_v from 4D to 3D to use as inputs for multihead attention node.
|
|
|
|
Args:
|
|
past_k (str): name of past K value of shape 4D
|
|
past_v (str): name of past V value of shape 4D
|
|
|
|
Returns:
|
|
k_3d (str): name of past K value of shape 3D
|
|
v_3d (str): name of past V value of shape 3D
|
|
"""
|
|
# Reshape past_k and past_v from (B,N,P,H) to (B,P,N*H)
|
|
# B = batch size, N = num heads, P = past seq len, H = head size
|
|
|
|
# Create initializer for reshaping past_k and past_v
|
|
new_dims_name = "kv_4d_to_3d"
|
|
new_dims = self.model.get_initializer(new_dims_name)
|
|
if new_dims is None:
|
|
new_dims = numpy_helper.from_array(
|
|
np.array([0, -1, self.model.hidden_size], dtype="int64"), name=new_dims_name
|
|
)
|
|
self.model.add_initializer(new_dims, self.this_graph_name)
|
|
|
|
reshape_k_name = self.model.create_node_name("Reshape")
|
|
reshape_v_name = self.model.create_node_name("Reshape")
|
|
k_3d_name = (past_k + "_3d").replace(".", "_")
|
|
v_3d_name = (past_v + "_3d").replace(".", "_")
|
|
|
|
k_3d = helper.make_node(
|
|
"Reshape",
|
|
inputs=[past_k, new_dims_name],
|
|
outputs=[k_3d_name],
|
|
name=reshape_k_name,
|
|
)
|
|
v_3d = helper.make_node(
|
|
"Reshape",
|
|
inputs=[past_v, new_dims_name],
|
|
outputs=[v_3d_name],
|
|
name=reshape_v_name,
|
|
)
|
|
|
|
# Add reshape nodes to graph
|
|
self.nodes_to_add.append(k_3d)
|
|
self.nodes_to_add.append(v_3d)
|
|
self.node_name_to_graph_name[reshape_k_name] = self.this_graph_name
|
|
self.node_name_to_graph_name[reshape_v_name] = self.this_graph_name
|
|
|
|
return k_3d_name, v_3d_name
|
|
|
|
def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
|
|
"""Split kv_node containing present KV values into separate present K and present V values.
|
|
|
|
Args:
|
|
present_k_name (str): name of output to store present K value in
|
|
present_v_name (str): name of output to store present V value in
|
|
kv_node (str): name of present KV values
|
|
"""
|
|
# Split kv_node into present_k and present_v nodes
|
|
|
|
# Create initializers for indexing kv_node, whose shape is (2,B,N,P,H)
|
|
k_index, v_index = "index_0", "index_1"
|
|
k_dim = self.model.get_initializer(k_index)
|
|
v_dim = self.model.get_initializer(v_index)
|
|
if k_dim is None:
|
|
k_dim = numpy_helper.from_array(np.array(0, dtype="int64"), name=k_index)
|
|
self.model.add_initializer(k_dim, self.this_graph_name)
|
|
if v_dim is None:
|
|
v_dim = numpy_helper.from_array(np.array(1, dtype="int64"), name=v_index)
|
|
self.model.add_initializer(v_dim, self.this_graph_name)
|
|
|
|
# Create nodes to index kv_node
|
|
gather_k_name = self.model.create_node_name("Gather")
|
|
gather_v_name = self.model.create_node_name("Gather")
|
|
present_k = helper.make_node(
|
|
"Gather",
|
|
inputs=[kv_node, k_index],
|
|
outputs=[present_k_name],
|
|
name=gather_k_name,
|
|
axis=0,
|
|
)
|
|
present_v = helper.make_node(
|
|
"Gather",
|
|
inputs=[kv_node, v_index],
|
|
outputs=[present_v_name],
|
|
name=gather_v_name,
|
|
axis=0,
|
|
)
|
|
|
|
# Add gather nodes to graph
|
|
self.nodes_to_add.append(present_k)
|
|
self.nodes_to_add.append(present_v)
|
|
self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
|
|
self.node_name_to_graph_name[gather_v_name] = self.this_graph_name
|
|
|
|
def transpose_kv(self, past_k: str, past_v: str):
|
|
"""Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H)
|
|
|
|
Args:
|
|
past_k (str): name of past K value of shape (B,N,P,H)
|
|
past_v (str): name of past V value of shape (B,N,P,H)
|
|
|
|
Returns:
|
|
past_k_transpose (str): name of past K value of shape (B,P,N,H)
|
|
past_v_transpose (str): name of past V value of shape (B,P,N,H)
|
|
"""
|
|
past_k_transpose = (past_k + "_transposed").replace(".", "_")
|
|
past_v_transpose = (past_v + "_transposed").replace(".", "_")
|
|
transpose_k_name = self.model.create_node_name("Transpose")
|
|
transpose_v_name = self.model.create_node_name("Transpose")
|
|
|
|
transpose_k = helper.make_node(
|
|
"Transpose",
|
|
inputs=[past_k],
|
|
outputs=[past_k_transpose],
|
|
name=transpose_k_name,
|
|
perm=[0, 2, 1, 3],
|
|
)
|
|
transpose_v = helper.make_node(
|
|
"Transpose",
|
|
inputs=[past_v],
|
|
outputs=[past_v_transpose],
|
|
name=transpose_v_name,
|
|
perm=[0, 2, 1, 3],
|
|
)
|
|
|
|
# Add reshape nodes to graph
|
|
self.nodes_to_add.append(transpose_k)
|
|
self.nodes_to_add.append(transpose_v)
|
|
self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name
|
|
self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name
|
|
|
|
return past_k_transpose, past_v_transpose
|
|
|
|
def create_combined_qkv_bias(
|
|
self,
|
|
q_add: NodeProto,
|
|
k_add: Union[NodeProto, None],
|
|
v_add: Union[NodeProto, None],
|
|
name_prefix: str,
|
|
) -> Union[NodeProto, None]:
|
|
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
|
|
qb = NumpyHelper.to_array(q_bias)
|
|
kb = np.zeros_like(qb)
|
|
vb = np.zeros_like(qb)
|
|
if k_add is not None:
|
|
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
|
|
kb = NumpyHelper.to_array(k_bias)
|
|
if v_add is not None:
|
|
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
|
|
vb = NumpyHelper.to_array(v_bias)
|
|
|
|
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
|
qkv_bias_dim = 3 * np.prod(qb.shape)
|
|
|
|
bias_name = name_prefix + "_qkv_bias"
|
|
self.add_initializer(
|
|
name=bias_name,
|
|
data_type=q_bias.data_type,
|
|
dims=[qkv_bias_dim],
|
|
vals=qkv_bias,
|
|
)
|
|
return bias_name
|
|
|
|
def create_packed_qkv_matmul_node(
|
|
self,
|
|
q_matmul: NodeProto,
|
|
k_matmul: NodeProto,
|
|
v_matmul: NodeProto,
|
|
q_add: NodeProto,
|
|
k_add: Union[NodeProto, None],
|
|
v_add: Union[NodeProto, None],
|
|
num_heads: int,
|
|
) -> Union[NodeProto, None]:
|
|
"""Create packed QKV MatMul node before MultiHeadAttention node.
|
|
This is for the scenario where an Attention node should be created but cannot be created
|
|
because past_key and past_value are separate inputs and not one concatenated input.
|
|
|
|
Args:
|
|
q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
|
|
k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size)
|
|
v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size)
|
|
q_add (NodeProto): name of Add from Q path
|
|
k_add (NodeProto): name of Add from K path
|
|
v_add (NodeProto): name of Add from V path
|
|
num_heads (int): number of heads
|
|
|
|
Returns:
|
|
Union[NodeProto, None]: the node created or None if failed.
|
|
"""
|
|
matmul_node_name = self.model.create_node_name("MatMul")
|
|
|
|
# Check that input for Q, K, V is the same
|
|
assert q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
|
|
|
|
# Created packed QKV weight
|
|
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
|
|
qw = NumpyHelper.to_array(q_weight)
|
|
kw = NumpyHelper.to_array(k_weight)
|
|
vw = NumpyHelper.to_array(v_weight)
|
|
|
|
assert qw.shape == kw.shape and kw.shape == vw.shape
|
|
d = qw.shape[0]
|
|
|
|
qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
|
|
qkv_weight_name = matmul_node_name + "_qkv_weight"
|
|
|
|
self.add_initializer(
|
|
name=qkv_weight_name,
|
|
data_type=q_weight.data_type,
|
|
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
|
|
vals=qkv_weight,
|
|
)
|
|
|
|
# Created packed QKV MatMul with output (B, S, 3*D)
|
|
# Output is of the form:
|
|
#
|
|
# [[[Q Q ... Q Q K K ... K K V V ... V V]]]
|
|
# [Q Q ... Q Q K K ... K K V V ... V V]
|
|
# .
|
|
# .
|
|
# .
|
|
# [[Q Q ... Q Q K K ... K K V V ... V V]
|
|
# [Q Q ... Q Q K K ... K K V V ... V V]]]
|
|
qkv_matmul_output = matmul_node_name + "_qkv_out"
|
|
qkv_matmul = helper.make_node(
|
|
"MatMul",
|
|
inputs=[q_matmul.input[0], qkv_weight_name],
|
|
outputs=[qkv_matmul_output],
|
|
name=matmul_node_name,
|
|
)
|
|
self.node_name_to_graph_name[matmul_node_name] = self.this_graph_name
|
|
|
|
qkv_nodes = [qkv_matmul]
|
|
|
|
# Create Slice nodes to access Q, K, V
|
|
q_slice_name = matmul_node_name + "_q_start_index"
|
|
self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
|
|
k_slice_name = matmul_node_name + "_k_start_index"
|
|
self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
|
|
v_slice_name = matmul_node_name + "_v_start_index"
|
|
self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
|
|
end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
|
|
self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
|
|
qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
|
|
self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)
|
|
|
|
q_slice_output = matmul_node_name + "_q_out"
|
|
q_slice = helper.make_node(
|
|
"Slice",
|
|
inputs=[qkv_matmul_output, q_slice_name, k_slice_name, qkv_last_axis_name],
|
|
outputs=[q_slice_output],
|
|
name=self.model.create_node_name("Slice"),
|
|
)
|
|
self.node_name_to_graph_name[q_slice.name] = self.this_graph_name
|
|
k_slice_output = matmul_node_name + "_k_out"
|
|
k_slice = helper.make_node(
|
|
"Slice",
|
|
inputs=[qkv_matmul_output, k_slice_name, v_slice_name, qkv_last_axis_name],
|
|
outputs=[k_slice_output],
|
|
name=self.model.create_node_name("Slice"),
|
|
)
|
|
self.node_name_to_graph_name[k_slice.name] = self.this_graph_name
|
|
v_slice_output = matmul_node_name + "_v_out"
|
|
v_slice = helper.make_node(
|
|
"Slice",
|
|
inputs=[qkv_matmul_output, v_slice_name, end_of_qkv_name, qkv_last_axis_name],
|
|
outputs=[v_slice_output],
|
|
name=self.model.create_node_name("Slice"),
|
|
)
|
|
self.node_name_to_graph_name[v_slice.name] = self.this_graph_name
|
|
|
|
q_output = q_slice
|
|
k_output = k_slice
|
|
v_output = v_slice
|
|
qkv_nodes.extend([q_slice, k_slice, v_slice])
|
|
|
|
if self.disable_multi_head_attention_bias:
|
|
if q_add is not None:
|
|
initializer_input = 1 if self.model.get_initializer(q_add.input[1]) else 0
|
|
if np.any(NumpyHelper.to_array(self.model.get_initializer(q_add.input[initializer_input]))):
|
|
q_add.input[1 - initializer_input] = q_slice_output
|
|
q_output = q_add
|
|
qkv_nodes.append(q_add)
|
|
self.node_name_to_graph_name[q_add.name] = self.this_graph_name
|
|
if k_add is not None:
|
|
initializer_input = 1 if self.model.get_initializer(k_add.input[1]) else 0
|
|
if np.any(NumpyHelper.to_array(self.model.get_initializer(k_add.input[initializer_input]))):
|
|
k_add.input[1 - initializer_input] = k_slice_output
|
|
k_output = k_add
|
|
qkv_nodes.append(k_add)
|
|
self.node_name_to_graph_name[k_add.name] = self.this_graph_name
|
|
if v_add is not None:
|
|
initializer_input = 1 if self.model.get_initializer(v_add.input[1]) else 0
|
|
if np.any(NumpyHelper.to_array(self.model.get_initializer(v_add.input[initializer_input]))):
|
|
v_add.input[1 - initializer_input] = v_slice_output
|
|
v_output = v_add
|
|
qkv_nodes.append(v_add)
|
|
self.node_name_to_graph_name[v_add.name] = self.this_graph_name
|
|
|
|
# Add nodes to graph
|
|
self.nodes_to_add.extend(qkv_nodes)
|
|
return q_output, k_output, v_output
|
|
|
|
def create_multihead_attention_node(
|
|
self,
|
|
q_matmul: NodeProto,
|
|
k_matmul: Union[NodeProto, str, None],
|
|
v_matmul: Union[NodeProto, str, None],
|
|
q_add: NodeProto,
|
|
k_add: Union[NodeProto, None],
|
|
v_add: Union[NodeProto, None],
|
|
num_heads: int,
|
|
hidden_size: int,
|
|
output: str,
|
|
key_padding_mask: str = "",
|
|
add_qk: str = "",
|
|
past_k: str = "",
|
|
past_v: str = "",
|
|
present_k: str = "",
|
|
present_v: str = "",
|
|
packed_qkv: bool = False,
|
|
) -> Union[NodeProto, None]:
|
|
"""Create a MultiHeadAttention node.
|
|
|
|
Args:
|
|
q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
|
|
k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
|
|
v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
|
|
q_add (NodeProto): name of Add from Q path
|
|
k_add (NodeProto): name of Add from K path
|
|
v_add (NodeProto): name of Add from V path
|
|
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
output (str): output name of MHA
|
|
key_padding_mask (str): name of key padding mask
|
|
add_qk (str): name of add after Q x K'
|
|
past_k (str): name of past K value - (batch_size, num_heads, past_sequence_length, head_size)
|
|
past_v (str): name of past V value - (batch_size, num_heads, past_sequence_length, head_size)
|
|
present_k (str): name of present K value - (batch_size, num_heads, sequence_length, head_size)
|
|
present_v (str): name of present V value - (batch_size, num_heads, sequence_length, head_size)
|
|
packed_qkv (bool): whether to combine MatMuls from Q, K, V paths
|
|
Note: This is for the scenario where an Attention node should be created but cannot be created
|
|
because past_key and past_value are separate inputs and not one concatenated input.
|
|
|
|
Returns:
|
|
Union[NodeProto, None]: the node created or None if failed.
|
|
"""
|
|
# B = batch size, N = num heads, P = past seq len, H = head size
|
|
assert num_heads > 0
|
|
|
|
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
return None
|
|
|
|
graph_input_names = set([node.name for node in self.model.graph().input])
|
|
mha_node_name = self.model.create_node_name("Attention")
|
|
|
|
# Add initial Q/K/V inputs for MHA
|
|
mha_inputs = []
|
|
if packed_qkv:
|
|
q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node(
|
|
q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, num_heads
|
|
)
|
|
mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]])
|
|
elif type(k_matmul) is NodeProto and type(v_matmul) is NodeProto:
|
|
if self.disable_multi_head_attention_bias:
|
|
mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]])
|
|
else:
|
|
mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]])
|
|
elif (
|
|
type(k_matmul) == str # noqa: E721
|
|
and type(v_matmul) == str # noqa: E721
|
|
and k_matmul in graph_input_names
|
|
and v_matmul in graph_input_names
|
|
):
|
|
if self.disable_multi_head_attention_bias:
|
|
mha_inputs.extend([q_add.output[0], k_matmul, v_matmul])
|
|
else:
|
|
mha_inputs.extend([q_matmul.output[0], k_matmul, v_matmul])
|
|
else:
|
|
return None
|
|
|
|
# Add bias to inputs for MHA
|
|
# Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume
|
|
# bias has been added to key and value when they are in BNSH format, so only bias for query is used.
|
|
# Need add checks if we found such assumption is not true.
|
|
if not self.disable_multi_head_attention_bias:
|
|
bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name)
|
|
mha_inputs.append(bias_name)
|
|
else:
|
|
mha_inputs.append("")
|
|
|
|
# Add optional inputs for MHA
|
|
|
|
if past_k and past_v:
|
|
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
|
|
elif key_padding_mask or add_qk:
|
|
mha_inputs.extend([key_padding_mask, add_qk])
|
|
|
|
# Add outputs for MHA
|
|
mha_outputs = [output]
|
|
if present_k and present_v:
|
|
mha_outputs.extend([present_k, present_v])
|
|
|
|
mha_node = helper.make_node(
|
|
"MultiHeadAttention",
|
|
inputs=mha_inputs,
|
|
outputs=mha_outputs,
|
|
name=mha_node_name,
|
|
)
|
|
mha_node.domain = "com.microsoft"
|
|
mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
return mha_node
|
|
|
|
def create_attention_node(
|
|
self,
|
|
mask_index: str,
|
|
q_matmul: NodeProto,
|
|
k_matmul: NodeProto,
|
|
v_matmul: NodeProto,
|
|
q_add: NodeProto,
|
|
k_add: NodeProto,
|
|
v_add: NodeProto,
|
|
num_heads: int,
|
|
hidden_size: int,
|
|
input: str,
|
|
output: str,
|
|
add_qk_str: str = "",
|
|
past_k: str = "",
|
|
past_v: str = "",
|
|
present_k: str = "",
|
|
present_v: str = "",
|
|
scale: Optional[float] = None,
|
|
causal: bool = False,
|
|
) -> Union[NodeProto, None]:
|
|
"""Create an Attention node.
|
|
|
|
Args:
|
|
mask_index (str): mask input
|
|
q_matmul (NodeProto): MatMul node in fully connection for Q
|
|
k_matmul (NodeProto): MatMul node in fully connection for K
|
|
v_matmul (NodeProto): MatMul node in fully connection for V
|
|
q_add (NodeProto): Add bias node in fully connection for Q
|
|
k_add (NodeProto): Add bias node in fully connection for K
|
|
v_add (NodeProto): Add bias node in fully connection for V
|
|
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
input (str): input name
|
|
output (str): output name
|
|
add_qk_str (str): name of Add node after Q x K'
|
|
past_k (str): name of input for past K value
|
|
past_v (str): name of input for past V value
|
|
present_k (str): name of output to store present K value
|
|
present_v (str): name of output to store present V value
|
|
scale: scale before softmax
|
|
causal: whether it is uni-directional mask.
|
|
|
|
Returns:
|
|
Union[NodeProto, None]: the node created or None if failed.
|
|
"""
|
|
assert num_heads > 0
|
|
|
|
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
return None
|
|
|
|
has_bias = True
|
|
if q_add is None and k_add is None and v_add is None:
|
|
has_bias = False
|
|
|
|
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
|
|
q_bias, k_bias, v_bias = None, None, None
|
|
if has_bias:
|
|
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
|
|
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
|
|
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
|
|
|
|
if not (k_weight and v_weight and q_bias and k_bias):
|
|
return None
|
|
|
|
if q_weight is None:
|
|
print(
|
|
f"{q_matmul.input[1]} is not an initializer. "
|
|
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
|
|
)
|
|
return None
|
|
|
|
qw = NumpyHelper.to_array(q_weight)
|
|
kw = NumpyHelper.to_array(k_weight)
|
|
vw = NumpyHelper.to_array(v_weight)
|
|
|
|
# assert q and k have same shape as expected
|
|
assert qw.shape == kw.shape
|
|
|
|
qw_in_size = qw.shape[0]
|
|
kw_in_size = kw.shape[0]
|
|
vw_in_size = vw.shape[0]
|
|
|
|
assert qw_in_size == kw_in_size == vw_in_size
|
|
|
|
if hidden_size > 0 and hidden_size != qw_in_size:
|
|
logger.warning(
|
|
f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
|
|
"Please provide a correct input hidden size or pass in 0"
|
|
)
|
|
|
|
is_qkv_diff_dims = False
|
|
if qw.shape != vw.shape:
|
|
is_qkv_diff_dims = True
|
|
|
|
# All the matrices can have the same shape or q, k matrices can have the same shape with v being different
|
|
# For 2d weights, the shapes would be [in_size, out_size].
|
|
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
|
|
qw_out_size = np.prod(qw.shape[1:])
|
|
kw_out_size = np.prod(kw.shape[1:])
|
|
vw_out_size = np.prod(vw.shape[1:])
|
|
|
|
qkv_weight_dim = 0
|
|
if is_qkv_diff_dims:
|
|
qkv_weight = np.concatenate((qw, kw, vw), axis=1)
|
|
qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
|
|
else:
|
|
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
|
qkv_weight_dim = 3 * qw_out_size
|
|
|
|
if has_bias:
|
|
qb = NumpyHelper.to_array(q_bias)
|
|
kb = NumpyHelper.to_array(k_bias)
|
|
vb = NumpyHelper.to_array(v_bias)
|
|
|
|
q_bias_shape = np.prod(qb.shape)
|
|
k_bias_shape = np.prod(kb.shape)
|
|
v_bias_shape = np.prod(vb.shape)
|
|
|
|
assert q_bias_shape == k_bias_shape == qw_out_size
|
|
assert v_bias_shape == vw_out_size
|
|
|
|
if is_qkv_diff_dims:
|
|
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
|
|
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
|
|
else:
|
|
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
|
qkv_bias_dim = 3 * q_bias_shape
|
|
|
|
attention_node_name = self.model.create_node_name("Attention")
|
|
|
|
if not self.use_multi_head_attention:
|
|
self.add_initializer(
|
|
name=attention_node_name + "_qkv_weight",
|
|
data_type=q_weight.data_type,
|
|
dims=[qw_in_size, qkv_weight_dim],
|
|
vals=qkv_weight,
|
|
)
|
|
|
|
if has_bias:
|
|
self.add_initializer(
|
|
name=attention_node_name + "_qkv_bias",
|
|
data_type=q_bias.data_type,
|
|
dims=[qkv_bias_dim],
|
|
vals=qkv_bias,
|
|
)
|
|
|
|
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
|
|
if self.use_multi_head_attention:
|
|
if add_qk_str:
|
|
logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
|
|
return None
|
|
|
|
attention_inputs = [
|
|
q_matmul.output[0],
|
|
k_matmul.output[0],
|
|
v_matmul.output[0],
|
|
attention_node_name + "_qkv_bias",
|
|
]
|
|
|
|
if mask_index is not None:
|
|
attention_inputs.append(mask_index)
|
|
|
|
attention_node = helper.make_node(
|
|
"MultiHeadAttention",
|
|
inputs=attention_inputs,
|
|
outputs=[output],
|
|
name=attention_node_name,
|
|
)
|
|
else:
|
|
attention_inputs = [
|
|
input,
|
|
attention_node_name + "_qkv_weight",
|
|
attention_node_name + "_qkv_bias" if has_bias else "",
|
|
]
|
|
if mask_index is not None:
|
|
attention_inputs.append(mask_index)
|
|
else:
|
|
attention_inputs.append("")
|
|
|
|
past_exists = past_k and past_v
|
|
if past_exists:
|
|
past_kv = self.concat_kv(past_k, past_v)
|
|
attention_inputs.append(past_kv)
|
|
|
|
if add_qk_str is not None:
|
|
mask_output_name = self.reshape_add_qk(add_qk_str)
|
|
|
|
# Add attention mask to attention node
|
|
if not past_exists:
|
|
attention_inputs.append("")
|
|
attention_inputs.append(mask_output_name)
|
|
|
|
attention_outputs = [output]
|
|
if present_k and present_v:
|
|
present_kv = present_k.replace(".key", "").replace("_key", "").replace(".", "_")
|
|
attention_outputs.append(present_kv)
|
|
self.split_kv(present_k, present_v, present_kv)
|
|
|
|
attention_node = helper.make_node(
|
|
"Attention",
|
|
inputs=attention_inputs,
|
|
outputs=attention_outputs,
|
|
name=attention_node_name,
|
|
)
|
|
|
|
attention_node.domain = "com.microsoft"
|
|
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
|
|
if causal:
|
|
attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
|
|
|
|
if scale is not None:
|
|
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
|
|
|
|
if is_qkv_diff_dims:
|
|
attention_node.attribute.extend(
|
|
[helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
|
|
)
|
|
|
|
if self.mask_filter_value is not None:
|
|
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
|
|
return attention_node
|
|
|
|
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
|
|
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
|
|
start_node = normalize_node
|
|
if normalize_node.op_type == "LayerNormalization":
|
|
add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
|
|
if add_before_layernorm is not None:
|
|
start_node = add_before_layernorm
|
|
else:
|
|
return
|
|
|
|
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
qkv_nodes = self.model.match_parent_path(
|
|
start_node,
|
|
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
[None, None, 0, 0, 0],
|
|
)
|
|
einsum_node = None
|
|
if qkv_nodes is not None:
|
|
(_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
else:
|
|
# Match Albert
|
|
qkv_nodes = self.model.match_parent_path(
|
|
start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
|
|
)
|
|
if qkv_nodes is not None:
|
|
(_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
else:
|
|
return
|
|
|
|
other_inputs = []
|
|
for _i, input in enumerate(start_node.input):
|
|
if input not in output_name_to_node:
|
|
continue
|
|
|
|
if input == qkv_nodes[0].output[0]:
|
|
continue
|
|
other_inputs.append(input)
|
|
if len(other_inputs) != 1:
|
|
return
|
|
|
|
root_input = other_inputs[0]
|
|
"""
|
|
Match flaubert Mask
|
|
|
|
|
Mul --> LayerNormalization --> Attention --> MatMul --> Add
|
|
| |
|
|
| |
|
|
+---------------------------------------------------------
|
|
"""
|
|
mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
|
|
if mul_before_layernorm is not None:
|
|
mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
|
|
if mul_children is not None and len(mul_children) == 2:
|
|
layernorm_node = mul_children[1]
|
|
if layernorm_node.op_type == "LayerNormalization":
|
|
root_input = layernorm_node.output[0]
|
|
else:
|
|
return
|
|
elif mul_children is not None and len(mul_children) == 5:
|
|
root_input = mul_before_layernorm.output[0]
|
|
else:
|
|
return
|
|
elif normalize_node.op_type == "LayerNormalization":
|
|
children = input_name_to_nodes[root_input]
|
|
for child in children:
|
|
if child.op_type == "LayerNormalization":
|
|
root_input = child.output[0]
|
|
|
|
"""
|
|
When Add before the LayerNormalization produces an output
|
|
that is consumed by some other nodes other than the LayerNormalization itself,
|
|
fused SkipLayerNormalization will have several outputs.
|
|
In this case we need to pick the one used in Attention
|
|
|
|
For example, this is the case for ViT
|
|
|
|
SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
|
|
| |
|
|
| |
|
|
+---------------------------------------------------------------------+
|
|
"""
|
|
parent_node = output_name_to_node[root_input]
|
|
if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
|
|
root_input = parent_node.output[0]
|
|
|
|
children = input_name_to_nodes[root_input]
|
|
children_types = [child.op_type for child in children]
|
|
if children_types.count("MatMul") != 3:
|
|
return
|
|
|
|
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
|
|
if v_nodes is None:
|
|
logger.debug("fuse_attention: failed to match v path")
|
|
return
|
|
(_, _, add_v, matmul_v) = v_nodes
|
|
|
|
is_distill = False
|
|
is_distill_add = False
|
|
is_no_mask_attention = False
|
|
qk_paths = {
|
|
"path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
|
|
"path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
|
|
"path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
|
|
"path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
|
|
"path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
|
|
}
|
|
|
|
qk_nodes = None
|
|
for k, v in qk_paths.items():
|
|
qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
|
|
if qk_nodes is None:
|
|
continue
|
|
if k == "path3":
|
|
is_distill = True
|
|
if k == "path4":
|
|
is_distill_add = True
|
|
if k == "path5":
|
|
is_no_mask_attention = True
|
|
break
|
|
|
|
if qk_nodes is None:
|
|
logger.debug("fuse_attention: failed to match qk path")
|
|
return
|
|
|
|
add_qk = None
|
|
matmul_qk = None
|
|
where_qk = None
|
|
if is_distill:
|
|
(_, where_qk, matmul_qk, _) = qk_nodes
|
|
elif is_distill_add:
|
|
(_, add_qk, where_qk, matmul_qk) = qk_nodes
|
|
elif is_no_mask_attention:
|
|
(_, _, matmul_qk) = qk_nodes
|
|
else:
|
|
(_, add_qk, _, matmul_qk) = qk_nodes
|
|
|
|
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
|
|
if q_nodes is None:
|
|
q_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Div", "Transpose", "Reshape", "Add", "MatMul"],
|
|
[0, 0, 0, 0, None],
|
|
)
|
|
if q_nodes is None:
|
|
logger.debug("fuse_attention: failed to match q path")
|
|
return
|
|
reshape_q = q_nodes[-3]
|
|
add_q = q_nodes[-2]
|
|
matmul_q = q_nodes[-1]
|
|
|
|
k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
|
|
if k_nodes is None:
|
|
k_nodes = self.model.match_parent_path(
|
|
matmul_qk,
|
|
["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
|
|
[1, 0, 0, 0, None],
|
|
)
|
|
if k_nodes is None:
|
|
logger.debug("fuse_attention: failed to match k path")
|
|
return
|
|
add_k = k_nodes[-2]
|
|
matmul_k = k_nodes[-1]
|
|
|
|
# Note that Cast might be removed by OnnxRuntime so we match two patterns here.
|
|
mask_nodes = None
|
|
add_qk_str = None
|
|
if is_distill:
|
|
_, mask_nodes, _ = self.model.match_parent_paths(
|
|
where_qk,
|
|
[
|
|
(["Expand", "Reshape", "Equal"], [0, 0, 0]),
|
|
(["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
|
|
(["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
|
|
],
|
|
output_name_to_node,
|
|
)
|
|
elif is_distill_add:
|
|
_, mask_nodes, _ = self.model.match_parent_paths(
|
|
where_qk,
|
|
[
|
|
(["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
|
|
(["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
|
|
],
|
|
output_name_to_node,
|
|
)
|
|
if add_qk is not None:
|
|
add_qk_str = self.get_add_qk_str(add_qk)
|
|
if add_qk_str is None:
|
|
logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
|
|
return
|
|
elif is_no_mask_attention:
|
|
pass
|
|
else:
|
|
_, mask_nodes, _ = self.model.match_parent_paths(
|
|
add_qk,
|
|
[
|
|
(
|
|
["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
|
[None, 0, 1, 0, 0],
|
|
),
|
|
(["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
|
|
],
|
|
output_name_to_node,
|
|
)
|
|
if not is_no_mask_attention and mask_nodes is None:
|
|
logger.debug("fuse_attention: failed to match mask path")
|
|
return
|
|
|
|
if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
|
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
|
if mul_val != -10000:
|
|
self.mask_filter_value = mul_val
|
|
|
|
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
|
|
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
|
|
|
|
attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
|
|
|
|
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
if q_num_heads <= 0 or q_hidden_size <= 0:
|
|
logger.warning(
|
|
"Failed to detect num_heads and hidden_size for Attention fusion. "
|
|
"Please specify those parameters in argument."
|
|
)
|
|
return
|
|
|
|
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
|
|
new_node = self.create_attention_node(
|
|
mask_index,
|
|
matmul_q,
|
|
matmul_k,
|
|
matmul_v,
|
|
add_q,
|
|
add_k,
|
|
add_v,
|
|
q_num_heads,
|
|
q_hidden_size,
|
|
root_input,
|
|
attention_last_node.output[0],
|
|
add_qk_str,
|
|
)
|
|
if new_node is None:
|
|
return
|
|
|
|
self.nodes_to_add.append(new_node)
|
|
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
|
|
if einsum_node is not None:
|
|
unique_index = einsum_node.input[0]
|
|
new_edge = "edge_modified_" + unique_index
|
|
|
|
shape_tensor = self.add_initializer(
|
|
name="shape_modified_tensor" + unique_index,
|
|
data_type=TensorProto.INT64,
|
|
dims=[4],
|
|
vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]),
|
|
raw=False,
|
|
)
|
|
|
|
self.model.add_node(
|
|
helper.make_node(
|
|
"Reshape",
|
|
[attention_last_node.output[0], shape_tensor.name],
|
|
[new_edge],
|
|
"reshape_modified_" + unique_index,
|
|
),
|
|
self.this_graph_name,
|
|
)
|
|
einsum_node.input[0] = new_edge
|
|
|
|
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
|
self.nodes_to_remove.extend(qk_nodes)
|
|
|
|
# For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
|
|
self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
|
|
self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
|
|
self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
|
|
|
|
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
self.prune_graph = True
|