I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))

View File

@ -0,0 +1,40 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Get/Set cpu affinity. Currently only support part of Unix system
import logging
import os
logger = logging.getLogger(__name__)
class AffinitySetting:
def __init__(self):
self.pid = os.getpid()
self.affinity = None
self.is_os_supported = hasattr(os, "sched_getaffinity") and hasattr(os, "sched_setaffinity")
if not self.is_os_supported:
logger.warning("Current OS does not support os.get_affinity() and os.set_affinity()")
def get_affinity(self):
if self.is_os_supported:
self.affinity = os.sched_getaffinity(self.pid)
def set_affinity(self):
if self.is_os_supported:
current_affinity = os.sched_getaffinity(self.pid)
if self.affinity != current_affinity:
logger.warning(
"Replacing affinity setting %s with %s",
str(current_affinity),
str(self.affinity),
)
os.sched_setaffinity(self.pid, self.affinity)
if __name__ == "__main__":
affi_helper = AffinitySetting()
affi_helper.get_affinity()
affi_helper.set_affinity()

View File

@ -0,0 +1,944 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the inference of pretrained transformer models.
PyTorch/TorchScript benchmark is based on https://github.com/huggingface/transformers/blob/master/examples/benchmarks.py.
One difference is that random input_ids is generated in this benchmark.
For onnxruntime, this script will convert a pretrained model to ONNX, and optimize it when -o parameter is used.
Example commands:
Export all models to ONNX, optimize and validate them:
python benchmark.py -b 0 -o -v -i 1 2 3
Run OnnxRuntime on GPU for all models:
python benchmark.py -g
Run OnnxRuntime on GPU for all models with fp32 optimization:
python benchmark.py -g -o
Run OnnxRuntime on GPU with fp16 optimization:
python benchmark.py -g -o -p "fp16"
Run TorchScript on GPU for all models:
python benchmark.py -e torchscript -g
Run TorchScript on GPU for all models with fp16:
python benchmark.py -e torchscript -g -p "fp16"
Run ONNXRuntime and TorchScript on CPU for all models with quantization:
python benchmark.py -e torchscript onnxruntime -p "int8" -o
Run OnnxRuntime with the ROCM provider and graph optimization script:
python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm
Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support:
python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm
It is recommended to use run_benchmark.sh to launch benchmark.
"""
import argparse
import logging
import os
import timeit
from datetime import datetime
import numpy
import psutil
from benchmark_helper import (
ConfigModifier,
OptimizerInfo,
Precision,
create_onnxruntime_session,
get_latency_result,
inference_ort,
inference_ort_with_io_binding,
output_details,
output_fusion_statistics,
output_summary,
setup_logger,
)
from fusion_options import FusionOptions
from huggingface_models import MODEL_CLASSES, MODELS
from onnx_exporter import (
create_onnxruntime_input,
export_onnx_model_from_pt,
export_onnx_model_from_tf,
load_pretrained_model,
)
from packaging import version
from quantize_helper import QuantizeHelper
logger = logging.getLogger("")
cpu_count = psutil.cpu_count(logical=False)
# Set OMP environment variable before importing onnxruntime or torch.
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = str(cpu_count)
import torch # noqa: E402
from transformers import AutoConfig, AutoTokenizer, LxmertConfig # noqa: E402
def run_onnxruntime(
use_gpu,
provider,
model_names,
model_class,
config_modifier,
precision,
num_threads,
batch_sizes,
sequence_lengths,
repeat_times,
input_counts,
optimizer_info,
validate_onnx,
cache_dir,
onnx_dir,
verbose,
overwrite,
disable_ort_io_binding,
use_raw_attention_mask,
model_fusion_statistics,
model_source,
enable_arm64_bfloat16_fastmath_mlas_gemm,
args,
):
import onnxruntime
results = []
if (
use_gpu
and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers())
and ("ROCMExecutionProvider" not in onnxruntime.get_available_providers())
and ("DmlExecutionProvider" not in onnxruntime.get_available_providers())
):
logger.error(
"Please install onnxruntime-gpu or onnxruntime-directml package instead of onnxruntime, and use a machine with GPU for testing gpu performance."
)
return results
warm_up_repeat = 0
if provider == "tensorrt":
optimizer_info = OptimizerInfo.NOOPT
warm_up_repeat = 5
if "TensorrtExecutionProvider" not in onnxruntime.get_available_providers():
logger.error(
"Please install onnxruntime-gpu-tensorrt package, and use a machine with GPU for testing gpu performance."
)
return results
if optimizer_info == OptimizerInfo.NOOPT:
logger.warning(
f"OptimizerInfo is set to {optimizer_info}, graph optimizations specified in FusionOptions are not applied."
)
for model_name in model_names:
all_input_names = MODELS[model_name][0]
for num_inputs in input_counts:
if num_inputs > len(all_input_names):
break
input_names = all_input_names[:num_inputs]
args.model_type = MODELS[model_name][3]
fusion_options = FusionOptions.parse(args)
if "pt" in model_source:
with torch.no_grad():
(
onnx_model_file,
is_valid_onnx_model,
vocab_size,
max_sequence_length,
) = export_onnx_model_from_pt(
model_name,
MODELS[model_name][1],
MODELS[model_name][2],
MODELS[model_name][3],
model_class,
config_modifier,
cache_dir,
onnx_dir,
input_names,
use_gpu,
precision,
optimizer_info,
validate_onnx,
use_raw_attention_mask,
overwrite,
model_fusion_statistics,
fusion_options,
)
if "tf" in model_source:
(
onnx_model_file,
is_valid_onnx_model,
vocab_size,
max_sequence_length,
) = export_onnx_model_from_tf(
model_name,
MODELS[model_name][1],
MODELS[model_name][2],
MODELS[model_name][3],
model_class,
config_modifier,
cache_dir,
onnx_dir,
input_names,
use_gpu,
precision,
optimizer_info,
validate_onnx,
use_raw_attention_mask,
overwrite,
model_fusion_statistics,
fusion_options,
)
if not is_valid_onnx_model:
continue
ort_session = create_onnxruntime_session(
onnx_model_file,
use_gpu,
provider,
enable_all_optimization=True,
num_threads=num_threads,
verbose=verbose,
enable_mlas_gemm_fastmath_arm64_bfloat16=enable_arm64_bfloat16_fastmath_mlas_gemm,
)
if ort_session is None:
continue
ort_output_names = [node_arg.name for node_arg in ort_session.get_outputs()]
output_buffers = []
device = "cuda" if use_gpu else "cpu"
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
max_last_state_size = numpy.prod(
[
max(batch_sizes),
max(sequence_lengths),
max(vocab_size, config.hidden_size),
]
)
max_pooler_size = numpy.prod([max(batch_sizes), config.hidden_size])
for batch_size in batch_sizes:
if batch_size <= 0:
continue
for sequence_length in sequence_lengths:
if max_sequence_length is not None and sequence_length > max_sequence_length:
continue
input_value_type = numpy.int64 if "pt" in model_source else numpy.int32
ort_inputs = create_onnxruntime_input(
vocab_size,
batch_size,
sequence_length,
input_names,
config,
input_value_type,
)
result_template = {
"engine": "onnxruntime",
"version": onnxruntime.__version__,
"providers": provider,
"device": device,
"optimizer": optimizer_info,
"precision": precision,
"io_binding": not disable_ort_io_binding,
"model_name": model_name,
"inputs": num_inputs,
"threads": num_threads,
"batch_size": batch_size,
"sequence_length": sequence_length,
"custom_layer_num": config_modifier.get_layer_num(),
"datetime": str(datetime.now()),
}
if config.model_type in ["vit", "swin"]:
logger.info(
f"Run onnxruntime on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
)
else:
logger.info(f"Run onnxruntime on {model_name} with input shape {[batch_size, sequence_length]}")
if disable_ort_io_binding:
result = inference_ort(
ort_session,
ort_inputs,
result_template,
repeat_times,
batch_size,
warm_up_repeat,
)
else:
# Get output sizes from a dummy ort run
ort_outputs = ort_session.run(ort_output_names, ort_inputs)
output_buffer_max_sizes = [max_last_state_size]
for i in range(len(ort_outputs)):
if i == 2 and MODELS[model_name][3] == "gpt":
# past state output max size
output_buffer_max_sizes.append(max_pooler_size)
else:
output_buffer_max_sizes.append(max_last_state_size)
data_type = numpy.longlong if "pt" in model_source else numpy.intc
result = inference_ort_with_io_binding(
ort_session,
ort_inputs,
result_template,
repeat_times,
ort_output_names,
ort_outputs,
output_buffers,
output_buffer_max_sizes,
batch_size,
device,
data_type,
warm_up_repeat,
)
logger.info(result)
results.append(result)
return results
def run_pytorch(
use_gpu,
model_names,
model_class,
config_modifier,
precision,
num_threads,
batch_sizes,
sequence_lengths,
repeat_times,
torchscript,
torch2,
cache_dir,
verbose,
):
results = []
if use_gpu and not torch.cuda.is_available():
logger.error("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
return results
torch.set_grad_enabled(False)
for model_name in model_names:
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript, cache_dir=cache_dir)
config_modifier.modify(config)
model = load_pretrained_model(
model_name,
config=config,
cache_dir=cache_dir,
custom_model_class=model_class,
)
if config.model_type in ["vit", "swin"]:
# These models don't use sequence lengths, so just pick the first sequence length so that the summary still works
sequence_lengths = [sequence_lengths[0]]
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024)
logger.debug(f"Model {model}")
logger.debug(f"Number of parameters {model.num_parameters()}")
if precision == Precision.FLOAT16:
model.half()
device = torch.device("cuda:0" if use_gpu else "cpu")
model.to(device)
if precision == Precision.INT8:
model = QuantizeHelper.quantize_torch_model(model)
for batch_size in batch_sizes:
if batch_size <= 0:
continue
for sequence_length in sequence_lengths:
if config.model_type in ["vit", "swin"]:
logger.info(
f"Run PyTorch on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
)
input_ids = torch.randn(
size=(batch_size, 3, config.image_size, config.image_size),
dtype=torch.float16 if precision == Precision.FLOAT16 else torch.float32,
device=device,
)
else:
if max_input_size is not None and sequence_length > max_input_size:
continue
logger.info(f"Run PyTorch on {model_name} with input shape {[batch_size, sequence_length]}")
input_ids = torch.randint(
low=0,
high=config.vocab_size - 1,
size=(batch_size, sequence_length),
dtype=torch.long,
device=device,
)
try:
inference = (
torch.jit.trace(model, input_ids) if torchscript else torch.compile(model) if torch2 else model
)
inference(input_ids)
runtimes = timeit.repeat(lambda: inference(input_ids), repeat=repeat_times, number=1) # noqa: B023
result = {
"engine": "torchscript" if torchscript else "torch2" if torch2 else "torch",
"version": torch.__version__,
"providers": "NA",
"device": "cuda" if use_gpu else "cpu",
"optimizer": "",
"precision": precision,
"io_binding": "",
"model_name": model_name,
"inputs": 1,
"threads": num_threads,
"batch_size": batch_size,
"sequence_length": sequence_length,
"custom_layer_num": config_modifier.get_layer_num(),
"datetime": str(datetime.now()),
}
result.update(get_latency_result(runtimes, batch_size))
logger.info(result)
results.append(result)
except RuntimeError as e:
logger.exception(e)
torch.cuda.empty_cache()
return results
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
from functools import wraps
import tensorflow as tf
def run_func(func):
@wraps(func)
def run_in_eager_mode(*args, **kwargs):
return func(*args, **kwargs)
@wraps(func)
@tf.function(experimental_compile=use_xla)
def run_in_graph_mode(*args, **kwargs):
return func(*args, **kwargs)
if do_eager_mode is True:
assert (
use_xla is False
), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
return run_in_eager_mode
else:
return run_in_graph_mode
return run_func
def run_tensorflow(
use_gpu,
model_names,
model_class,
config_modifier,
precision,
num_threads,
batch_sizes,
sequence_lengths,
repeat_times,
cache_dir,
verbose,
):
results = []
import tensorflow as tf
tf.config.threading.set_intra_op_parallelism_threads(num_threads)
if not use_gpu:
tf.config.set_visible_devices([], "GPU")
if use_gpu and not tf.test.is_built_with_cuda():
logger.error("Please install Tensorflow-gpu, and use a machine with GPU for testing gpu performance.")
return results
if use_gpu: # Restrict TensorFlow to only use the first GPU
physical_devices = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices(physical_devices[0], "GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.distribute.OneDeviceStrategy(device="/gpu:0")
except RuntimeError as e:
logger.exception(e)
if precision == Precision.FLOAT16 or precision == Precision.INT8:
raise NotImplementedError("Mixed precision is currently not supported.")
for model_name in model_names:
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
config_modifier.modify(config)
model = load_pretrained_model(
model_name,
config=config,
cache_dir=cache_dir,
custom_model_class=model_class,
is_tf_model=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024)
for batch_size in batch_sizes:
if batch_size <= 0:
continue
for sequence_length in sequence_lengths:
if max_input_size is not None and sequence_length > max_input_size:
continue
logger.info(f"Run Tensorflow on {model_name} with input shape {[batch_size, sequence_length]}")
import random
rng = random.Random()
values = [rng.randint(0, config.vocab_size - 1) for i in range(batch_size * sequence_length)]
input_ids = tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
try:
# Disable both for better inference perf
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
def encoder_forward():
return model(input_ids, training=False) # noqa: B023
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
def encoder_decoder_forward():
return model(input_ids, decoder_input_ids=input_ids, training=False) # noqa: B023
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
def lxmert_forward():
feats = tf.random.normal([1, 1, config.visual_feat_dim]) # noqa: B023
pos = tf.random.normal([1, 1, config.visual_pos_dim]) # noqa: B023
return model( # noqa: B023
input_ids, # noqa: B023
visual_feats=feats,
visual_pos=pos,
training=False,
)
inference = encoder_forward
if config.is_encoder_decoder:
inference = encoder_decoder_forward
elif isinstance(config, LxmertConfig):
inference = lxmert_forward
inference()
runtimes = timeit.repeat(lambda: inference(), repeat=repeat_times, number=1) # noqa: B023
result = {
"engine": "tensorflow",
"version": tf.__version__,
"providers": "NA",
"device": "cuda" if use_gpu else "cpu",
"optimizer": "",
"precision": precision,
"io_binding": "",
"model_name": model_name,
"inputs": 1,
"threads": num_threads,
"batch_size": batch_size,
"sequence_length": sequence_length,
"custom_layer_num": config_modifier.get_layer_num(),
"datetime": str(datetime.now()),
}
result.update(get_latency_result(runtimes, batch_size))
logger.info(result)
results.append(result)
except RuntimeError as e:
logger.exception(e)
from numba import cuda
device = cuda.get_current_device()
device.reset()
return results
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--models",
required=False,
nargs="+",
type=str,
default=["bert-base-cased", "roberta-base", "gpt2"],
choices=list(MODELS.keys()),
help="Pre-trained models in the list: " + ", ".join(MODELS.keys()),
)
parser.add_argument(
"--model_source",
required=False,
nargs=1,
type=str,
default="pt",
choices=["pt", "tf"],
help="Export onnx from pt or tf",
)
parser.add_argument(
"--model_class",
required=False,
type=str,
default=None,
choices=list(MODEL_CLASSES),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
)
parser.add_argument(
"-e",
"--engines",
required=False,
nargs="+",
type=str,
default=["onnxruntime"],
choices=["onnxruntime", "torch", "torch2", "torchscript", "tensorflow"],
help="Engines to benchmark",
)
parser.add_argument(
"-c",
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--onnx_dir",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Directory to store onnx models",
)
parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="Run on gpu device")
parser.add_argument(
"--provider",
required=False,
type=str,
default=None,
help="Execution provider to use",
)
parser.add_argument(
"-p",
"--precision",
type=Precision,
default=Precision.FLOAT32,
choices=list(Precision),
help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization",
)
parser.add_argument("--verbose", required=False, action="store_true", help="Print more information")
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing models",
)
parser.add_argument(
"-o",
"--optimizer_info",
type=OptimizerInfo,
default=OptimizerInfo.BYSCRIPT,
choices=list(OptimizerInfo),
help="Optimizer info: Use optimizer.py to optimize onnx model as default. Can also choose from by_ort and no_opt",
)
parser.add_argument(
"-v",
"--validate_onnx",
required=False,
action="store_true",
help="Validate ONNX model",
)
parser.add_argument(
"-f",
"--fusion_csv",
required=False,
default=None,
help="CSV file for saving summary results of graph optimization.",
)
parser.add_argument(
"-d",
"--detail_csv",
required=False,
default=None,
help="CSV file for saving detail results.",
)
parser.add_argument(
"-r",
"--result_csv",
required=False,
default=None,
help="CSV file for saving summary results.",
)
parser.add_argument(
"-i",
"--input_counts",
required=False,
nargs="+",
default=[1],
type=int,
choices=[1, 2, 3],
help="Number of ONNX model inputs. Please use 1 for fair comparison with Torch or TorchScript.",
)
parser.add_argument(
"-t",
"--test_times",
required=False,
default=100,
type=int,
help="Number of repeat times to get average inference latency.",
)
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
parser.add_argument(
"-s",
"--sequence_lengths",
nargs="+",
type=int,
default=[4, 8, 16, 32, 64, 128, 256],
)
parser.add_argument(
"--disable_ort_io_binding",
required=False,
action="store_true",
help="Disable running ONNX Runtime with binded inputs and outputs. ",
)
parser.set_defaults(disable_ort_io_binding=False)
parser.add_argument(
"-n",
"--num_threads",
required=False,
nargs="+",
type=int,
default=[0],
help="Threads to use",
)
parser.add_argument(
"--force_num_layers",
required=False,
type=int,
default=None,
help="Manually set the model's layer number",
)
parser.add_argument(
"--enable_arm64_bfloat16_fastmath_mlas_gemm",
required=False,
action="store_true",
help="Enable bfloat16 mlas gemm kernels on aarch64. Supported only for CPU EP ",
)
parser.set_defaults(enable_arm64_bfloat16_fastmath_mlas_gemm=False)
FusionOptions.add_arguments(parser)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
setup_logger(args.verbose)
if args.precision == Precision.FLOAT16 and not args.use_gpu:
logger.error("fp16 is for GPU only")
return
if args.precision == Precision.INT8 and args.use_gpu and args.provider != "migraphx":
logger.error("int8 is for CPU only")
return
if len(args.models) == 1 and MODELS[args.models[0]][3] in ["vit", "swim"]:
args.sequence_lengths = [""]
args.num_threads = sorted({cpu_count if x <= 0 else x for x in args.num_threads})
logger.info(f"Arguments: {args}")
if not os.path.exists(args.cache_dir):
try:
os.mkdir(args.cache_dir)
except OSError:
logger.error("Creation of the directory %s failed", args.cache_dir)
enable_torch = "torch" in args.engines
enable_torch2 = "torch2" in args.engines
enable_torchscript = "torchscript" in args.engines
enable_onnxruntime = "onnxruntime" in args.engines
enable_tensorflow = "tensorflow" in args.engines
if enable_torch2 and version.parse(torch.__version__) < version.parse("2.0.0"):
logger.error(f"PyTorch version must be >=2.0.0 and you are using {torch.__version__}")
return
config_modifier = ConfigModifier(args.force_num_layers)
results = []
for num_threads in args.num_threads:
torch.set_num_threads(num_threads)
logger.debug(torch.__config__.parallel_info())
if enable_torch or enable_torch2 or enable_torchscript:
if args.input_counts != [1]:
logger.warning("--input_counts is not implemented for torch or torchscript engine.")
if enable_torchscript:
results += run_pytorch(
args.use_gpu,
args.models,
args.model_class,
config_modifier,
args.precision,
num_threads,
args.batch_sizes,
args.sequence_lengths,
args.test_times,
True,
False,
args.cache_dir,
args.verbose,
)
if enable_torch:
results += run_pytorch(
args.use_gpu,
args.models,
args.model_class,
config_modifier,
args.precision,
num_threads,
args.batch_sizes,
args.sequence_lengths,
args.test_times,
False,
False,
args.cache_dir,
args.verbose,
)
if enable_torch2:
results += run_pytorch(
args.use_gpu,
args.models,
args.model_class,
config_modifier,
args.precision,
num_threads,
args.batch_sizes,
args.sequence_lengths,
args.test_times,
False,
True,
args.cache_dir,
args.verbose,
)
if enable_tensorflow:
results += run_tensorflow(
args.use_gpu,
args.models,
args.model_class,
config_modifier,
args.precision,
num_threads,
args.batch_sizes,
args.sequence_lengths,
args.test_times,
args.cache_dir,
args.verbose,
)
model_fusion_statistics = {}
if enable_onnxruntime:
try:
use_raw_attention_mask = not args.use_mask_index
results += run_onnxruntime(
args.use_gpu,
args.provider,
args.models,
args.model_class,
config_modifier,
args.precision,
num_threads,
args.batch_sizes,
args.sequence_lengths,
args.test_times,
args.input_counts,
args.optimizer_info,
args.validate_onnx,
args.cache_dir,
args.onnx_dir,
args.verbose,
args.overwrite,
args.disable_ort_io_binding,
use_raw_attention_mask,
model_fusion_statistics,
args.model_source,
args.enable_arm64_bfloat16_fastmath_mlas_gemm,
args,
)
except Exception:
logger.exception("Exception")
time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
if model_fusion_statistics:
csv_filename = args.fusion_csv or f"benchmark_fusion_{time_stamp}.csv"
output_fusion_statistics(model_fusion_statistics, csv_filename)
if len(results) == 0:
if args.batch_sizes != [0]:
logger.warning("No any result available.")
return
csv_filename = args.detail_csv or f"benchmark_detail_{time_stamp}.csv"
output_details(results, csv_filename)
csv_filename = args.result_csv or f"benchmark_summary_{time_stamp}.csv"
output_summary(results, csv_filename, args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,646 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import csv
import logging
import os
import random
import sys
import time
import timeit
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from time import sleep
from typing import Any, Dict, List, Optional
import coloredlogs
import numpy
import torch
import transformers
from packaging import version
import onnxruntime
logger = logging.getLogger(__name__)
class Precision(Enum):
FLOAT32 = "fp32"
FLOAT16 = "fp16"
INT8 = "int8"
INT4 = "int4"
def __str__(self):
return self.value
class OptimizerInfo(Enum):
# no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
# graph optimization level is not 0 (disable all).
NOOPT = "no_opt"
BYORT = "by_ort"
BYSCRIPT = "by_script"
def __str__(self):
return self.value
class ConfigModifier:
def __init__(self, num_layers):
self.num_layers = num_layers
def modify(self, config):
if self.num_layers is None:
return
if hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = self.num_layers
logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
if hasattr(config, "encoder_layers"):
config.encoder_layers = self.num_layers
logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
if hasattr(config, "decoder_layers "):
config.decoder_layers = self.num_layers
logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
def get_layer_num(self):
return self.num_layers
IO_BINDING_DATA_TYPE_MAP = {
"float32": numpy.float32,
# TODO: Add more.
}
def create_onnxruntime_session(
onnx_model_path,
use_gpu,
provider=None,
enable_all_optimization=True,
num_threads=-1,
enable_profiling=False,
verbose=False,
enable_mlas_gemm_fastmath_arm64_bfloat16=False,
provider_options={}, # map execution provider name to its option # noqa: B006
):
session = None
try:
sess_options = onnxruntime.SessionOptions()
if enable_all_optimization:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
else:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
if enable_profiling:
sess_options.enable_profiling = True
if num_threads > 0:
sess_options.intra_op_num_threads = num_threads
logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
if verbose:
sess_options.log_severity_level = 0
else:
sess_options.log_severity_level = 4
logger.debug(f"Create session for onnx model: {onnx_model_path}")
if use_gpu:
if provider == "dml":
providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
elif provider == "rocm":
providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
elif provider == "migraphx":
providers = [
"MIGraphXExecutionProvider",
"ROCMExecutionProvider",
"CPUExecutionProvider",
]
elif provider == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
elif provider == "tensorrt":
providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
else:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
if provider_options:
providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
if enable_mlas_gemm_fastmath_arm64_bfloat16:
sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1")
session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
except Exception:
logger.error("Exception", exc_info=True) # noqa: G201
return session
def setup_logger(verbose=True):
if verbose:
coloredlogs.install(
level="DEBUG",
fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
)
else:
coloredlogs.install(fmt="%(message)s")
logging.getLogger("transformers").setLevel(logging.WARNING)
def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if use_gpu:
if provider == "dml":
assert (
"DmlExecutionProvider" in onnxruntime.get_available_providers()
), "Please install onnxruntime-directml package to test GPU inference."
else:
assert not set(onnxruntime.get_available_providers()).isdisjoint(
["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"]
), "Please install onnxruntime-gpu package, or install ROCm support, to test GPU inference."
logger.info(f"PyTorch Version:{torch.__version__}")
logger.info(f"Transformers Version:{transformers.__version__}")
logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
# Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
assert version.parse(torch.__version__) >= version.parse("1.10.0")
assert version.parse(transformers.__version__) >= version.parse("4.12.0")
assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
def get_latency_result(latency_list, batch_size):
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
throughput = batch_size * (1000.0 / latency_ms)
return {
"test_times": len(latency_list),
"latency_variance": f"{latency_variance:.2f}",
"latency_90_percentile": f"{numpy.percentile(latency_list, 90) * 1000.0:.2f}",
"latency_95_percentile": f"{numpy.percentile(latency_list, 95) * 1000.0:.2f}",
"latency_99_percentile": f"{numpy.percentile(latency_list, 99) * 1000.0:.2f}",
"average_latency_ms": f"{latency_ms:.2f}",
"QPS": f"{throughput:.2f}",
}
def output_details(results, csv_filename):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = [
"engine",
"version",
"providers",
"device",
"precision",
"optimizer",
"io_binding",
"model_name",
"inputs",
"threads",
"batch_size",
"sequence_length",
"custom_layer_num",
"datetime",
"test_times",
"QPS",
"average_latency_ms",
"latency_variance",
"latency_90_percentile",
"latency_95_percentile",
"latency_99_percentile",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for result in results:
csv_writer.writerow(result)
logger.info(f"Detail results are saved to csv file: {csv_filename}")
def output_summary(results, csv_filename, args):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
header_names = [
"model_name",
"inputs",
"custom_layer_num",
"engine",
"version",
"providers",
"device",
"precision",
"optimizer",
"io_binding",
"threads",
]
data_names = []
for batch_size in args.batch_sizes:
if args.sequence_lengths == [""]:
data_names.append(f"b{batch_size}")
else:
for sequence_length in args.sequence_lengths:
data_names.append(f"b{batch_size}_s{sequence_length}")
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
csv_writer.writeheader()
for model_name in args.models:
for input_count in [1, 2, 3]:
for engine_name in args.engines:
for io_binding in [True, False, ""]:
for threads in args.num_threads:
row = {}
for result in results:
if (
result["model_name"] == model_name
and result["inputs"] == input_count
and result["engine"] == engine_name
and result["io_binding"] == io_binding
and result["threads"] == threads
):
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
row.update({k: "" for k in data_names})
else:
for k in header_names:
assert row[k] == headers[k]
b = result["batch_size"]
s = result["sequence_length"]
if s:
row[f"b{b}_s{s}"] = result["average_latency_ms"]
else:
row[f"b{b}"] = result["average_latency_ms"]
if row:
csv_writer.writerow(row)
logger.info(f"Summary results are saved to csv file: {csv_filename}")
def output_fusion_statistics(model_fusion_statistics, csv_filename):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = [
"model_filename",
"datetime",
"transformers",
"torch",
*list(next(iter(model_fusion_statistics.values())).keys()),
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for key in model_fusion_statistics:
model_fusion_statistics[key]["datetime"] = str(datetime.now())
model_fusion_statistics[key]["transformers"] = transformers.__version__
model_fusion_statistics[key]["torch"] = torch.__version__
model_fusion_statistics[key]["model_filename"] = key
csv_writer.writerow(model_fusion_statistics[key])
logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
result = {}
timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
result.update(result_template)
result.update({"io_binding": False})
result.update(get_latency_result(latency_list, batch_size))
return result
def inference_ort_with_io_binding(
ort_session,
ort_inputs,
result_template,
repeat_times,
ort_output_names,
ort_outputs,
output_buffers,
output_buffer_max_sizes,
batch_size,
device,
data_type=numpy.longlong,
warm_up_repeat=0,
):
result = {}
# Bind inputs and outputs to onnxruntime session
io_binding = ort_session.io_binding()
# Bind inputs to device
for name in ort_inputs:
np_input = torch.from_numpy(ort_inputs[name]).to(device)
input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type)
io_binding.bind_input(
name,
np_input.device.type,
0,
input_type,
np_input.shape,
np_input.data_ptr(),
)
# Bind outputs buffers with the sizes needed if not allocated already
if len(output_buffers) == 0:
allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
for i, ort_output_name in enumerate(ort_output_names):
io_binding.bind_output(
ort_output_name,
output_buffers[i].device.type,
0,
numpy.float32,
ort_outputs[i].shape,
output_buffers[i].data_ptr(),
)
timeit.repeat(
lambda: ort_session.run_with_iobinding(io_binding),
number=1,
repeat=warm_up_repeat,
) # Dry run
latency_list = timeit.repeat(
lambda: ort_session.run_with_iobinding(io_binding),
number=1,
repeat=repeat_times,
)
result.update(result_template)
result.update({"io_binding": True})
result.update(get_latency_result(latency_list, batch_size))
return result
def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device): # noqa: N802
# Allocate output tensors with the largest test size needed. So the allocated memory can be reused
# for each test run.
for i in output_buffer_max_sizes:
output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
def set_random_seed(seed=123):
"""Set random seed manually to get deterministic results"""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.enabled = False
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
from py3nvml.py3nvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
nvmlInit,
nvmlShutdown,
)
try:
nvmlInit()
result = []
device_count = nvmlDeviceGetCount()
if not isinstance(device_count, int):
return None
for i in range(device_count):
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
if isinstance(info, str):
return None
result.append(
{
"id": i,
"name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
"total": info.total,
"free": info.free,
"used": info.used,
}
)
nvmlShutdown()
return result
except NVMLError as error:
print("Error fetching GPU information using nvml: %s", error)
return None
class MemoryMonitor(ABC):
def __init__(self, keep_measuring=True):
self.keep_measuring = keep_measuring
def measure_cpu_usage(self):
import psutil
max_usage = 0
while True:
max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
sleep(0.005) # 5ms
if not self.keep_measuring:
break
return max_usage
@abstractmethod
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
raise NotImplementedError()
class CudaMemoryMonitor(MemoryMonitor):
def __init__(self, keep_measuring=True):
super().__init__(keep_measuring)
def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
from py3nvml.py3nvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
nvmlInit,
nvmlShutdown,
)
max_gpu_usage = []
gpu_name = []
try:
nvmlInit()
device_count = nvmlDeviceGetCount()
if not isinstance(device_count, int):
logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
return None
max_gpu_usage = [0 for i in range(device_count)]
gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
while True:
for i in range(device_count):
info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
if isinstance(info, str):
logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
return None
max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
sleep(0.005) # 5ms
if not self.keep_measuring:
break
nvmlShutdown()
return [
{
"device_id": i,
"name": gpu_name[i],
"max_used_MB": max_gpu_usage[i],
}
for i in range(device_count)
]
except NVMLError as error:
logger.error("Error fetching GPU information using nvml: %s", error)
return None
class RocmMemoryMonitor(MemoryMonitor):
def __init__(self, keep_measuring=True):
super().__init__(keep_measuring)
rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
if os.path.exists(rocm_smi_path):
if rocm_smi_path not in sys.path:
sys.path.append(rocm_smi_path)
try:
import rocm_smi
self.rocm_smi = rocm_smi
self.rocm_smi.initializeRsmi()
except ImportError:
self.rocm_smi = None
def get_used_memory(self, dev):
if self.rocm_smi is None:
return -1
return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
def measure_gpu_usage(self):
if self.rocm_smi is None:
return None
device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
max_gpu_usage = [0 for i in range(device_count)]
gpu_name = [f"GPU{i}" for i in range(device_count)]
while True:
for i in range(device_count):
max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
time.sleep(0.005) # 5ms
if not self.keep_measuring:
break
return [
{
"device_id": i,
"name": gpu_name[i],
"max_used_MB": max_gpu_usage[i],
}
for i in range(device_count)
]
def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
memory_monitor_type = None
if monitor_type == "rocm":
memory_monitor_type = RocmMemoryMonitor
else:
memory_monitor_type = CudaMemoryMonitor
monitor = memory_monitor_type(False)
if is_gpu:
if start_memory is not None:
memory_before_test = start_memory
else:
memory_before_test = monitor.measure_gpu_usage()
if memory_before_test is None:
return None
if func is None:
return memory_before_test
with ThreadPoolExecutor() as executor:
monitor = memory_monitor_type()
mem_thread = executor.submit(monitor.measure_gpu_usage)
try:
fn_thread = executor.submit(func)
_ = fn_thread.result()
finally:
monitor.keep_measuring = False
max_usage = mem_thread.result()
if max_usage is None:
return None
logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
# When there are multiple GPUs, we will check the one with maximum usage.
max_used = 0
for i, memory_before in enumerate(memory_before_test):
before = memory_before["max_used_MB"]
after = max_usage[i]["max_used_MB"]
used = after - before
max_used = max(max_used, used)
return max_used
return None
# CPU memory
if start_memory is not None:
memory_before_test = start_memory
else:
memory_before_test = monitor.measure_cpu_usage()
if func is None:
return memory_before_test
with ThreadPoolExecutor() as executor:
monitor = memory_monitor_type()
mem_thread = executor.submit(monitor.measure_cpu_usage)
try:
fn_thread = executor.submit(func)
_ = fn_thread.result()
finally:
monitor.keep_measuring = False
max_usage = mem_thread.result()
logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
return max_usage - memory_before_test
def get_ort_environment_variables():
# Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
env_names = [
"ORT_DISABLE_FUSED_ATTENTION",
"ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
"ORT_DISABLE_TRT_FLASH_ATTENTION",
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
"ORT_TRANSFORMER_OPTIONS",
"ORT_CUDA_GEMM_OPTIONS",
]
env = ""
for name in env_names:
value = os.getenv(name)
if value is None:
continue
if env:
env += ","
env += f"{name}={value}"
return env

View File

@ -0,0 +1,634 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# This tool measures the inference performance of onnxruntime on BERT-like model with inputs like input_ids,
# token_type_ids (optional), and attention_mask (optional).
#
# If the model does not have exactly three inputs like above, you might need specify names of inputs with
# --input_ids_name, --segment_ids_name and --input_mask_name
# Example command to run test on batch_size 1 and 2 for a model on GPU:
# python bert_perf_test.py --model bert.onnx --batch_size 1 2 --sequence_length 128 --use_gpu --samples 1000 --test_times 1
import argparse
import csv
import json
import multiprocessing
import os
import random
import statistics
import timeit
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
import numpy as np
import psutil
import torch
from bert_test_data import generate_test_data, get_bert_inputs
@dataclass
class TestSetting:
batch_size: int
sequence_length: int
test_cases: int
test_times: int
use_gpu: bool
use_io_binding: bool
provider: str
intra_op_num_threads: int
seed: int
verbose: bool
log_severity: int
average_sequence_length: int
random_sequence_length: bool
@dataclass
class ModelSetting:
model_path: str
input_ids_name: str
segment_ids_name: str
input_mask_name: str
opt_level: int
input_tuning_results: Optional[str]
output_tuning_results: Optional[str]
mask_type: int
def create_session(
model_path,
use_gpu,
provider,
intra_op_num_threads,
graph_optimization_level=None,
log_severity=2,
tuning_results_path=None,
):
import onnxruntime
onnxruntime.set_default_logger_severity(log_severity)
if use_gpu and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers()):
print(
"Warning: Please install onnxruntime-gpu package instead of onnxruntime, and use a machine with GPU for testing gpu performance."
)
if use_gpu:
if provider == "dml":
execution_providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
elif provider == "rocm":
execution_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
elif provider == "migraphx":
execution_providers = [
"MIGraphXExecutionProvider",
"ROCMExecutionProvider",
"CPUExecutionProvider",
]
elif provider == "cuda":
execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
elif provider == "tensorrt":
execution_providers = [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
else:
execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
execution_providers = ["CPUExecutionProvider"]
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = log_severity
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
if graph_optimization_level is None:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
elif graph_optimization_level == 0:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
elif graph_optimization_level == 1:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
elif graph_optimization_level == 2:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
elif graph_optimization_level == 99:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
else:
sess_options.graph_optimization_level = graph_optimization_level
if intra_op_num_threads is not None:
sess_options.intra_op_num_threads = intra_op_num_threads
session = onnxruntime.InferenceSession(model_path, sess_options, providers=execution_providers)
if use_gpu:
if provider == "dml":
assert "DmlExecutionProvider" in session.get_providers()
elif provider == "rocm":
assert "ROCMExecutionProvider" in session.get_providers()
elif provider == "migraphx":
assert "MIGraphXExecutionProvider" in session.get_providers()
assert "ROCMExecutionProvider" in session.get_providers()
elif provider == "cuda":
assert "CUDAExecutionProvider" in session.get_providers()
elif provider == "tensorrt":
assert "TensorrtExecutionProvider" in session.get_providers()
assert "CUDAExecutionProvider" in session.get_providers()
else:
assert "CUDAExecutionProvider" in session.get_providers()
else:
assert "CPUExecutionProvider" in session.get_providers()
if tuning_results_path is not None:
with open(tuning_results_path) as f:
session.set_tuning_results(json.load(f))
return session
def numpy_type(torch_type):
type_map = {
torch.float32: np.float32,
torch.float16: np.float16,
torch.int32: np.int32,
torch.int64: np.longlong,
}
return type_map[torch_type]
def create_input_output_tensors(inputs, outputs, device):
input_tensors = {name: torch.from_numpy(array).to(device) for name, array in inputs.items()}
output_tensors = {name: torch.from_numpy(array).to(device) for name, array in outputs.items()}
return input_tensors, output_tensors
def create_io_binding(sess, input_tensors, output_tensors):
io_binding = sess.io_binding()
for name, tensor in input_tensors.items():
io_binding.bind_input(
name,
tensor.device.type,
0,
numpy_type(tensor.dtype),
tensor.shape,
tensor.data_ptr(),
)
for name, tensor in output_tensors.items():
io_binding.bind_output(
name,
tensor.device.type,
0,
numpy_type(tensor.dtype),
tensor.shape,
tensor.data_ptr(),
)
return io_binding
def onnxruntime_inference_with_io_binding(session, all_inputs, output_names, test_setting):
results = []
latency_list = []
device = "cuda" if test_setting.use_gpu else "cpu"
for _test_case_id, inputs in enumerate(all_inputs):
result = session.run(output_names, inputs)
results.append(result)
outputs = {}
for i in range(len(output_names)):
outputs[output_names[i]] = result[i]
input_tensors, output_tensors = create_input_output_tensors(inputs, outputs, device)
io_binding = create_io_binding(session, input_tensors, output_tensors)
# warm up once
session.run_with_iobinding(io_binding)
start_time = timeit.default_timer()
session.run_with_iobinding(io_binding)
latency = timeit.default_timer() - start_time
latency_list.append(latency)
return results, latency_list
def onnxruntime_inference(session, all_inputs, output_names):
if len(all_inputs) > 0:
# Use a random input as warm up.
session.run(output_names, random.choice(all_inputs))
results = []
latency_list = []
for _test_case_id, inputs in enumerate(all_inputs):
start_time = timeit.default_timer()
result = session.run(output_names, inputs)
latency = timeit.default_timer() - start_time
results.append(result)
latency_list.append(latency)
return results, latency_list
def to_string(model_path, session, test_setting):
sess_options = session.get_session_options()
option = f"model={os.path.basename(model_path)},"
option += f"graph_optimization_level={sess_options.graph_optimization_level},intra_op_num_threads={sess_options.intra_op_num_threads},".replace(
"GraphOptimizationLevel.ORT_", ""
)
option += f"batch_size={test_setting.batch_size},sequence_length={test_setting.sequence_length},"
option += f"test_cases={test_setting.test_cases},test_times={test_setting.test_times},"
option += f"use_gpu={test_setting.use_gpu},use_io_binding={test_setting.use_io_binding},"
option += f"average_sequence_length={test_setting.average_sequence_length},"
option += f"random_sequence_length={test_setting.random_sequence_length}"
return option
def run_one_test(model_setting, test_setting, perf_results, all_inputs, intra_op_num_threads):
session = create_session(
model_setting.model_path,
test_setting.use_gpu,
test_setting.provider,
intra_op_num_threads,
model_setting.opt_level,
log_severity=test_setting.log_severity,
tuning_results_path=model_setting.input_tuning_results,
)
output_names = [output.name for output in session.get_outputs()]
key = to_string(model_setting.model_path, session, test_setting)
if key in perf_results:
print("skip duplicated test:", key)
return
print("Running test:", key)
all_latency_list = []
if test_setting.use_io_binding:
for _i in range(test_setting.test_times):
results, latency_list = onnxruntime_inference_with_io_binding(
session, all_inputs, output_names, test_setting
)
all_latency_list.extend(latency_list)
else:
for _i in range(test_setting.test_times):
results, latency_list = onnxruntime_inference(session, all_inputs, output_names)
all_latency_list.extend(latency_list)
# latency in milliseconds
latency_ms = np.array(all_latency_list) * 1000
average_latency = statistics.mean(latency_ms)
latency_50 = np.percentile(latency_ms, 50)
latency_75 = np.percentile(latency_ms, 75)
latency_90 = np.percentile(latency_ms, 90)
latency_95 = np.percentile(latency_ms, 95)
latency_99 = np.percentile(latency_ms, 99)
throughput = test_setting.batch_size * (1000.0 / average_latency)
perf_results[key] = (
average_latency,
latency_50,
latency_75,
latency_90,
latency_95,
latency_99,
throughput,
)
print(
"Average latency = {} ms, Throughput = {} QPS".format(format(average_latency, ".2f"), format(throughput, ".2f"))
)
if model_setting.output_tuning_results:
output_path = os.path.abspath(model_setting.output_tuning_results)
if os.path.exists(output_path):
old_output_path = output_path
output_path = f"""{output_path.rsplit(".json", 1)[0]}.{datetime.now().timestamp()}.json"""
print("WARNING:", old_output_path, "exists, will write to", output_path, "instead.")
trs = session.get_tuning_results()
with open(output_path, "w") as f:
json.dump(trs, f)
print("Tuning results is saved to", output_path)
def launch_test(model_setting, test_setting, perf_results, all_inputs, intra_op_num_threads):
process = multiprocessing.Process(
target=run_one_test,
args=(
model_setting,
test_setting,
perf_results,
all_inputs,
intra_op_num_threads,
),
)
process.start()
process.join()
def run_perf_tests(model_setting, test_setting, perf_results, all_inputs):
if test_setting.intra_op_num_threads is not None:
launch_test(
model_setting,
test_setting,
perf_results,
all_inputs,
test_setting.intra_op_num_threads,
)
return
cpu_count = psutil.cpu_count(logical=False)
logical_cores = psutil.cpu_count(logical=True)
candidate_threads = list({logical_cores, cpu_count})
for i in range(1, min(16, logical_cores)):
if i not in candidate_threads:
candidate_threads.append(i)
candidate_threads.sort(reverse=True)
for intra_op_num_threads in candidate_threads:
launch_test(model_setting, test_setting, perf_results, all_inputs, intra_op_num_threads)
def run_performance(model_setting, test_setting, perf_results):
input_ids, segment_ids, input_mask = get_bert_inputs(
model_setting.model_path,
model_setting.input_ids_name,
model_setting.segment_ids_name,
model_setting.input_mask_name,
)
# Do not generate random mask for performance test.
print(
f"Generating {test_setting.test_cases} samples for batch_size={test_setting.batch_size} sequence_length={test_setting.sequence_length}"
)
all_inputs = generate_test_data(
test_setting.batch_size,
test_setting.sequence_length,
test_setting.test_cases,
test_setting.seed,
test_setting.verbose,
input_ids,
segment_ids,
input_mask,
test_setting.average_sequence_length,
test_setting.random_sequence_length,
mask_type=model_setting.mask_type,
)
run_perf_tests(model_setting, test_setting, perf_results, all_inputs)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, type=str, help="bert onnx model path")
parser.add_argument(
"-b",
"--batch_size",
required=True,
type=int,
nargs="+",
help="batch size of input. Allow one or multiple values in the range of [1, 128].",
)
parser.add_argument(
"-s",
"--sequence_length",
required=True,
type=int,
help="maximum sequence length of input",
)
parser.add_argument(
"--samples",
required=False,
type=int,
default=10,
help="number of samples to be generated",
)
parser.add_argument(
"-t",
"--test_times",
required=False,
type=int,
default=0,
help="number of times to run per sample. By default, the value is 1000 / samples",
)
parser.add_argument(
"--opt_level",
required=False,
type=int,
choices=[0, 1, 2, 99],
default=99,
help="onnxruntime optimization level: 0 - disable all, 1 - basic, 2 - extended, 99 - enable all.",
)
parser.add_argument(
"--seed",
required=False,
type=int,
default=3,
help="random seed. Use the same seed to make sure test data is same in multiple tests.",
)
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="print verbose information",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"--log_severity",
required=False,
type=int,
default=2,
choices=[0, 1, 2, 3, 4],
help="0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal",
)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU")
parser.set_defaults(use_gpu=False)
parser.add_argument("--use_io_binding", required=False, action="store_true", help="use io_binding")
parser.set_defaults(use_io_binding=False)
parser.add_argument(
"--provider",
required=False,
type=str,
default=None,
help="Execution provider to use",
)
parser.add_argument(
"-n",
"--intra_op_num_threads",
required=False,
type=int,
default=None,
help=">=0, set intra_op_num_threads",
)
parser.add_argument(
"--input_ids_name",
required=False,
type=str,
default=None,
help="input name for input ids",
)
parser.add_argument(
"--segment_ids_name",
required=False,
type=str,
default=None,
help="input name for segment ids",
)
parser.add_argument(
"--input_mask_name",
required=False,
type=str,
default=None,
help="input name for attention mask",
)
parser.add_argument(
"--input_tuning_results",
default=None,
type=str,
help="tuning results (json) to be loaded before benchmark",
)
parser.add_argument(
"--output_tuning_results",
default=None,
type=str,
help="tuning results (json) to be saved after benchmark",
)
parser.add_argument(
"-a",
"--average_sequence_length",
default=-1,
type=int,
help="average sequence length excluding padding",
)
parser.add_argument(
"-r",
"--random_sequence_length",
required=False,
action="store_true",
help="use uniform random instead of fixed sequence length",
)
parser.set_defaults(random_sequence_length=False)
parser.add_argument(
"--mask_type",
required=False,
type=int,
default=2,
help="mask type: (1: mask index or sequence length, 2: raw 2D mask, 3: key len, cumulated lengths of query and key)",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
if args.test_times == 0:
args.test_times = max(1, int(1000 / args.samples))
if args.average_sequence_length <= 0:
args.average_sequence_length = args.sequence_length
manager = multiprocessing.Manager()
perf_results = manager.dict()
batch_size_set = set(args.batch_size)
if not (min(batch_size_set) >= 1 and max(batch_size_set) <= 128):
raise Exception("batch_size not in range [1, 128]")
model_setting = ModelSetting(
args.model,
args.input_ids_name,
args.segment_ids_name,
args.input_mask_name,
args.opt_level,
args.input_tuning_results,
args.output_tuning_results,
args.mask_type,
)
for batch_size in batch_size_set:
test_setting = TestSetting(
batch_size,
args.sequence_length,
args.samples,
args.test_times,
args.use_gpu,
args.use_io_binding,
args.provider,
args.intra_op_num_threads,
args.seed,
args.verbose,
args.log_severity,
args.average_sequence_length,
args.random_sequence_length,
)
print("test setting", test_setting)
run_performance(model_setting, test_setting, perf_results)
# Sort the results so that the first one has smallest latency.
sorted_results = sorted(perf_results.items(), reverse=False, key=lambda x: x[1])
summary_file = os.path.join(
Path(args.model).parent,
"perf_results_{}_B{}_S{}_{}.txt".format(
"GPU" if args.use_gpu else "CPU",
"-".join([str(x) for x in sorted(list(batch_size_set))]),
args.sequence_length,
datetime.now().strftime("%Y%m%d-%H%M%S"),
),
)
with open(summary_file, "w+", newline="") as tsv_file:
tsv_writer = csv.writer(tsv_file, delimiter="\t", lineterminator="\n")
headers = None
for key, perf_result in sorted_results:
params = key.split(",")
if headers is None:
headers = [
"Latency(ms)",
"Latency_P50",
"Latency_P75",
"Latency_P90",
"Latency_P95",
"Latency_P99",
"Throughput(QPS)",
]
headers.extend([x.split("=")[0] for x in params])
tsv_writer.writerow(headers)
values = [format(x, ".2f") for x in perf_result]
values.extend([x.split("=")[1] for x in params])
tsv_writer.writerow(values)
print("Test summary is saved to", summary_file)
if __name__ == "__main__":
# work around for AnaConda Jupyter. See https://stackoverflow.com/questions/45720153/python-multiprocessing-error-attributeerror-module-main-has-no-attribute
__spec__ = None
main()

View File

@ -0,0 +1,642 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# It is a tool to generate test data for a bert model.
# The test data can be used by onnxruntime_perf_test tool to evaluate the inference latency.
import argparse
import os
import random
from pathlib import Path
from typing import Dict, Optional, Tuple
import numpy as np
from onnx import ModelProto, TensorProto, numpy_helper
from onnx_model import OnnxModel
def fake_input_ids_data(
input_ids: TensorProto, batch_size: int, sequence_length: int, dictionary_size: int
) -> np.ndarray:
"""Create input tensor based on the graph input of input_ids
Args:
input_ids (TensorProto): graph input of the input_ids input tensor
batch_size (int): batch size
sequence_length (int): sequence length
dictionary_size (int): vocabulary size of dictionary
Returns:
np.ndarray: the input tensor created
"""
assert input_ids.type.tensor_type.elem_type in [
TensorProto.FLOAT,
TensorProto.INT32,
TensorProto.INT64,
]
data = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=np.int32)
if input_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
data = np.float32(data)
elif input_ids.type.tensor_type.elem_type == TensorProto.INT64:
data = np.int64(data)
return data
def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_length: int) -> np.ndarray:
"""Create input tensor based on the graph input of segment_ids
Args:
segment_ids (TensorProto): graph input of the token_type_ids input tensor
batch_size (int): batch size
sequence_length (int): sequence length
Returns:
np.ndarray: the input tensor created
"""
assert segment_ids.type.tensor_type.elem_type in [
TensorProto.FLOAT,
TensorProto.INT32,
TensorProto.INT64,
]
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
if segment_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
data = np.float32(data)
elif segment_ids.type.tensor_type.elem_type == TensorProto.INT64:
data = np.int64(data)
return data
def get_random_length(max_sequence_length: int, average_sequence_length: int):
assert average_sequence_length >= 1 and average_sequence_length <= max_sequence_length
# For uniform distribution, we find proper lower and upper bounds so that the average is in the middle.
if 2 * average_sequence_length > max_sequence_length:
return random.randint(2 * average_sequence_length - max_sequence_length, max_sequence_length)
else:
return random.randint(1, 2 * average_sequence_length - 1)
def fake_input_mask_data(
input_mask: TensorProto,
batch_size: int,
sequence_length: int,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int = 2,
) -> np.ndarray:
"""Create input tensor based on the graph input of segment_ids.
Args:
input_mask (TensorProto): graph input of the attention mask input tensor
batch_size (int): batch size
sequence_length (int): sequence length
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type (int): mask type - 1: mask index (sequence length excluding paddings). Shape is (batch_size).
2: 2D attention mask. Shape is (batch_size, sequence_length).
3: key len, cumulated lengths of query and key. Shape is (3 * batch_size + 2).
Returns:
np.ndarray: the input tensor created
"""
assert input_mask.type.tensor_type.elem_type in [
TensorProto.FLOAT,
TensorProto.INT32,
TensorProto.INT64,
]
if mask_type == 1: # sequence length excluding paddings
data = np.ones((batch_size), dtype=np.int32)
if random_sequence_length:
for i in range(batch_size):
data[i] = get_random_length(sequence_length, average_sequence_length)
else:
for i in range(batch_size):
data[i] = average_sequence_length
elif mask_type == 2: # 2D attention mask
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
if random_sequence_length:
for i in range(batch_size):
actual_seq_len = get_random_length(sequence_length, average_sequence_length)
for j in range(actual_seq_len):
data[i, j] = 1
else:
temp = np.ones((batch_size, average_sequence_length), dtype=np.int32)
data[: temp.shape[0], : temp.shape[1]] = temp
else:
assert mask_type == 3
data = np.zeros((batch_size * 3 + 2), dtype=np.int32)
if random_sequence_length:
for i in range(batch_size):
data[i] = get_random_length(sequence_length, average_sequence_length)
for i in range(batch_size + 1):
data[batch_size + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
data[2 * batch_size + 1 + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
else:
for i in range(batch_size):
data[i] = average_sequence_length
for i in range(batch_size + 1):
data[batch_size + i] = i * average_sequence_length
data[2 * batch_size + 1 + i] = i * average_sequence_length
if input_mask.type.tensor_type.elem_type == TensorProto.FLOAT:
data = np.float32(data)
elif input_mask.type.tensor_type.elem_type == TensorProto.INT64:
data = np.int64(data)
return data
def output_test_data(directory: str, inputs: Dict[str, np.ndarray]):
"""Output input tensors of test data to a directory
Args:
directory (str): path of a directory
inputs (Dict[str, np.ndarray]): map from input name to value
"""
if not os.path.exists(directory):
try:
os.mkdir(directory)
except OSError:
print(f"Creation of the directory {directory} failed")
else:
print(f"Successfully created the directory {directory} ")
else:
print(f"Warning: directory {directory} existed. Files will be overwritten.")
for index, (name, data) in enumerate(inputs.items()):
tensor = numpy_helper.from_array(data, name)
with open(os.path.join(directory, f"input_{index}.pb"), "wb") as file:
file.write(tensor.SerializeToString())
def fake_test_data(
batch_size: int,
sequence_length: int,
test_cases: int,
dictionary_size: int,
verbose: bool,
random_seed: int,
input_ids: TensorProto,
segment_ids: TensorProto,
input_mask: TensorProto,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int,
):
"""Create given number of input data for testing
Args:
batch_size (int): batch size
sequence_length (int): sequence length
test_cases (int): number of test cases
dictionary_size (int): vocabulary size of dictionary for input_ids
verbose (bool): print more information or not
random_seed (int): random seed
input_ids (TensorProto): graph input of input IDs
segment_ids (TensorProto): graph input of token type IDs
input_mask (TensorProto): graph input of attention mask
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
Returns:
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
with input name as key and a tensor as value
"""
assert input_ids is not None
np.random.seed(random_seed)
random.seed(random_seed)
all_inputs = []
for _test_case in range(test_cases):
input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
inputs = {input_ids.name: input_1}
if segment_ids:
inputs[segment_ids.name] = fake_segment_ids_data(segment_ids, batch_size, sequence_length)
if input_mask:
inputs[input_mask.name] = fake_input_mask_data(
input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length, mask_type
)
if verbose and len(all_inputs) == 0:
print("Example inputs", inputs)
all_inputs.append(inputs)
return all_inputs
def generate_test_data(
batch_size: int,
sequence_length: int,
test_cases: int,
seed: int,
verbose: bool,
input_ids: TensorProto,
segment_ids: TensorProto,
input_mask: TensorProto,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int,
):
"""Create given number of input data for testing
Args:
batch_size (int): batch size
sequence_length (int): sequence length
test_cases (int): number of test cases
seed (int): random seed
verbose (bool): print more information or not
input_ids (TensorProto): graph input of input IDs
segment_ids (TensorProto): graph input of token type IDs
input_mask (TensorProto): graph input of attention mask
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
Returns:
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
with input name as key and a tensor as value
"""
dictionary_size = 10000
all_inputs = fake_test_data(
batch_size,
sequence_length,
test_cases,
dictionary_size,
verbose,
seed,
input_ids,
segment_ids,
input_mask,
average_sequence_length,
random_sequence_length,
mask_type,
)
if len(all_inputs) != test_cases:
print("Failed to create test data for test.")
return all_inputs
def get_graph_input_from_embed_node(onnx_model, embed_node, input_index):
if input_index >= len(embed_node.input):
return None
input = embed_node.input[input_index]
graph_input = onnx_model.find_graph_input(input)
if graph_input is None:
parent_node = onnx_model.get_parent(embed_node, input_index)
if parent_node is not None and parent_node.op_type == "Cast":
graph_input = onnx_model.find_graph_input(parent_node.input[0])
return graph_input
def find_bert_inputs(
onnx_model: OnnxModel,
input_ids_name: Optional[str] = None,
segment_ids_name: Optional[str] = None,
input_mask_name: Optional[str] = None,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""Find graph inputs for BERT model.
First, we will deduce inputs from EmbedLayerNormalization node.
If not found, we will guess the meaning of graph inputs based on naming.
Args:
onnx_model (OnnxModel): onnx model object
input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
Raises:
ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
and input_mask_name
Returns:
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
segment_ids and input_mask
"""
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
if input_ids_name is not None:
input_ids = onnx_model.find_graph_input(input_ids_name)
if input_ids is None:
raise ValueError(f"Graph does not have input named {input_ids_name}")
segment_ids = None
if segment_ids_name:
segment_ids = onnx_model.find_graph_input(segment_ids_name)
if segment_ids is None:
raise ValueError(f"Graph does not have input named {segment_ids_name}")
input_mask = None
if input_mask_name:
input_mask = onnx_model.find_graph_input(input_mask_name)
if input_mask is None:
raise ValueError(f"Graph does not have input named {input_mask_name}")
expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else 0)
if len(graph_inputs) != expected_inputs:
raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
return input_ids, segment_ids, input_mask
if len(graph_inputs) != 3:
raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization")
if len(embed_nodes) == 1:
embed_node = embed_nodes[0]
input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0)
segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 1)
input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7)
if input_mask is None:
for input in graph_inputs:
input_name_lower = input.name.lower()
if "mask" in input_name_lower:
input_mask = input
if input_mask is None:
raise ValueError("Failed to find attention mask input")
return input_ids, segment_ids, input_mask
# Try guess the inputs based on naming.
input_ids = None
segment_ids = None
input_mask = None
for input in graph_inputs:
input_name_lower = input.name.lower()
if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask"
input_mask = input
elif (
"token" in input_name_lower or "segment" in input_name_lower
): # matches input with name like "segment_ids" or "token_type_ids"
segment_ids = input
else:
input_ids = input
if input_ids and segment_ids and input_mask:
return input_ids, segment_ids, input_mask
raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
def get_bert_inputs(
onnx_file: str,
input_ids_name: Optional[str] = None,
segment_ids_name: Optional[str] = None,
input_mask_name: Optional[str] = None,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""Find graph inputs for BERT model.
First, we will deduce inputs from EmbedLayerNormalization node.
If not found, we will guess the meaning of graph inputs based on naming.
Args:
onnx_file (str): onnx model path
input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
Returns:
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
segment_ids and input_mask
"""
model = ModelProto()
with open(onnx_file, "rb") as file:
model.ParseFromString(file.read())
onnx_model = OnnxModel(model)
return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
parser.add_argument(
"--output_dir",
required=False,
type=str,
default=None,
help="output test data path. Default is current directory.",
)
parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
parser.add_argument(
"--sequence_length",
required=False,
type=int,
default=128,
help="maximum sequence length of input",
)
parser.add_argument(
"--input_ids_name",
required=False,
type=str,
default=None,
help="input name for input ids",
)
parser.add_argument(
"--segment_ids_name",
required=False,
type=str,
default=None,
help="input name for segment ids",
)
parser.add_argument(
"--input_mask_name",
required=False,
type=str,
default=None,
help="input name for attention mask",
)
parser.add_argument(
"--samples",
required=False,
type=int,
default=1,
help="number of test cases to be generated",
)
parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="print verbose information",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"--only_input_tensors",
required=False,
action="store_true",
help="only save input tensors and no output tensors",
)
parser.set_defaults(only_input_tensors=False)
parser.add_argument(
"-a",
"--average_sequence_length",
default=-1,
type=int,
help="average sequence length excluding padding",
)
parser.add_argument(
"-r",
"--random_sequence_length",
required=False,
action="store_true",
help="use uniform random instead of fixed sequence length",
)
parser.set_defaults(random_sequence_length=False)
parser.add_argument(
"--mask_type",
required=False,
type=int,
default=2,
help="mask type: (1: mask index, 2: raw 2D mask, 3: key lengths, cumulated lengths of query and key)",
)
args = parser.parse_args()
return args
def create_and_save_test_data(
model: str,
output_dir: str,
batch_size: int,
sequence_length: int,
test_cases: int,
seed: int,
verbose: bool,
input_ids_name: Optional[str],
segment_ids_name: Optional[str],
input_mask_name: Optional[str],
only_input_tensors: bool,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int,
):
"""Create test data for a model, and save test data to a directory.
Args:
model (str): path of ONNX bert model
output_dir (str): output directory
batch_size (int): batch size
sequence_length (int): sequence length
test_cases (int): number of test cases
seed (int): random seed
verbose (bool): whether print more information
input_ids_name (str): graph input name of input_ids
segment_ids_name (str): graph input name of segment_ids
input_mask_name (str): graph input name of input_mask
only_input_tensors (bool): only save input tensors,
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type(int): mask type
"""
input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name)
all_inputs = generate_test_data(
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids,
segment_ids,
input_mask,
average_sequence_length,
random_sequence_length,
mask_type,
)
for i, inputs in enumerate(all_inputs):
directory = os.path.join(output_dir, "test_data_set_" + str(i))
output_test_data(directory, inputs)
if only_input_tensors:
return
import onnxruntime
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if "CUDAExecutionProvider" in onnxruntime.get_available_providers()
else ["CPUExecutionProvider"]
)
session = onnxruntime.InferenceSession(model, providers=providers)
output_names = [output.name for output in session.get_outputs()]
for i, inputs in enumerate(all_inputs):
directory = os.path.join(output_dir, "test_data_set_" + str(i))
result = session.run(output_names, inputs)
for i, output_name in enumerate(output_names): # noqa: PLW2901
tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name)
with open(os.path.join(directory, f"output_{i}.pb"), "wb") as file:
file.write(tensor_result.SerializeToString())
def main():
args = parse_arguments()
if args.average_sequence_length <= 0:
args.average_sequence_length = args.sequence_length
output_dir = args.output_dir
if output_dir is None:
# Default output directory is a sub-directory under the directory of model.
p = Path(args.model)
output_dir = os.path.join(p.parent, f"batch_{args.batch_size}_seq_{args.sequence_length}")
if output_dir is not None:
# create the output directory if not existed
path = Path(output_dir)
path.mkdir(parents=True, exist_ok=True)
else:
print("Directory existed. test data files will be overwritten.")
create_and_save_test_data(
args.model,
output_dir,
args.batch_size,
args.sequence_length,
args.samples,
args.seed,
args.verbose,
args.input_ids_name,
args.segment_ids_name,
args.input_mask_name,
args.only_input_tensors,
args.average_sequence_length,
args.random_sequence_length,
args.mask_type,
)
print("Test data is saved to directory:", output_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,246 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# It is a tool to compare the inference results of the original model and optimized model.
import argparse
import statistics
from pathlib import Path
import numpy as np
import psutil
from bert_perf_test import create_session, onnxruntime_inference
from bert_test_data import generate_test_data, get_bert_inputs, output_test_data
def run_model(model_path, all_inputs, use_gpu, disable_optimization):
import onnxruntime
graph_optimization_level = None
if disable_optimization:
graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
intra_op_num_threads = psutil.cpu_count(logical=False)
session = create_session(
model_path, use_gpu, "cuda" if use_gpu else "cpu", intra_op_num_threads, graph_optimization_level
)
output_names = [output.name for output in session.get_outputs()]
results, latency_list = onnxruntime_inference(session, all_inputs, output_names)
return results, latency_list, output_names
def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
# Validate the output of baseline and treatment, to make sure the results are similar.
diff_count = 0
max_abs_diff = 0
for test_case_id, results in enumerate(baseline_results):
case_passed = True
for i in range(len(results)):
treatment_output = treatment_results[test_case_id][i]
abs_diff = np.amax(np.abs(treatment_output - results[i]))
if verbose and abs_diff > atol:
print("abs_diff", abs_diff)
print("treatment", treatment_output)
print("baseline", results[i])
max_abs_diff = max(max_abs_diff, abs_diff)
if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
if case_passed:
case_passed = False
diff_count += 1
if verbose:
print(f"case {test_case_id} output {i}")
print(f"baseline={results[i].tolist()}\ntreatment={treatment_output}")
print(f"abs_diff={abs_diff}")
if diff_count == 0:
print(f"100% passed for {len(baseline_results)} random inputs given thresholds (rtol={rtol}, atol={atol}).")
else:
print(
f"WARNING: {diff_count} out of {len(baseline_results)} results NOT passed for thresholds (rtol={rtol}, atol={atol})."
)
print(f"maximum absolute difference={max_abs_diff}")
return max_abs_diff, case_passed
def run_test(
baseline_model,
optimized_model,
output_dir,
batch_size,
sequence_length,
use_gpu,
test_cases,
seed,
verbose,
rtol,
atol,
input_ids_name,
segment_ids_name,
input_mask_name,
mask_type,
):
# Try deduce input names from optimized model.
input_ids, segment_ids, input_mask = get_bert_inputs(
optimized_model, input_ids_name, segment_ids_name, input_mask_name
)
# Use random mask length for accuracy test. It might introduce slight inflation in latency reported in this script.
average_sequence_length = int(sequence_length / 2) if sequence_length >= 2 else sequence_length
all_inputs = generate_test_data(
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids,
segment_ids,
input_mask,
average_sequence_length,
True, # random sequence length
mask_type,
)
baseline_results, baseline_latency, output_names = run_model(
baseline_model, all_inputs, use_gpu, disable_optimization=True
)
if verbose:
print(f"baseline average latency (all optimizations disabled): {statistics.mean(baseline_latency) * 1000} ms")
if output_dir is not None:
for i, inputs in enumerate(all_inputs):
output_test_data(output_dir, i, inputs)
treatment_results, treatment_latency, treatment_output_names = run_model(
optimized_model, all_inputs, use_gpu, disable_optimization=False
)
if verbose:
print(f"treatment average latency: {statistics.mean(treatment_latency) * 1000} ms")
# Validate the output of baseline and treatment, to make sure the results are similar.
return compare(baseline_results, treatment_results, verbose, rtol, atol)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline_model", required=True, type=str, help="baseline onnx model path.")
parser.add_argument(
"--optimized_model",
required=True,
type=str,
default=None,
help="path of the optimized model. It shall have same inputs as the baseline model.",
)
parser.add_argument(
"--output_dir",
required=False,
type=str,
default=None,
help="output test data path. If not specified, test data will not be saved.",
)
parser.add_argument("--batch_size", required=True, type=int, help="batch size of input")
parser.add_argument(
"--sequence_length",
required=True,
type=int,
help="maximum sequence length of input",
)
parser.add_argument("--rtol", required=False, type=float, default=1e-3, help="relative tolerance")
parser.add_argument("--atol", required=False, type=float, default=1e-4, help="absolute tolerance")
parser.add_argument(
"--samples",
required=False,
type=int,
default=100,
help="number of test cases to be generated",
)
parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="print verbose information",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"--input_ids",
required=False,
type=str,
default=None,
help="input name for input ids",
)
parser.add_argument(
"--segment_ids",
required=False,
type=str,
default=None,
help="input name for segment ids",
)
parser.add_argument(
"--input_mask",
required=False,
type=str,
default=None,
help="input name for attention mask",
)
parser.add_argument(
"--mask_type",
required=False,
type=int,
default=2,
help="mask type: (1: mask index or sequence length, 2: raw 2D mask, 3: key len, cumulated lengths of query and key)",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
if args.output_dir is not None:
# create the output directory if not existed
path = Path(args.output_dir)
path.mkdir(parents=True, exist_ok=True)
run_test(
args.baseline_model,
args.optimized_model,
args.output_dir,
args.batch_size,
args.sequence_length,
args.use_gpu,
args.samples,
args.seed,
args.verbose,
args.rtol,
args.atol,
args.input_ids,
args.segment_ids,
args.input_mask,
args.mask_type,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,47 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
class Operators:
ATTENTION = "Attention"
LAYERNORM = "LayerNormalization"
MULTI_HEAD_ATTENTION = "MultiHeadAttention"
PACKEDATTENTION = "PackedAttention"
PACKED_MULTI_HEAD_ATTENTION = "PackedMultiHeadAttention"
REMOVEPADDING = "RemovePadding"
RESTOREPADDING = "RestorePadding"
SKIPLAYERNORM = "SkipLayerNormalization"
class AttentionInputIDs:
INPUT = 0
WEIGHTS = 1
BIAS = 2
MASK_INDEX = 3
PAST = 4
RELATIVE_POSITION_BIAS = 5
PAST_SEQUENCE_LENGTH = 6
class AttentionOutputIDs:
OUTPUT = 0
PRESENT = 1
class MultiHeadAttentionInputIDs:
QUERY = 0
KEY = 1
VALUE = 2
BIAS = 3
KEY_PADDING_MASK = 4
RELATIVE_POSITION_BIAS = 5
PAST_KEY = 6
PAST_VALUE = 7
class MultiHeadAttentionOutputIDs:
OUTPUT = 0
PRESENT_KEY = 1
PRESENT_VALUE = 2

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,205 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import glob
import os
import requests
TFMODELS = {
"bert-base-uncased": (
"bert",
"BertConfig",
"",
"https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip",
),
"bert-base-cased": (
"bert",
"BertConfig",
"",
"https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip",
),
"bert-large-uncased": (
"bert",
"BertConfig",
"",
"https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip",
),
"albert-base": (
"albert",
"AlbertConfig",
"",
"https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz",
),
"albert-large": (
"albert",
"AlbertConfig",
"",
"https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz",
),
"gpt-2-117M": (
"gpt2",
"GPT2Config",
"GPT2Model",
"https://storage.googleapis.com/gpt-2/models/117M",
),
"gpt-2-124M": (
"gpt2",
"GPT2Config",
"GPT2Model",
"https://storage.googleapis.com/gpt-2/models/124M",
),
}
def download_compressed_file(tf_ckpt_url, ckpt_dir):
r = requests.get(tf_ckpt_url)
compressed_file_name = tf_ckpt_url.split("/")[-1]
compressed_file_dir = os.path.join(ckpt_dir, compressed_file_name)
with open(compressed_file_dir, "wb") as f:
f.write(r.content)
return compressed_file_dir
def get_ckpt_prefix_path(ckpt_dir):
# get prefix
sub_folder_dir = None
for o in os.listdir(ckpt_dir):
sub_folder_dir = os.path.join(ckpt_dir, o)
break
if os.path.isfile(sub_folder_dir):
sub_folder_dir = ckpt_dir
unique_file_name = str(glob.glob(sub_folder_dir + "/*data-00000-of-00001"))
prefix = (unique_file_name.rpartition(".")[0]).split("/")[-1]
return os.path.join(sub_folder_dir, prefix)
def download_tf_checkpoint(model_name, tf_models_dir="tf_models"):
import pathlib
base_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), tf_models_dir)
ckpt_dir = os.path.join(base_dir, model_name)
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
tf_ckpt_url = TFMODELS[model_name][3]
import re
if re.search(".zip$", tf_ckpt_url) is not None:
zip_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
# unzip file
import zipfile
with zipfile.ZipFile(zip_dir, "r") as zip_ref:
zip_ref.extractall(ckpt_dir)
os.remove(zip_dir)
return get_ckpt_prefix_path(ckpt_dir)
elif re.search(".tar.gz$", tf_ckpt_url) is not None:
tar_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
# untar file
import tarfile
with tarfile.open(tar_dir, "r") as tar_ref:
tar_ref.extractall(ckpt_dir)
os.remove(tar_dir)
return get_ckpt_prefix_path(ckpt_dir)
else:
for filename in [
"checkpoint",
"model.ckpt.data-00000-of-00001",
"model.ckpt.index",
"model.ckpt.meta",
]:
r = requests.get(tf_ckpt_url + "/" + filename)
with open(os.path.join(ckpt_dir, filename), "wb") as f:
f.write(r.content)
return get_ckpt_prefix_path(ckpt_dir)
def init_pytorch_model(model_name, tf_checkpoint_path):
config_name = TFMODELS[model_name][1]
config_module = __import__("transformers", fromlist=[config_name])
model_config = getattr(config_module, config_name)
parent_path = tf_checkpoint_path.rpartition("/")[0]
config_path = glob.glob(parent_path + "/*config.json")
config = model_config() if len(config_path) == 0 else model_config.from_json_file(str(config_path[0]))
if not TFMODELS[model_name][2]:
from transformers import AutoModelForPreTraining
init_model = AutoModelForPreTraining.from_config(config)
else:
model_categroy_name = TFMODELS[model_name][2]
module = __import__("transformers", fromlist=[model_categroy_name])
model_categroy = getattr(module, model_categroy_name)
init_model = model_categroy(config)
return config, init_model
def convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2):
load_tf_weight_func_name = "load_tf_weights_in_" + TFMODELS[model_name][0]
module = __import__("transformers", fromlist=[load_tf_weight_func_name])
if is_tf2 is False:
load_tf_weight_func = getattr(module, load_tf_weight_func_name)
else:
if TFMODELS[model_name][0] != "bert":
raise NotImplementedError("Only support tf2 ckeckpoint for Bert model")
from transformers import convert_bert_original_tf2_checkpoint_to_pytorch
load_tf_weight_func = convert_bert_original_tf2_checkpoint_to_pytorch.load_tf2_weights_in_bert
# Expect transformers team will unify the order of signature in the future
model = (
load_tf_weight_func(init_model, config, tf_checkpoint_path)
if is_tf2 is False
else load_tf_weight_func(init_model, tf_checkpoint_path, config)
)
model.eval()
return model
def tf2pt_pipeline(model_name, is_tf2=False):
if model_name not in TFMODELS:
raise NotImplementedError(model_name + " not implemented")
tf_checkpoint_path = download_tf_checkpoint(model_name)
config, init_model = init_pytorch_model(model_name, tf_checkpoint_path)
model = convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2)
# Could then use the model in Benchmark
return config, model
def tf2pt_pipeline_test():
# For test on linux only
import logging
import torch
logger = logging.getLogger("")
for model_name in TFMODELS:
config, model = tf2pt_pipeline(model_name)
assert config.model_type is TFMODELS[model_name][0]
input = torch.randint(low=0, high=config.vocab_size - 1, size=(4, 128), dtype=torch.long)
try:
model(input)
except RuntimeError as e:
logger.exception(e)
if __name__ == "__main__":
tf2pt_pipeline_test()

View File

@ -0,0 +1,387 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import logging
import os
from typing import List, Union
import coloredlogs
from constants import (
AttentionInputIDs,
AttentionOutputIDs,
MultiHeadAttentionInputIDs,
MultiHeadAttentionOutputIDs,
Operators,
)
from onnx import helper, load_model
from onnx_model import NodeProto, OnnxModel
from shape_infer_helper import SymbolicShapeInferenceHelper
logger = logging.getLogger(__name__)
class PackingAttentionBase:
def __init__(self, model: OnnxModel, attention_op_type: str):
self.model: OnnxModel = model
self.nodes_to_remove: List = []
self.nodes_to_add: List = []
self.prune_graph: bool = False
self.node_name_to_graph_name: dict = {}
self.this_graph_name: str = self.model.model.graph.name
self.attention_op_type = attention_op_type
self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type)
def _try_getting_attention_mask(self) -> Union[str, None]:
mask_index = (
AttentionInputIDs.MASK_INDEX
if self.attention_op_type == Operators.ATTENTION
else MultiHeadAttentionInputIDs.KEY_PADDING_MASK
)
first_attention_node = self._try_getting_first_attention()
# check if attention has mask
if not first_attention_node or len(first_attention_node.input) <= mask_index:
return None
attention_mask = first_attention_node.input[mask_index]
# check if all attention nodes have same mask
for node in self.attention_nodes:
if len(node.input) <= mask_index or node.input[mask_index] != attention_mask:
return None
return attention_mask
def _try_getting_first_attention(self) -> Union[NodeProto, None]:
if len(self.attention_nodes) <= 0:
return None
return self.attention_nodes[0]
def _try_getting_last_layernorm(self) -> Union[NodeProto, None]:
last_layernorm_node = None
for node in self.model.nodes():
if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM:
last_layernorm_node = node
return last_layernorm_node
def _are_attentions_supported(self) -> bool:
raise NotImplementedError()
def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
new_node = helper.make_node(
Operators.REMOVEPADDING,
inputs=inputs,
outputs=outputs,
name=self.model.create_node_name(Operators.REMOVEPADDING),
)
new_node.domain = "com.microsoft"
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
def _insert_restorepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
new_node = helper.make_node(
Operators.RESTOREPADDING,
inputs=inputs,
outputs=outputs,
name=self.model.create_node_name(Operators.RESTOREPADDING),
)
new_node.domain = "com.microsoft"
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
raise NotImplementedError()
def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
if self.attention_op_type == Operators.ATTENTION:
return first_attention_node.input[AttentionInputIDs.INPUT]
return None
def convert(self, use_symbolic_shape_infer: bool = True) -> None:
logger.debug("start converting to packing model...")
if not self._are_attentions_supported():
return
attention_mask = self._try_getting_attention_mask()
if not attention_mask:
return
first_attention_node = self._try_getting_first_attention()
last_layernorm_node = self._try_getting_last_layernorm()
if not last_layernorm_node:
return
# insert RemovePadding
input_to_remove_padding = self._get_input_to_remove_padding(first_attention_node)
if not input_to_remove_padding:
return
output_without_padding = input_to_remove_padding + "_no_padding"
token_offset = input_to_remove_padding + "_token_offset"
cumulated_seq_len = input_to_remove_padding + "_cumulated_seq_len"
max_seq_len = input_to_remove_padding + "_max_seq_len"
self._insert_removepadding_node(
[input_to_remove_padding, attention_mask],
[output_without_padding, token_offset, cumulated_seq_len, max_seq_len],
)
self.model.replace_input_of_all_nodes(input_to_remove_padding, output_without_padding)
logger.debug("inserted RemovePadding before Attention")
# insert RestorePadding
restorepadding_input = last_layernorm_node.output[0] + "_restore_input"
self._insert_restorepadding_node([restorepadding_input, token_offset], [last_layernorm_node.output[0]])
self.model.replace_output_of_all_nodes(last_layernorm_node.output[0], restorepadding_input)
logger.debug(f"inserted RestorePadding after last {last_layernorm_node.op_type} layer")
# insert PackedAttention
self._replace_attention_with_packing_attention(token_offset, cumulated_seq_len)
logger.debug(f"replaced {self.attention_op_type} with Packed{self.attention_op_type}")
self.model.remove_nodes(self.nodes_to_remove)
self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
if self.prune_graph:
self.model.prune_graph()
elif self.nodes_to_remove or self.nodes_to_add:
self.model.update_graph()
self.model.clean_shape_infer()
if use_symbolic_shape_infer:
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
# are not recognized by onnx shape inference.
shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model, verbose=0)
inferred_model = shape_infer_helper.infer_shapes(self.model.model, auto_merge=True, guess_output_rank=False)
if inferred_model:
self.model.model = inferred_model
class PackingAttention(PackingAttentionBase):
def __init__(self, model: OnnxModel):
super().__init__(model, Operators.ATTENTION)
def _are_attentions_supported(self) -> bool:
for node in self.attention_nodes:
if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None:
return False
if OnnxModel.get_node_attribute(node, "do_rotary") is not None:
return False
unidirection_attr = OnnxModel.get_node_attribute(node, "unidirectional")
if unidirection_attr is not None and unidirection_attr != 0:
return False
if len(node.input) > AttentionInputIDs.PAST and not node.input[AttentionInputIDs.PAST]:
return False
if (
len(node.input) > AttentionInputIDs.PAST_SEQUENCE_LENGTH
and not node.input[AttentionInputIDs.PAST_SEQUENCE_LENGTH]
):
return False
return True
def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
for attention in self.attention_nodes:
relative_pos_bias = (
attention.input[AttentionInputIDs.RELATIVE_POSITION_BIAS]
if len(attention.input) > AttentionInputIDs.RELATIVE_POSITION_BIAS
else ""
)
packed_attention = helper.make_node(
Operators.PACKEDATTENTION,
inputs=[
attention.input[AttentionInputIDs.INPUT],
attention.input[AttentionInputIDs.WEIGHTS],
attention.input[AttentionInputIDs.BIAS],
token_offset,
cumulative_sequence_length,
relative_pos_bias,
],
outputs=[attention.output[AttentionOutputIDs.OUTPUT]],
name=self.model.create_node_name(Operators.PACKEDATTENTION),
)
attributes = []
for attr in attention.attribute:
if attr.name in ["num_heads", "qkv_hidden_sizes", "scale"]:
attributes.append(attr)
packed_attention.attribute.extend(attributes)
packed_attention.domain = "com.microsoft"
self.nodes_to_add.append(packed_attention)
self.nodes_to_remove.append(attention)
self.node_name_to_graph_name[packed_attention.name] = self.this_graph_name
logger.info("Converted %d Attention nodes to PackedAttention.", len(self.attention_nodes))
class PackingMultiHeadAttention(PackingAttentionBase):
def __init__(self, model: OnnxModel):
super().__init__(model, Operators.MULTI_HEAD_ATTENTION)
def _check_empty_input(self, node, index: int, name: str):
"""Check a node does not have given input."""
if len(node.input) > index:
if len(node.input[index]) > 0:
logger.error(f"node input {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
return False
return True
def _check_empty_output(self, node, index: int, name: str):
"""Check a node does not have given input."""
if len(node.output) > index:
if len(node.output[index]) > 0:
logger.error(f"node output {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
return False
return True
def _are_attentions_supported(self) -> bool:
for node in self.attention_nodes:
for attr in node.attribute:
if attr.name not in ["num_heads", "mask_filter_value", "scale"]:
logger.error(f"node attribute {attr.name} is not supported in PackedMultiHeadAttention: {node}")
return False
if node.input[MultiHeadAttentionInputIDs.KEY] and not node.input[MultiHeadAttentionInputIDs.VALUE]:
logger.error("packed kv format is not supported in PackedMultiHeadAttention")
return False
if not (
self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_KEY, "past_key")
and self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_VALUE, "past_key")
and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_KEY, "present_key")
and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_VALUE, "present_key")
):
return False
return True
def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
gated_relative_pos_bias_count = 0
for mha in self.attention_nodes:
relative_pos_bias = (
mha.input[MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS]
if len(mha.input) > MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS
else ""
)
packed_mha = helper.make_node(
Operators.PACKED_MULTI_HEAD_ATTENTION,
inputs=[
mha.input[MultiHeadAttentionInputIDs.QUERY],
mha.input[MultiHeadAttentionInputIDs.KEY],
mha.input[MultiHeadAttentionInputIDs.VALUE],
mha.input[MultiHeadAttentionInputIDs.BIAS],
token_offset,
cumulative_sequence_length,
relative_pos_bias,
],
outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]],
name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION),
)
attributes = []
for attr in mha.attribute:
if attr.name in ["num_heads", "mask_filter_value", "scale"]:
attributes.append(attr)
packed_mha.attribute.extend(attributes)
packed_mha.domain = "com.microsoft"
self.nodes_to_add.append(packed_mha)
self.nodes_to_remove.append(mha)
self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name
# Append token_offset input to GatedRelativePositionBias
if relative_pos_bias:
rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS)
if (
rel_pos_bias_node
and rel_pos_bias_node.op_type == "GatedRelativePositionBias"
and len(rel_pos_bias_node.input) == 6
):
rel_pos_bias_node.input.append(token_offset)
gated_relative_pos_bias_count += 1
logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes))
logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count)
def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
# When there are query, key and value inputs, we need to find the first input of the parent MatMul node.
matmul = self.model.get_parent(first_attention_node, 0)
if matmul and matmul.op_type == "MatMul":
return matmul.input[0]
return None
class PackingMode:
def __init__(self, model: OnnxModel):
self.model = model
def convert(self, use_symbolic_shape_infer: bool = True) -> None:
if self.model.get_nodes_by_op_type(Operators.ATTENTION):
if self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
logger.error("Packing mode does not support both Attention and MultiHeadAttention in same graph.")
return None
packing = PackingAttention(self.model)
return packing.convert(use_symbolic_shape_infer)
elif self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
packing = PackingMultiHeadAttention(self.model)
return packing.convert(use_symbolic_shape_infer)
else:
logger.error("Packing mode requires either Attention or MultiHeadAttention node in onnx graph.")
return None
def _parse_arguments():
parser = argparse.ArgumentParser(
description="Convert to packing mode tool for ONNX Runtime. It converts BERT like model to use packing mode."
)
parser.add_argument("--input", required=True, type=str, help="input onnx model path")
parser.add_argument("--output", required=True, type=str, help="optimized onnx model path")
parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.")
parser.set_defaults(verbose=False)
parser.add_argument(
"--use_external_data_format",
required=False,
action="store_true",
help="use external data format to store large model (>2GB)",
)
parser.set_defaults(use_external_data_format=False)
args = parser.parse_args()
return args
def _setup_logger(verbose):
if verbose:
coloredlogs.install(
level="DEBUG",
fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
)
else:
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
def main():
args = _parse_arguments()
_setup_logger(args.verbose)
logger.debug(f"arguments:{args}")
if os.path.realpath(args.input) == os.path.realpath(args.output):
logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
model = load_model(args.input)
packing_mode = PackingMode(OnnxModel(model))
packing_mode.convert()
packing_mode.model.save_model_to_file(args.output, use_external_data_format=args.use_external_data_format)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,104 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import onnx
class DynamoOnnxHelper:
"""
Helper class for processing ONNX models exported by torch Dynamo.
"""
def __init__(self, model: onnx.ModelProto):
self.model = model
def update_edges(self, edge_mapping: dict) -> None:
"""
Updates the edges in the model according to the given mapping.
"""
for node in self.model.graph.node:
for i in range(len(node.input)):
if node.input[i] in edge_mapping:
node.input[i] = edge_mapping[node.input[i]]
for i in range(len(node.output)):
if node.output[i] in edge_mapping:
node.output[i] = edge_mapping[node.output[i]]
for graph_input in self.model.graph.input:
if graph_input.name in edge_mapping:
graph_input.name = edge_mapping[graph_input.name]
for graph_output in self.model.graph.output:
if graph_output.name in edge_mapping:
graph_output.name = edge_mapping[graph_output.name]
def unroll_function(self, func_name: str) -> None:
"""
Unrolls the function with the given name in the model.
"""
logging.info(f"Unrolling function {func_name}...")
nodes_to_remove = []
nodes_to_add = []
edges_to_remove = []
edges_to_add = []
for node in self.model.graph.node:
if node.op_type == func_name:
nodes_to_remove.append(node)
edges_to_remove.extend(list(node.input) + list(node.output))
func_to_remove = None
for f in self.model.functions:
if f.name == func_name:
nodes_to_add.extend(list(f.node))
edges_to_add.extend(list(f.input) + list(f.output))
func_to_remove = f
assert len(edges_to_remove) == len(edges_to_add)
for node in nodes_to_remove:
self.model.graph.node.remove(node)
for node in nodes_to_add:
self.model.graph.node.append(node)
if func_to_remove is not None:
self.model.functions.remove(func_to_remove)
edge_mapping = {}
for i in range(len(edges_to_remove)):
k = edges_to_remove[i]
v = edges_to_add[i]
if k != v:
edge_mapping[k] = v
return self.update_edges(edge_mapping)
def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
"""
Removes the function in the model.
"""
edge_mapping = {}
nodes_to_remove = []
for node in self.model.graph.node:
if node.op_type.find(func_name) != -1:
edge_mapping[node.input[input_id]] = node.output[output_id]
nodes_to_remove.append(node)
for node in nodes_to_remove:
self.model.graph.node.remove(node)
self.update_edges(edge_mapping)
def remove_dropout_layer(self) -> None:
"""
Removes the dropout layer in the model.
"""
logging.info("Removing dropout layer...")
self.remove_function("Dropout", 0, 0)
def remove_lm_head_layer(self) -> None:
"""
Removes the LM head layer in the model.
"""
logging.info("Removing LM head layer...")
# bugbug: need to copy the right vi over
self.remove_function("Linear_lm_head", 2, 0)

View File

@ -0,0 +1,501 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# This file is modified from https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/float16.py
# Modifications:
# (1) Update default value of min_positive_val and max_finite_val
# (2) keep_io_types can be list of names
# (3) convert initializers if needed to preserve precision
# (4) add force_fp16_initializers option
# (5) handle Resize and GroupNorm with mixed float inputs
# (6) allow convert_float_to_float16 to accept model path
import itertools
import logging
import os
import tempfile
from typing import Dict
import numpy as np
import onnx
from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper
from onnx.shape_inference import infer_shapes, infer_shapes_path
from packaging import version
logger = logging.getLogger(__name__)
def _npfloat16_to_int(np_list):
"""
Convert numpy float16 to python int.
:param np_list: numpy float16 list
:return int_list: python int list
"""
return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list]
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0):
"""
Convert float32 numpy array to float16 without changing sign or finiteness.
Positive values less than min_positive_val are mapped to min_positive_val.
Positive finite values greater than max_finite_val are mapped to max_finite_val.
Similar for negative values. NaN, 0, inf, and -inf are unchanged.
"""
def between(a, b, c):
return np.logical_and(a < b, b < c)
if np_array[np.where(np_array > 0)].shape[0] > 0:
positive_max = np_array[np.where(np_array > 0)].max()
positive_min = np_array[np.where(np_array > 0)].min()
if positive_max >= max_finite_val:
logger.debug(f"the float32 number {positive_max} will be truncated to {max_finite_val}")
if positive_min <= min_positive_val:
logger.debug(f"the float32 number {positive_min} will be truncated to {min_positive_val}")
if np_array[np.where(np_array < 0)].shape[0] > 0:
negative_max = np_array[np.where(np_array < 0)].max()
negative_min = np_array[np.where(np_array < 0)].min()
if negative_min <= -max_finite_val:
logger.debug(f"the float32 number {negative_min} will be truncated to {-max_finite_val}")
if negative_max >= -min_positive_val:
logger.debug(f"the float32 number {negative_max} will be truncated to {-min_positive_val}")
np_array = np.where(between(0, np_array, min_positive_val), min_positive_val, np_array)
np_array = np.where(between(-min_positive_val, np_array, 0), -min_positive_val, np_array)
np_array = np.where(between(max_finite_val, np_array, float("inf")), max_finite_val, np_array)
np_array = np.where(between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array)
return np.float16(np_array)
def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0):
"""Convert tensor float to float16.
Args:
tensor (TensorProto): the tensor to convert.
min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
Raises:
ValueError: input type is not TensorProto.
Returns:
TensorProto: the converted tensor.
"""
if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
if tensor.data_type == TensorProto.FLOAT:
tensor.data_type = TensorProto.FLOAT16
# convert float_data (float type) to float16 and write to int32_data
if tensor.float_data:
float16_data = convert_np_to_float16(np.array(tensor.float_data), min_positive_val, max_finite_val)
int_list = _npfloat16_to_int(float16_data)
tensor.int32_data[:] = int_list
tensor.float_data[:] = []
# convert raw_data (bytes type)
if tensor.raw_data:
# convert n.raw_data to float
float32_list = np.frombuffer(tensor.raw_data, dtype="float32")
# convert float to float16
float16_list = convert_np_to_float16(float32_list, min_positive_val, max_finite_val)
# convert float16 to bytes and write back to raw_data
tensor.raw_data = float16_list.tobytes()
return tensor
def make_value_info_from_tensor(tensor):
shape = numpy_helper.to_array(tensor).shape
return helper.make_tensor_value_info(tensor.name, tensor.data_type, shape)
DEFAULT_OP_BLOCK_LIST = [
"ArrayFeatureExtractor",
"Binarizer",
"CastMap",
"CategoryMapper",
"DictVectorizer",
"FeatureVectorizer",
"Imputer",
"LabelEncoder",
"LinearClassifier",
"LinearRegressor",
"Normalizer",
"OneHotEncoder",
"RandomUniformLike",
"SVMClassifier",
"SVMRegressor",
"Scaler",
"TreeEnsembleClassifier",
"TreeEnsembleRegressor",
"ZipMap",
"NonMaxSuppression",
"TopK",
"RoiAlign",
"Range",
"CumSum",
"Min",
"Max",
"Upsample",
]
# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}
class InitializerTracker:
"""Class for keeping track of initializer."""
def __init__(self, initializer: TensorProto):
self.initializer = initializer
self.fp32_nodes = []
self.fp16_nodes = []
def add_node(self, node: NodeProto, is_node_blocked):
if is_node_blocked:
self.fp32_nodes.append(node)
else:
self.fp16_nodes.append(node)
def convert_float_to_float16(
model,
min_positive_val=5.96e-08,
max_finite_val=65504.0,
keep_io_types=False,
disable_shape_infer=False,
op_block_list=None,
node_block_list=None,
force_fp16_initializers=False,
force_fp16_inputs=None,
use_bfloat16_as_blocked_nodes_dtype=False,
):
"""Convert tensor float type in the input ONNX model to tensor float16.
Args:
model (ModelProto or str): The ONNX model or path of the model to convert.
min_positive_val (float, optional): minimal positive value. Defaults to 5.96e-08.
max_finite_val (float, optional): maximal finite value of float16. Defaults to 65504.
keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names.
If True, model inputs/outputs should be left as float32.
Defaults to False.
disable_shape_infer (bool, optional): Skips running onnx shape/type inference.
Useful if shape inference has been done. Defaults to False.
op_block_list (List[str], optional): List of op types to leave as float32.
Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`.
node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
force_fp16_initializers(bool): force converting all float initializers to float16.
Default to false, which will convert only the one needed to avoid precision loss.
force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if
this script's preference it to keep them in float32.
Raises:
ValueError: input type is not ModelProto.
Returns:
ModelProto: converted model.
"""
assert (
min_positive_val >= 5.96e-08
), "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05"
assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504"
force_fp16_inputs_dict = {} if force_fp16_inputs is None else force_fp16_inputs
if isinstance(model, str):
model_path = model
if version.parse(onnx.__version__) >= version.parse("1.8.0") and not disable_shape_infer:
# shape_infer_model_path should be in the same folder of model_path
with tempfile.NamedTemporaryFile(dir=os.path.dirname(model_path)) as tmpfile:
shape_infer_model_path = tmpfile.name
# infer_shapes_path can be used for model >2GB, and infer_shapes cannot.
infer_shapes_path(model_path, shape_infer_model_path)
model = onnx.load(shape_infer_model_path)
disable_shape_infer = True
else:
model = onnx.load(model_path)
if not isinstance(model, ModelProto):
raise ValueError(f"Expected an ONNX ModelProto but got {type(model)}")
func_infer_shape = None
if not disable_shape_infer and version.parse(onnx.__version__) >= version.parse("1.2.0"):
try:
func_infer_shape = infer_shapes
finally:
pass
# create blocklists
if op_block_list is None:
op_block_list = DEFAULT_OP_BLOCK_LIST
if node_block_list is None:
node_block_list = []
op_block_list = set(op_block_list)
node_block_list = set(node_block_list)
logger.debug(
f"fp16 parameters: min_positive_val={min_positive_val} max_finite_val={max_finite_val} keep_io_types={keep_io_types} disable_shape_infer={disable_shape_infer} op_block_list={op_block_list} node_block_list={node_block_list} force_fp16_initializers={force_fp16_initializers}"
)
# create a queue for BFS
queue = []
value_info_list = []
node_list = []
# Some operators (Like Resize or GroupNorm) have data type fixed as float for some input.
# When it is converted to float16, there are mixed types: some inputs are float32 and some are float16.
# This list keeps track of such nodes that are not in block list.
mixed_float_type_node_list = []
# type inference on input model
if func_infer_shape is not None:
model = func_infer_shape(model)
queue.append(model)
name_mapping = {}
graph_io_to_skip = set()
io_casts = set()
fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == TensorProto.FLOAT]
fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == TensorProto.FLOAT]
if isinstance(keep_io_types, list):
fp32_inputs = [n for n in fp32_inputs if n in keep_io_types]
fp32_outputs = [n for n in fp32_outputs if n in keep_io_types]
elif not keep_io_types:
fp32_inputs = []
fp32_outputs = []
for i, n in enumerate(model.graph.input):
if n.name in fp32_inputs:
output_name = "graph_input_cast_" + str(i)
name_mapping[n.name] = output_name
graph_io_to_skip.add(n.name)
node_name = "graph_input_cast" + str(i)
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(n)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
# add Cast node (from tensor(float) to tensor(float16) after graph input
new_node = [helper.make_node("Cast", [n.name], [output_name], to=TensorProto.FLOAT16, name=node_name)]
model.graph.node.extend(new_node)
value_info_list.append(new_value_info)
io_casts.add(node_name)
for i, n in enumerate(model.graph.output):
if n.name in fp32_outputs:
input_name = "graph_output_cast_" + str(i)
name_mapping[n.name] = input_name
graph_io_to_skip.add(n.name)
node_name = "graph_output_cast" + str(i)
# add Cast node (from tensor(float16) to tensor(float) before graph output
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(n)
new_value_info.name = input_name
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
new_node = [helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name)]
model.graph.node.extend(new_node)
value_info_list.append(new_value_info)
io_casts.add(node_name)
fp32_initializers: Dict[str, InitializerTracker] = {}
while queue:
next_level = []
for q in queue:
# if q is model, push q.graph (GraphProto)
if isinstance(q, ModelProto):
next_level.append(q.graph)
# if q is model.graph, push q.node.attribute (AttributeProto)
if isinstance(q, GraphProto):
for n in q.initializer: # TensorProto type
if n.data_type == TensorProto.FLOAT:
assert n.name not in fp32_initializers
fp32_initializers[n.name] = InitializerTracker(n)
for n in q.node:
# if n is in the block list (doesn't support float16), no conversion for the node,
# and save the node for further processing
if n.name in io_casts:
continue
for i in range(len(n.input)):
if n.input[i] in name_mapping:
n.input[i] = name_mapping[n.input[i]]
for i in range(len(n.output)):
if n.output[i] in name_mapping:
n.output[i] = name_mapping[n.output[i]]
is_node_blocked = n.op_type in op_block_list or n.name in node_block_list
for i, input_name in enumerate(n.input):
if input_name in fp32_initializers:
# For Resize/GroupNorm, only the first input can be float16
use_fp32_weight = is_node_blocked or (
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, [])
and i not in force_fp16_inputs_dict.get(n.op_type, [])
)
fp32_initializers[input_name].add_node(n, use_fp32_weight)
if is_node_blocked:
node_list.append(n)
else:
if n.op_type == "Cast":
for attr in n.attribute:
if attr.name == "to" and attr.i == TensorProto.FLOAT:
attr.i = TensorProto.FLOAT16
break
if n.op_type in [
"EyeLike",
"Multinomial",
"RandomNormal",
"RandomNormalLike",
"RandomUniform",
"RandomUniformLike",
"SequenceEmpty",
"Bernoulli",
]:
has_dtype = False
for attr in n.attribute:
if attr.name == "dtype":
has_dtype = True
if attr.i == TensorProto.FLOAT:
attr.i = TensorProto.FLOAT16
# The dtype attribute is optional and default is FLOAT in the following operators
# so we need add dtype attribute to specify the data type float16
if (n.op_type in ["RandomNormal", "RandomUniform", "SequenceEmpty"]) and not has_dtype:
n.attribute.extend([helper.make_attribute("dtype", TensorProto.FLOAT16)])
# For Resize/GroupNorm, attribute data type cannot be changed
if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict:
for attr in n.attribute:
next_level.append(attr) # noqa: PERF402
else:
mixed_float_type_node_list.append(n)
# if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto)
# and process node.attribute.t and node.attribute.tensors (TensorProto)
if isinstance(q, AttributeProto):
next_level.append(q.g)
for n in q.graphs:
next_level.append(n) # noqa: PERF402
q.t.CopyFrom(convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val))
for n in q.tensors:
n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) # noqa: PLW2901
# if q is graph, process input, output and value_info (ValueInfoProto)
if isinstance(q, GraphProto):
# Note that float initializers tracked by fp32_initializers will be processed later.
# for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to
# tensor(float16) except map and seq(map). And save them in value_info_list for further processing
for n in itertools.chain(q.input, q.output, q.value_info):
if n.type.tensor_type.elem_type == TensorProto.FLOAT:
if n.name not in graph_io_to_skip:
n.type.tensor_type.elem_type = TensorProto.FLOAT16
value_info_list.append(n)
if n.type.HasField("sequence_type"):
if n.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT:
if n.name not in graph_io_to_skip:
n.type.sequence_type.elem_type.tensor_type.elem_type = TensorProto.FLOAT16
value_info_list.append(n)
queue = next_level
for value in fp32_initializers.values():
# By default, to avoid precision loss, do not convert an initializer to fp16 when it is used only by fp32 nodes.
if force_fp16_initializers or value.fp16_nodes:
value.initializer = convert_tensor_float_to_float16(value.initializer, min_positive_val, max_finite_val)
value_info_list.append(make_value_info_from_tensor(value.initializer))
if value.fp32_nodes and not force_fp16_initializers:
logger.info(
f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}"
)
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
for node in mixed_float_type_node_list:
for i, input_name in enumerate(node.input):
if i not in ALWAYS_FLOAT_INPUTS[node.op_type] or i in force_fp16_inputs_dict.get(node.op_type, []):
continue
for value_info in value_info_list:
if input_name == value_info.name:
# create new value_info for current node's new input name
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(value_info)
output_name = node.name + "_input_cast_" + str(i)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
model.graph.node.extend(new_node)
# change current node's input name
node.input[i] = output_name
break
accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT
# process the nodes in block list that doesn't support tensor(float16)
for node in node_list:
# if input's name is in the value_info_list meaning input is tensor(float16) type,
# insert a float16 to float Cast node before the node,
# change current node's input name and create new value_info for the new name
for i in range(len(node.input)):
input_name = node.input[i]
for value_info in value_info_list:
if input_name == value_info.name:
# create new value_info for current node's new input name
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(value_info)
output_name = node.name + "_input_cast_" + str(i)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = accuracy_type
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)]
model.graph.node.extend(new_node)
# change current node's input name
node.input[i] = output_name
break
# if output's name is in the value_info_list meaning output is tensor(float16) type, insert a float to
# float16 Cast node after the node, change current node's output name and create new value_info for the new name
for i in range(len(node.output)):
output = node.output[i]
for value_info in value_info_list:
if output == value_info.name:
# create new value_info for current node's new output
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(value_info)
input_name = node.name + "_output_cast_" + str(i)
new_value_info.name = input_name
new_value_info.type.tensor_type.elem_type = accuracy_type
# add Cast node (from tensor(float) to tensor(float16) after current node
node_name = node.name + "_output_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)]
model.graph.node.extend(new_node)
# change current node's input name
node.output[i] = input_name
break
return model
def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0):
"""Measure the maximum absolute difference after converting a float tensor to float16."""
if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
if tensor.data_type != TensorProto.FLOAT:
raise ValueError("Expected tensor data type is float.")
float32_data = None
if tensor.float_data:
float32_data = np.array(tensor.float_data)
if tensor.raw_data:
float32_data = np.frombuffer(tensor.raw_data, dtype="float32")
if float32_data is None:
raise RuntimeError("external data not loaded!")
float16_data = convert_np_to_float16(float32_data, min_positive_val, max_finite_val)
return np.amax(np.abs(float32_data - np.float32(float16_data)))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,257 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Tuple
from fusion_attention import AttentionMask, FusionAttention
from fusion_options import AttentionMaskFormat
from onnx import NodeProto
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionAttentionClip(FusionAttention):
"""
Fuse Attention subgraph of Clip into one Attention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
):
attention_mask = AttentionMask(model)
attention_mask.mask_format = AttentionMaskFormat.NoMask
super().__init__(
model,
hidden_size,
num_heads,
attention_mask,
use_multi_head_attention=False,
search_op_types=["SkipLayerNormalization"],
)
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
"""Detect num_heads and hidden_size for ONNX model from MiDaS
Args:
reshape_q (NodeProto): reshape node for q
Returns:
Tuple[int, int]: num_heads and hidden_size
"""
concat = self.model.match_parent(reshape_q, "Concat", 1)
if concat is None or len(concat.input) != 4:
return self.num_heads, self.hidden_size
# The shape is a tensor like [?, ?, num_heads, head_size]
num_head_value = self.model.get_constant_value(concat.input[2])
if num_head_value is None:
return self.num_heads, self.hidden_size # Fall back to user specified value
if len(num_head_value) != 1 or num_head_value[0] <= 0:
return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = num_head_value[0]
head_size_value = self.model.get_constant_value(concat.input[3])
if head_size_value is None:
return self.num_heads, self.hidden_size # Fall back to user specified value
if len(head_size_value) != 1 or head_size_value[0] <= 0:
return self.num_heads, self.hidden_size # Fall back to user specified value
head_size = head_size_value[0]
hidden_size = num_heads * head_size
if self.num_heads > 0 and num_heads != self.num_heads:
if self.num_heads_warning:
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
self.num_heads_warning = False # Do not show the warning more than once
if self.hidden_size > 0 and hidden_size != self.hidden_size:
if self.hidden_size_warning:
logger.warning(
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
)
self.hidden_size_warning = False # Do not show the warning more than once
return num_heads, hidden_size
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
skip_input_index = None
node_before_layer_norm = None
for i in [1, 0]:
parent = self.model.match_parent(normalize_node, "SkipLayerNormalization", i)
if parent is not None:
skip_input_index = i
node_before_layer_norm = parent
root_input = None
if node_before_layer_norm is not None:
root_input = node_before_layer_norm.output[0]
else:
# Deal with the first attention after the embedding layer.
for i in [0, 1]:
node_before_layer_norm = None
node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i)
node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i)
if node_before_layer_norm_1 is not None:
# Add -----------+
# | |
# LayerNorm |
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_1
elif node_before_layer_norm_2 is not None:
# Add
# |
# LayerNorm --------+
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_2
if node_before_layer_norm is None:
continue
child = self.model.find_first_child_by_type(
node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False
)
if child is None:
continue
root_input = child.output[0]
skip_input_index = i
break
if skip_input_index is None:
return
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
[1 - skip_input_index, None, None, 0, 0, 0],
)
if qkv_nodes is None:
return
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
v_nodes = self.model.match_parent_path(
matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
)
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(_, _, reshape_v, add_v, matmul_v) = v_nodes
add_mask = None
add_mask_indices = []
qk_nodes = None
qk_nodes_1 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
[0, 0, 0, None, 0],
return_indice=add_mask_indices,
)
qk_nodes_2 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "MatMul"],
[0, 0],
)
if qk_nodes_1 is not None:
qk_nodes = qk_nodes_1
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]
(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
elif qk_nodes_2 is not None:
qk_nodes = qk_nodes_2
(_softmax_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
)
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
(_, _transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
(_transpose_k, _reshape_k, _, _, add_k, matmul_k) = k_nodes
if matmul_q.input[0] != root_input or matmul_k.input[0] != root_input or matmul_v.input[0] != root_input:
logger.debug("fuse_attention: expect to have same input to q, k and v matmul")
return
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if num_heads <= 0 or hidden_size <= 0:
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
return
attention_last_node = reshape_qkv
if add_mask is not None:
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
# If the model is exported with batch_size == 1, there is no Concat node
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
new_node = self.create_attention_node(
mask_index=None,
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=num_heads,
hidden_size=hidden_size,
input=root_input,
output=attention_last_node.output[0],
add_qk_str=None,
scale=None,
causal=(add_mask is not None),
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,301 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Tuple, Union
import numpy as np
from fusion_base import Fusion
from onnx import NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionAttentionVae(Fusion):
"""
Fuse Attention subgraph of Vae Decoder into one Attention node.
"""
def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int):
super().__init__(model, "Attention", ["Softmax"])
self.hidden_size = hidden_size
self.num_heads = num_heads
# Flags to show warning only once
self.num_heads_warning = True
self.hidden_size_warning = True
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> Tuple[int, int]:
"""Detect num_heads and hidden_size from a reshape node.
Args:
reshape_q (NodeProto): reshape node for Q
add_q (NodeProto): add node for Q
Returns:
Tuple[int, int]: num_heads and hidden_size
"""
concat = self.model.get_parent(reshape_q, 1)
if concat is None or len(concat.input) != 4:
return self.num_heads, self.hidden_size # Fall back to user specified value
value = self.model.get_constant_value(concat.input[2])
if not (value is not None and isinstance(value, np.ndarray) and value.size == 1):
return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = int(value)
if num_heads <= 0:
return self.num_heads, self.hidden_size # Fall back to user specified value
_, bias = self.model.get_constant_input(add_q)
if (bias is None) or (not isinstance(bias, np.ndarray)) or bias.ndim != 1:
return self.num_heads, self.hidden_size # Fall back to user specified value
hidden_size = bias.shape[0]
if self.num_heads > 0 and num_heads != self.num_heads:
if self.num_heads_warning:
logger.warning(
"Detected number of attention heads is %d. Ignore --num_heads %d", num_heads, self.num_heads
)
self.num_heads_warning = False # Do not show the warning more than once
if self.hidden_size > 0 and hidden_size != self.hidden_size:
if self.hidden_size_warning:
logger.warning("Detected hidden size is %d. Ignore --hidden_size %d", hidden_size, self.hidden_size)
self.hidden_size_warning = False # Do not show the warning more than once
return num_heads, hidden_size
def create_attention_node(
self,
q_matmul: NodeProto,
q_add: NodeProto,
k_matmul: NodeProto,
k_add: NodeProto,
v_matmul: NodeProto,
v_add: NodeProto,
num_heads: int,
hidden_size: int,
input_name: str,
output_name: str,
) -> Union[NodeProto, None]:
"""Create an Attention node.
Args:
q_matmul (NodeProto): MatMul node in fully connection for Q
q_add (NodeProto): Add bias node in fully connection for Q
k_matmul (NodeProto): MatMul node in fully connection for K
k_add (NodeProto): Add bias node in fully connection for K
v_matmul (NodeProto): MatMul node in fully connection for V
v_add (NodeProto): Add bias node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input_name (str): input name
output_name (str): output name
Returns:
Union[NodeProto, None]: the node created or None if failed.
"""
if q_matmul.input[0] != input_name or k_matmul.input[0] != input_name or v_matmul.input[0] != input_name:
logger.debug(
"For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
q_matmul.input[0],
k_matmul.input[0],
v_matmul.input[0],
)
return None
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
return None
q_weight_tensor = self.model.get_initializer(q_matmul.input[1])
k_weight_tensor = self.model.get_initializer(k_matmul.input[1])
v_weight_tensor = self.model.get_initializer(v_matmul.input[1])
if not (q_weight_tensor and k_weight_tensor and v_weight_tensor):
return None
q_bias_tensor = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
k_bias_tensor = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
v_bias_tensor = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
q_bias = numpy_helper.to_array(q_bias_tensor)
k_bias = numpy_helper.to_array(k_bias_tensor)
v_bias = numpy_helper.to_array(v_bias_tensor)
q_bias_shape = np.prod(q_bias.shape)
k_bias_shape = np.prod(k_bias.shape)
v_bias_shape = np.prod(v_bias.shape)
# Sometimes weights are stored in fp16
if q_weight_tensor.data_type == 10:
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
return None
q_weight = numpy_helper.to_array(q_weight_tensor)
k_weight = numpy_helper.to_array(k_weight_tensor)
v_weight = numpy_helper.to_array(v_weight_tensor)
# assert q and k have same shape as expected
if q_weight.shape != k_weight.shape or q_weight.shape != v_weight.shape:
return None
qw_in_size = q_weight.shape[0]
kw_in_size = k_weight.shape[0]
vw_in_size = v_weight.shape[0]
assert qw_in_size == kw_in_size and kw_in_size == vw_in_size
if hidden_size > 0 and hidden_size != qw_in_size:
raise ValueError(
f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
"Please provide a correct input hidden size or pass in 0"
)
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
# For 2d weights, the shapes would be [in_size, out_size].
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
qw_out_size = np.prod(q_weight.shape[1:])
qkv_weight = np.stack((q_weight, k_weight, v_weight), axis=1)
qkv_weight_dim = 3 * int(qw_out_size)
attention_node_name = self.model.create_node_name("Attention")
assert q_bias_shape == k_bias_shape == v_bias_shape
qkv_bias_dim = 0
qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0)
qkv_bias_dim = 3 * q_bias_shape
self.add_initializer(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight,
)
# No bias, use zeros
qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
qkv_bias_dim = 3 * hidden_size
self.add_initializer(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
dims=[qkv_bias_dim],
vals=qkv_bias,
)
attention_inputs = [
input_name,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output_name],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
self.increase_counter("Attention (self attention)")
return attention_node
def fuse(self, softmax_node, input_name_to_nodes, output_name_to_node):
matmul_qkv = self.model.find_first_child_by_type(softmax_node, "MatMul", input_name_to_nodes, recursive=False)
if matmul_qkv is None:
return
reshape_qkv = self.model.find_first_child_by_type(matmul_qkv, "Reshape", input_name_to_nodes, recursive=False)
if reshape_qkv is None:
return
transpose_qkv = self.model.find_first_child_by_type(
reshape_qkv, "Transpose", input_name_to_nodes, recursive=False
)
if transpose_qkv is None:
return
reshape_out = self.model.find_first_child_by_type(
transpose_qkv, "Reshape", input_name_to_nodes, recursive=False
)
if reshape_out is None:
return
matmul_out = self.model.find_first_child_by_type(reshape_out, "MatMul", input_name_to_nodes, recursive=False)
if matmul_out is None:
return
add_out = self.model.find_first_child_by_type(matmul_out, "Add", input_name_to_nodes, recursive=False)
if add_out is None:
return
transpose_out = self.model.find_first_child_by_type(add_out, "Transpose", input_name_to_nodes, recursive=False)
if transpose_out is None:
return
v_nodes = self.model.match_parent_path(
matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
)
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(_, _, _, add_v, matmul_v) = v_nodes
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
if qk_nodes is not None:
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None]
)
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
(_, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
(_, _, _, _, add_k, matmul_k) = k_nodes
attention_last_node = reshape_out
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, add_q)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
return
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
new_node = self.create_attention_node(
matmul_q,
add_q,
matmul_k,
add_k,
matmul_v,
add_v,
q_num_heads,
q_hidden_size,
matmul_q.input[0],
attention_last_node.output[0],
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True

View File

@ -0,0 +1,640 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import numpy as np
from fusion_attention import AttentionMask, FusionAttention
from onnx import TensorProto, helper
from onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class FusionBartAttention(FusionAttention):
"""
Fuse Bart Attention subgraph into one Attention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)
def check_runtime_shape_path(
self,
reshape_qkv_2,
reshape_qkv_1,
reshape_q_2,
reshape_k_2,
reshape_v_2,
root_input,
):
concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
if concat_qkv_2_path is None:
return False
concat_qkv_2 = concat_qkv_2_path[0]
reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None:
return False
_, gather_1, shape_1 = reshape_qkv_2_path_1
_, gather_2, shape_2 = reshape_qkv_2_path_2
if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
return False
reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0])
reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0])
if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None:
return False
if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name:
return False
reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None:
return False
mul_q = reshape_q_2_path[-1]
mul_k = reshape_k_2_path[-1]
mul_v = reshape_v_2_path[-1]
gather_1_out = gather_1.output[0]
if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
return False
return True
def check_runtime_shape_path_openai(
self,
reshape_qkv_2,
matmul_qkv,
add_qk,
matmul_qk,
add_q,
):
reshape_qkv_2_path = self.model.match_parent_path(
reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0]
)
if reshape_qkv_2_path is None:
return False
else:
if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]:
return False
matmul_qk_path_1 = self.model.match_parent_path(
matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0]
)
matmul_qk_path_2 = self.model.match_parent_path(
matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0]
)
if matmul_qk_path_1 is None or matmul_qk_path_2 is None:
return False
mul_1 = matmul_qk_path_1[0]
mul_2 = matmul_qk_path_2[0]
if mul_1.input[1] != mul_2.input[1]:
return False
if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]:
return False
# For decoder attentions only
if add_qk is not None:
add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1])
if add_qk_path is None:
return False
slice_q_path_1 = self.model.match_parent_path(
add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0]
)
slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
if slice_q_path_1 is None and slice_q_path_2 is None:
return False
_, unsqueeze_1, _, _ = slice_q_path_1
unsqueeze_2, _, _ = slice_q_path_2
if unsqueeze_1.input[0] != unsqueeze_2.input[0]:
return False
if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]:
return False
return True
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Track if fusion is occurring for OpenAI implementation of Whisper
model_impl_openai = False
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
[1, 1, 0, 0, 0, 0],
)
qkv_nodes_openai = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
)
if qkv_nodes is not None:
(
add_out,
matmul_out,
reshape_qkv_2,
transpose_qkv,
reshape_qkv_1,
matmul_qkv,
) = qkv_nodes
elif qkv_nodes_openai is not None:
qkv_nodes = qkv_nodes_openai
(
add_out,
matmul_out,
reshape_qkv_2,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
# Set model implementation to openai
model_impl_openai = True
else:
return
other_inputs = []
for input in normalize_node.input:
if input not in output_name_to_node:
continue
if input == qkv_nodes[0].output[0]:
continue
other_inputs.append(input)
if len(other_inputs) != 1:
return
root_input = other_inputs[0]
# Sometimes the input name to the attention MatMul nodes does not match the input name to the end
# SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
# nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
# children nodes for each of its output names.
"""
root_input
+---------------------------------------------------+
| |
| |
SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
"""
skip_layernorm = output_name_to_node[root_input]
# For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose
# child is the LayerNormalization node.
if skip_layernorm.op_type == "Add":
skip_layernorm = self.model.get_children(skip_layernorm)[0]
for output in skip_layernorm.output:
if not output:
continue
children = input_name_to_nodes[output]
children_types = [child.op_type for child in children]
if children_types.count("MatMul") >= 1:
root_input = output
break
graph_input_names = set([node.name for node in self.model.graph().input])
graph_output_names = set([node.name for node in self.model.graph().output])
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 0, 0, None],
)
v_nodes_openai = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 0, None],
)
v_nodes_with_past_self_attn = self.model.match_parent_path(
# Decoder attention with past value concatenated before MatMul
matmul_qkv,
["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, None],
)
v_nodes_with_past_cross_attn = self.model.match_parent_path(
# Decoder attention with past value directly used in MatMul
matmul_qkv,
["Reshape"],
[1],
)
v_nodes_with_past_cross_attn_openai = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "Reshape", "Transpose"],
[1, 0, 0, 0],
)
past_v, present_v = "", ""
reshape_v_2, add_v = None, None
if v_nodes is not None:
(reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
# For initial pass through encoder-decoder_with_past to get starting past values (beam search)
present_v = transpose_v.output[0]
elif v_nodes_openai is not None:
v_nodes = v_nodes_openai
(transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
# For initial pass through encoder-decoder_with_past to get starting past values (beam search)
# Find the child path to access the correct present_v values
# Openai impl provides present/past v values in 3D format
# whereas ort MultiHeadAttention expects v values in 4D, hence the
# additional Reshape and Transpose nodes are added
# For encoder attention types
# Add -> Reshape -> Transpose -> Present_V
reshape_path = self.model.match_child_path(
add_v,
["Reshape", "Transpose"],
exclude=[reshape_v_1],
)
# For decoder attention types
# add_v_node Reshape <- Transpose <-Past_V
# \ /
# \ /
# -> Concat <-
# |
# |--> Reshape -> Transpose -> Present_V
concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"])
if reshape_path is not None:
(_, transpose_add_v) = reshape_path
if transpose_add_v.output[0] in graph_output_names:
present_v = transpose_add_v.output[0]
if concat_path is not None:
(concat_v, _, transpose_concat_v) = concat_path
if transpose_concat_v.output[0] in graph_output_names:
present_v = transpose_concat_v.output[0]
concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0])
_, transpose_concat_v_in = concat_nodes
past_v = transpose_concat_v_in.input[0]
elif v_nodes_with_past_self_attn is not None:
(reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn
v_nodes = v_nodes_with_past_self_attn
past_v = concat_v.input[0]
present_v = concat_v.output[0]
elif (
v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names
):
v_nodes = v_nodes_with_past_cross_attn
past_v = v_nodes[-1].input[0]
present_v = v_nodes[-1].output[0]
if present_v not in graph_output_names:
identity_node_v = list(
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
)
present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
elif (
v_nodes_with_past_cross_attn_openai is not None
and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names
):
v_nodes = v_nodes_with_past_cross_attn_openai
past_v = v_nodes[-1].input[0]
present_v = v_nodes[-1].output[0]
if present_v not in graph_output_names:
identity_node_v = list(
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
)
present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
else:
logger.debug("fuse_attention: failed to match v path")
return
past_v = past_v if past_v in graph_input_names else ""
present_v = present_v if present_v in graph_output_names else ""
qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
qk_nodes_2 = self.model.match_parent_path(
matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0]
)
qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
add_qk = None
if qk_nodes_1 is not None:
_, matmul_qk = qk_nodes_1
qk_nodes = qk_nodes_1
elif qk_nodes_2 is not None:
_, _, add_qk, _, matmul_qk = qk_nodes_2
qk_nodes = qk_nodes_2
elif qk_nodes_2_openai is not None:
_, add_qk, matmul_qk = qk_nodes_2_openai
qk_nodes = qk_nodes_2_openai
else:
return
q_nodes = self.model.match_parent_path(
matmul_qk,
["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
[0, 0, 0, 0, 0, 1],
)
q_nodes_openai = self.model.match_parent_path(
matmul_qk,
["Mul", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 1],
)
reshape_q_2 = None
if q_nodes is not None:
reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes
elif q_nodes_openai is not None:
q_nodes = q_nodes_openai
mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes
else:
return
k_nodes_with_bias = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 0, 0, 0, 1],
)
k_nodes_with_bias_openai = self.model.match_parent_path(
matmul_qk,
["Mul", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 0],
)
k_nodes_no_bias = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 0, 0],
)
k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path(
# Decoder attention with past key concatenated before MatMul
matmul_qk,
["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 1, 0, 0],
)
k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path(
# Decoder attention with past key directly used in MatMul
matmul_qk,
["Transpose", "Reshape"],
[1, 0],
)
k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path(
# Decoder attention with past key directly used in MatMul
matmul_qk,
["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
[1, 0, 0, 0, 0],
)
past_k, present_k = "", ""
reshape_k_2, reshape_k_1, matmul_k = None, None, None
if k_nodes_with_bias is not None:
_, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias
k_nodes = k_nodes_with_bias
elif k_nodes_with_bias_openai is not None:
mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai
k_nodes = k_nodes_with_bias_openai
present_k = matmul_k.output[0]
# Find the child path to access the correct present_k values
# Openai impl provides present/past k values in 3D format
# whereas ort MultiHeadAttention expects k values in 4D, hence the
# additional Reshape and Transpose nodes are added
# For encoder attention types
# Matmul -> Reshape -> Transpose -> Present_K
reshape_path = self.model.match_child_path(
matmul_k,
["Reshape", "Transpose"],
exclude=[reshape_k_1],
)
# For decoder attention types
# matmul_k_node Reshape <- Transpose <- Past_K
# \ /
# \ /
# -> Concat <-
# |
# |--> Reshape -> Transpose -> Present_K
concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"])
if reshape_path is not None:
(_, transpose_matmul_k) = reshape_path
if transpose_matmul_k.output[0] in graph_output_names:
present_k = transpose_matmul_k.output[0]
if concat_path is not None:
(concat_k, _, transpose_concat_k) = concat_path
if transpose_concat_k.output[0] in graph_output_names:
present_k = transpose_concat_k.output[0]
concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0])
_, transpose_concat_k_in = concat_nodes
past_k = transpose_concat_k_in.input[0]
elif k_nodes_no_bias is not None:
_, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias
k_nodes = k_nodes_no_bias
# For initial pass through encoder-decoder_with_past to get starting past values (beam search)
present_k = transpose_k_1.output[0]
elif k_nodes_no_bias_with_past_self_attn is not None:
_, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn
k_nodes = k_nodes_no_bias_with_past_self_attn
past_k = concat_k.input[0]
present_k = concat_k.output[0]
elif (
k_nodes_no_bias_with_past_cross_attn is not None
and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names
):
k_nodes = k_nodes_no_bias_with_past_cross_attn
past_k = k_nodes[-1].input[0]
present_k = k_nodes[-1].output[0]
if present_k not in graph_output_names:
identity_node_k = list(
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
)
present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
elif (
k_nodes_no_bias_with_past_cross_attn_openai is not None
and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names
):
k_nodes = k_nodes_no_bias_with_past_cross_attn_openai
past_k = k_nodes[-1].input[0]
present_k = k_nodes[-1].output[0]
if present_k not in graph_output_names:
identity_node_k = list(
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
)
present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
else:
return
past_k = past_k if past_k in graph_input_names else ""
present_k = present_k if present_k in graph_output_names else ""
if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn):
# Create empty Add node for attention graph
bias_dim = self.model.get_initializer(add_v.input[0]).dims[0]
empty_bias_name = "empty_bias"
empty_tensor = self.model.get_initializer(empty_bias_name)
if empty_tensor is None:
self.add_initializer(
empty_bias_name,
TensorProto.FLOAT,
dims=[bias_dim],
vals=np.array([0.0] * bias_dim, dtype=np.float32),
)
add_name = self.model.create_node_name("Add")
add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name)
if (
model_impl_openai
and not past_k
and not self.check_runtime_shape_path_openai(
reshape_qkv_2,
matmul_qkv,
add_qk,
matmul_qk,
add_q,
)
):
return
elif (
not model_impl_openai
and not past_k
and not self.check_runtime_shape_path(
reshape_qkv_2,
reshape_qkv_1,
reshape_q_2,
reshape_k_2,
reshape_v_2,
root_input,
)
):
return
three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals()
one_root_input = (
not three_root_inputs
and matmul_k.input[0] == root_input
and matmul_q.input[0] == root_input
and matmul_v.input[0] == root_input
)
two_root_inputs = (
not three_root_inputs
and matmul_q.input[0] == root_input
and matmul_k.input[0] == matmul_v.input[0]
and matmul_k.input[0] != matmul_q.input[0]
)
# There are 5 types of attention:
# 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1
# 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2
# 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value
# 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1
# 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1
encoder_attention = one_root_input and qk_nodes == qk_nodes_1
decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai)
decoder_attention_with_past = (
(encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v
)
decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1
decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1
# For decoder_attention, the attention mask needs to be included in the attention node
mask_index = None
if decoder_attention:
mask_nodes_bart = self.model.match_parent_path(
add_qk,
["Where"],
[1],
)
mask_nodes_whisper = self.model.match_parent_path(
add_qk,
["Expand", "Unsqueeze", "Unsqueeze", "Where"],
[1, 0, 0, 0],
)
if mask_nodes_whisper is not None:
mask_index = mask_nodes_whisper[0].output[-1]
elif mask_nodes_bart is not None:
mask_index = mask_nodes_bart[0].output[-1]
if (
encoder_attention
or decoder_attention
or decoder_attention_with_past
or decoder_cross_attention
or decoder_cross_attention_with_past
):
attention_last_node = reshape_qkv_2
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1)
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
return
new_node = None
if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
# Note: Decoder attention with past key and past value is fused as multihead attention
# rather than attention because multihead attention supports separate past key and past
# value whereas attention supports concatenated past key and past value.
new_node = (
self.create_multihead_attention_node(
matmul_q,
matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k,
matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v,
add_q,
add_k if decoder_cross_attention or decoder_attention_with_past else None,
add_v if decoder_cross_attention or decoder_attention_with_past else None,
num_heads,
hidden_size,
attention_last_node.output[0],
past_k=past_k if decoder_attention_with_past else "",
past_v=past_v if decoder_attention_with_past else "",
present_k=present_k,
present_v=present_v,
packed_qkv=decoder_attention_with_past,
)
if self.use_multi_head_attention
else None
)
else:
# Temporarily set multihead attention flag to false
use_multi_head_attention_ground_truth = self.use_multi_head_attention
self.use_multi_head_attention = False
new_node = self.create_attention_node(
None,
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
root_input,
attention_last_node.output[0],
add_qk_str=mask_index if decoder_attention else None,
past_k=past_k,
past_v=past_v,
present_k=present_k,
present_v=present_v,
)
self.use_multi_head_attention = use_multi_head_attention_ground_truth
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
# When using multihead attention, keep MatMul nodes in original graph
if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
if q_nodes[-1].op_type == "MatMul":
q_nodes.pop()
if k_nodes[-1].op_type == "MatMul":
k_nodes.pop()
if v_nodes[-1].op_type == "MatMul":
v_nodes.pop()
if self.disable_multi_head_attention_bias and (
decoder_cross_attention or decoder_cross_attention_with_past
):
if q_nodes[-1].op_type == "Add":
q_nodes.pop()
if k_nodes[-1].op_type == "Add":
k_nodes.pop()
if v_nodes[-1].op_type == "Add":
v_nodes.pop()
self.nodes_to_remove.extend(q_nodes)
self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
self.prune_graph = True

View File

@ -0,0 +1,137 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from collections import defaultdict
from logging import getLogger
from typing import Any, Dict, List, Optional, Sequence, Union
import numpy as np
from onnx import NodeProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class Fusion:
"""
Base class for Graph Fusion
"""
def __init__(
self,
model: OnnxModel,
fused_op_type: str,
search_op_types: Union[str, List[str]],
description: str = "",
):
self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
self.fused_op_type: str = fused_op_type
self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
self.model: OnnxModel = model
self.nodes_to_remove: List = []
self.nodes_to_add: List = []
self.prune_graph: bool = False
self.node_name_to_graph_name: dict = {}
self.this_graph_name: Optional[str] = None
# It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
self.fused_count: defaultdict = defaultdict(int)
def increase_counter(self, fused_op_name: str):
"""
Increase counter of a fused operator.
"""
self.fused_count[fused_op_name] += 1
def fuse(
self,
node: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
output_name_to_node: Dict[str, NodeProto],
):
"""Interface for fusion that starts from a node"""
raise NotImplementedError
def apply(self):
"""
Apply graph fusion on the whole model graph.
It searched nodes of given operators, and start fusion on each of those nodes.
"""
logger.debug(f"start {self.description} fusion...")
input_name_to_nodes = self.model.input_name_to_nodes()
output_name_to_node = self.model.output_name_to_node()
# This assumes that two search ops will not be fused at same time!
for search_op_type in self.search_op_types:
for node in self.model.get_nodes_by_op_type(search_op_type):
graph = self.model.get_graph_by_node(node)
if graph is None:
raise Exception("Can not find node in any graph")
self.this_graph_name = graph.name
self.fuse(node, input_name_to_nodes, output_name_to_node)
op_list = [node.op_type for node in self.nodes_to_add]
if self.fused_count:
for key, value in self.fused_count.items():
if value:
logger.info(f"Fused {key}: {value}")
else:
count = op_list.count(self.fused_op_type)
if count > 0:
logger.info(f"Fused {self.description}: {count}")
self.model.remove_nodes(self.nodes_to_remove)
self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
if self.prune_graph:
self.model.prune_graph()
elif self.nodes_to_remove or self.nodes_to_add:
self.model.update_graph()
def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
if raw:
np_type = helper.tensor_dtype_to_np_dtype(data_type)
if not isinstance(vals, np.ndarray):
bytes = np.array(vals, dtype=np_type).tobytes()
else:
bytes = vals.astype(np_type).tobytes()
tensor = helper.make_tensor(
name=name,
data_type=data_type,
dims=dims,
vals=bytes,
raw=True,
)
else:
tensor = helper.make_tensor(
name=name,
data_type=data_type,
dims=dims,
vals=vals,
raw=False,
)
self.model.add_initializer(tensor, self.this_graph_name)
return tensor
def add_nodes_to_remove(self, nodes: List[NodeProto]):
# Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths).
# When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B
# is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are
# iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first.
# Since path A's shared nodes are removed, path B's shared nodes are not removed because they
# were previously removed for path A. This causes an error to print in remove_node that a node
# has failed to be removed.
#
# To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`.
# We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could
# be scenarios where the nodes need to be removed in a specific order and converting to a set would
# lose this order.
for node in nodes:
if node not in self.nodes_to_remove:
self.nodes_to_remove.append(node)
def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]):
for node in nodes:
if node not in self.nodes_to_remove and node not in nodes_to_keep:
self.nodes_to_remove.append(node)

View File

@ -0,0 +1,58 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict
from fusion_base import Fusion
from numpy import ndarray
from onnx import helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionBiasAdd(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "BiasAdd", "Add")
def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
Fuse Add bias and Add skip connection into BiasAdd
"""
nodes = self.model.match_parent_path(
add_node,
["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
[0, None, 0, 0, 0],
output_name_to_node,
)
if nodes is None:
return
bias_node = nodes[0]
skip_layer_norm = nodes[-1]
# Check skip connection is from SkipLayerNormalization output
if add_node.input[1] not in skip_layer_norm.output:
return
bias_index, bias_value = self.model.get_constant_input(bias_node)
if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
return
if bias_value.ndim != 1:
return
self.nodes_to_remove.extend([add_node, bias_node])
node_name = self.model.create_node_name("BiasAdd")
fused_node = helper.make_node(
"BiasAdd",
inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
outputs=[add_node.output[0]],
name=node_name,
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[node_name] = self.this_graph_name

View File

@ -0,0 +1,66 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionBiasGelu(Fusion):
def __init__(self, model: OnnxModel, is_fastgelu):
if is_fastgelu:
super().__init__(model, "FastGelu", "FastGelu", "add bias")
else:
super().__init__(model, "BiasGelu", "Gelu")
def fuse(self, node, input_name_to_nodes, output_name_to_node):
gelu_op_type = node.op_type
fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
if len(node.input) != 1:
return
nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
if nodes is None:
return
(add, matmul) = nodes
bias_weight = None
# bias should be one dimension
bias_index = -1
for i, input in enumerate(add.input):
initializer = self.model.get_initializer(input)
if initializer is None:
continue
bias_index = i
bias_weight = NumpyHelper.to_array(initializer)
break
if bias_weight is None:
return
if len(bias_weight.shape) != 1:
return
subgraph_nodes = [node, add]
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
):
return
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
fuse_op_type,
inputs=[matmul.output[0], add.input[bias_index]],
outputs=node.output,
name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name

View File

@ -0,0 +1,111 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict
from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionBiasSplitGelu(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "BiasSplitGelu", "Gelu")
def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
[root] --->Add --------------------> Slice ---------------> Mul -->
| ^ ^
| | |
+----------------------------+---Slice --> Gelu---+
| | ^
| |-----|
| | |
| Mul Mul
| ^ ^
v | |
Shape ---> Gather --> Add --> Div --+
"""
if gelu_node.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[gelu_node.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_after_gelu = children[0]
slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node)
if slice_before_gelu is None:
return
if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3:
return
add_output = slice_before_gelu.input[0]
start_index_nodes = self.model.match_parent_path(
slice_before_gelu,
["Div", "Add", "Gather", "Shape", "Add"],
[1, 0, 0, 0, 0],
output_name_to_node, # Mul(1) is optional
)
if start_index_nodes is None:
start_index_nodes = self.model.match_parent_path(
slice_before_gelu,
["Mul", "Div", "Add", "Gather", "Shape", "Add"],
[1, 0, 0, 0, 0, 0],
output_name_to_node,
)
if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output:
return
end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node)
if (
end_index_nodes is None or end_index_nodes[1] not in start_index_nodes
): # the Div is parent of both two Mul nodes
return
slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node)
if slice_before_mul is None:
return
if (
slice_before_mul.input[2] != slice_before_gelu.input[1]
): # end index of slice_before_mul is start index of slice_before_gelu
return
subgraph_nodes = [
*start_index_nodes,
end_index_nodes[0],
mul_after_gelu,
gelu_node,
slice_before_mul,
slice_before_gelu,
]
subgraph_output = mul_after_gelu.output[0]
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
):
logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.")
return
add_node = start_index_nodes[-1]
bias_index, _value = self.model.get_constant_input(add_node)
if not isinstance(bias_index, int):
return
self.nodes_to_remove.extend(subgraph_nodes)
node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu")
fused_node = helper.make_node(
"BiasSplitGelu",
inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]],
outputs=[subgraph_output],
name=node_name,
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[node_name] = self.this_graph_name

View File

@ -0,0 +1,143 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from fusion_attention import AttentionMask, FusionAttention
from onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class FusionConformerAttention(FusionAttention):
"""
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(model, hidden_size, num_heads, attention_mask)
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
)
if qkv_nodes is not None:
(
_,
_,
reshape_qkv,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qkv path")
return
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 1, 0, 0, 1],
)
add_v = None
if v_nodes is not None:
(concat_v, _, _, add_v, matmul_v) = v_nodes
concat_parent = self.model.get_parent(concat_v, 0, None)
present_v = concat_v.output[0]
past_v = concat_parent.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match v path")
return
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
if qk_nodes is not None:
_, add_qk, matmul_qk = qk_nodes
else:
logger.debug("fuse_conformer_attention: failed to match qk path")
return
q_nodes = self.model.match_parent_path(
matmul_qk,
["Div", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 1],
)
if q_nodes is not None:
_, _, reshape_q, add_q, matmul_q = q_nodes
else:
logger.debug("fuse_conformer_attention: failed to match q path")
return
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, 1],
)
matmul_k = None
if k_nodes is not None:
_, concat_k, _, _, add_k, matmul_k = k_nodes
concat_parent = self.model.get_parent(concat_k, 0, None)
past_k = concat_parent.output[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_conformer_attention: failed to match k path")
return
attention_last_node = reshape_qkv
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
return
new_node = self.create_multihead_attention_node(
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
attention_last_node.output[0],
add_qk=add_qk.input[1],
past_k=past_k,
past_v=past_v,
present_k=present_k,
present_v=present_v,
)
if new_node is None:
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
# When using multihead attention, keep MatMul nodes in original graph
if q_nodes[-1].op_type == "MatMul":
q_nodes.pop()
if k_nodes[-1].op_type == "MatMul":
k_nodes.pop()
if v_nodes[-1].op_type == "MatMul":
v_nodes.pop()
self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
self.prune_graph = True

View File

@ -0,0 +1,811 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, List, Optional, Tuple, Union
from fusion_base import Fusion
from fusion_utils import FusionUtils
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionEmbedLayerNoMask(Fusion):
"""
Fuse embedding layer into one node (EmbedLayerNormalization).
It supports the following model types: BERT, DistilBert, ALBert.
"""
def __init__(self, model: OnnxModel, description: str = "no mask"):
super().__init__(
model,
"EmbedLayerNormalization",
["LayerNormalization", "SkipLayerNormalization"],
description,
)
self.utils = FusionUtils(model)
self.shape_infer = None
self.shape_infer_done = False
# The following will be reset in each fuse call of FusionEmbedLayerNormalization
self.attention = None
self.embed_node = None
def match_two_gather(self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]:
gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
if gather_0_path is None:
return None
gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
if gather_1_path is None:
return None
return gather_0_path[0], gather_1_path[0]
def check_attention_subgraph(
self,
layernorm: NodeProto,
input_name_to_nodes: Dict[str, List[NodeProto]],
is_distil_bert: bool,
) -> bool:
"""Check that LayerNormalization has a child of Attention node or subgraph like Attention.
Args:
layernorm (NodeProto): LayerNormalization node
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
is_distil_bert (bool): whether it is DistilBert or not
Returns:
bool: whether there is Attention node or subgraph like Attention
"""
self.attention = self.model.find_first_child_by_type(
layernorm, "Attention", input_name_to_nodes, recursive=False
)
if self.attention is not None:
return True
if layernorm.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[layernorm.output[0]]
children_types = sorted([child.op_type for child in children])
# Try find MultiHeadAttention
if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
for node in children:
if node.op_type == "SkipLayerNormalization":
path1 = self.model.match_parent_path(
node,
["Add", "MatMul", "MultiHeadAttention", "MatMul"],
[None, None, 0, 0],
)
if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
self.cross_attention = path1[2]
return True
# In case user disables attention fusion, check whether subgraph looks like Attention.
# For Albert, there is MatMul+Add after embedding layer before attention.
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
grandchildren = input_name_to_nodes[children[0].output[0]]
if (
len(grandchildren) == 1
and grandchildren[0].op_type == "Add"
and grandchildren[0].output[0] in input_name_to_nodes
):
nodes = input_name_to_nodes[grandchildren[0].output[0]]
for node in nodes:
if node.op_type == "Attention":
self.attention = node
return True
children_types = sorted([child.op_type for child in nodes])
# Two Shape nodes might be merged by ORT
if is_distil_bert:
# SkipLayerNormailization might exist when model has been optimized by ORT first.
if (
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
):
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
else:
if children_types != [
"Add",
"MatMul",
"MatMul",
"MatMul",
] and children_types != [
"MatMul",
"MatMul",
"MatMul",
"SkipLayerNormalization",
]:
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
return True
def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
""" Match position embedding path from input_ids to Gather for DistilBert.
Pattern is like the following:
(input_ids)
|
Shape
| \
| Gather (indices=1)
| |
| Cast (optional)
| |
| Range (start=0, end=*, delta=1)
| |
| Unsqueeze
| /
Expand
|
Gather
"""
# remove after tests pass
path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
if path1 is None:
path1 = self.model.match_parent_path(
position_embedding_gather,
["Expand", "Where", "Reshape", "Shape"],
[1, 1, 2, 0],
)
if path1 is None:
return False
expand, shape = path1[0], path1[-1]
if shape.input[0] != input_ids:
return False
_, path2, _ = self.model.match_parent_paths(
expand,
[
(["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
(["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
],
output_name_to_node,
)
if path2 is None:
return False
range_node = path2[1]
if not (
self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
):
return False
gather_node = path2[-2]
if not (self.utils.check_node_input_value(gather_node, 1, 1)):
return False
shape_node = path2[-1]
if shape_node.input[0] != input_ids:
return False
return True
def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
"""Match position embedding path from input_ids to Gather for Roberta.
Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
(input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
| ^
V |
+------------------------------+
Roberta new pattern from transformers v4.9:
(input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
| ^
V |
+-------------------------------------------+
start_node = position_embedding_gather
start_index = 1
# match optional Cast node.
parent = self.model.get_parent(start_node, start_index, output_name_to_node)
if parent is None:
return
if parent.op_type == "Cast":
if OnnxModel.get_node_attribute(parent, "to") != 7:
return
start_node = parent
start_index = 0
i, path, return_indices = self.model.match_parent_paths(
start_node,
[ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
(['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
output_name_to_node)
if path is not None:
# constant input of Add shall be 1.
i, value = self.model.get_constant_input(path[0])
if value != 1:
return False
_, self.padding_word_id = self.model.get_constant_input(path[-1])
return input_ids == path[-1].input[0]
"""
return False
def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
""" Match position embedding path from input_ids to Gather for BERT.
BERT Embedding Layer Pattern:
(input_ids)
/ \
/ Shape
/ |
/ Gather (indices=1)
/ |
/ Add (optional, B=0)
/ |
Gather (segment_ids) Unsqueeze (axes=0)
\\ | |
\\ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
\\ / |
Add Gather
\\ /
Add
|
LayerNormalization
"""
path = self.model.match_parent_path(
position_embedding_gather,
["Slice", "Unsqueeze"],
[1, 2],
output_name_to_node,
)
if path is None:
return False
slice, unsqueeze = path
slice_weight = self.model.get_constant_value(slice.input[0])
if not (
slice_weight is not None
and len(slice_weight.shape) == 2
and slice_weight.shape[0] == 1
and self.utils.check_node_input_value(slice, 1, [0])
and self.utils.check_node_input_value(slice, 3, [1])
and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
):
return False
opset_version = self.model.get_opset_version()
if opset_version < 13:
if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
return False
else:
if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
return False
node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
if node is None:
return False
if node.op_type == "Add":
if not self.utils.check_node_input_value(node, 1, 0):
return False
gather = self.model.get_parent(node, 0, output_name_to_node)
else:
gather = node
if gather is None or gather.op_type != "Gather":
return False
if not (self.utils.check_node_input_value(gather, 1, 1)):
return False
shape = self.model.get_parent(gather, 0, output_name_to_node)
if shape is None or shape.op_type != "Shape":
return False
return input_ids == shape.input[0]
def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
return True
# TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
# related: https://github.com/huggingface/transformers/issues/10736
# if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
# return True
if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
return True
return False
def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
"""Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
input_ids = word_embedding_gather.input[1]
segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
position_ids = position_embedding_gather.input[1]
if not self.shape_infer_done:
self.shape_infer = self.model.infer_runtime_shape(update=True)
self.shape_infer_done = True
if self.shape_infer is not None:
input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
assert input_ids_shape and position_ids_shape
if not (
len(input_ids_shape) == 2
and len(position_ids_shape) == 2
and input_ids_shape[1] == position_ids_shape[1]
):
logger.info(
f"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {input_ids_shape} vs {position_ids_shape}"
)
return False
if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
logger.info(
f"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {input_ids_shape} != {self.shape_infer.get_edge_shape(segment_ids)}"
)
return False
word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
if word_embedding_table is None or len(word_embedding_table.shape) != 2:
logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
return False
position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
if (
position_embedding_table is None
or len(position_embedding_table.shape) != 2
or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
):
logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
return False
if segment_ids:
segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
if (
segment_embedding_table is None
or len(segment_embedding_table.shape) != 2
or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
):
logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
return False
# In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
# TODO: use other information (like initializer names) to identify different embedding weights automatically.
if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
logger.warning(
f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
)
if segment_ids:
if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
logger.warning(
f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
)
if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
logger.warning(
f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
)
return True
def cast_to_int32(self, input_name: str) -> Tuple[str, Union[None, NodeProto]]:
"""Cast a graph input or node input to int32.
Args:
input_name (str): name of graph input or node input
Returns:
A tuple of casted input name and the cast node.
int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
"""
input_cast_node = None
graph_input = self.model.find_graph_input(input_name)
if graph_input is not None:
if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
else:
int32_output = input_name
else:
int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
return int32_output, input_cast_node
def create_fused_node(
self,
input_ids: str,
layernorm: NodeProto,
word_embedding_gather: NodeProto,
position_embedding_gather: NodeProto,
segment_embedding_gather: Union[None, NodeProto],
position_ids: Optional[str] = None,
embedding_sum_output=False,
embedding_sum_name=None,
):
"""Create an EmbedLayerNormalization node. Note that segment embedding is optional.
Args:
input_ids (str): input_ids for word embeddings
layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
word_embedding_gather (NodeProto): the Gather node for word embedding
position_embedding_gather (NodeProto): the Gather node for position embedding
segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
Returns:
NodeProto: the EmbedLayerNormalization node created.
"""
nodes_to_add = []
input_ids, _ = self.cast_to_int32(input_ids)
node_name = self.model.create_node_name("EmbedLayerNormalization")
if layernorm.op_type == "LayerNormalization":
gamma = layernorm.input[1]
beta = layernorm.input[2]
else: # SkipLayerNormalization
gamma = layernorm.input[2]
beta = layernorm.input[3]
embed_node_inputs = None
if segment_embedding_gather is not None:
segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
embed_node_inputs = [
input_ids,
segment_ids,
word_embedding_gather.input[0],
position_embedding_gather.input[0],
segment_embedding_gather.input[0],
gamma,
beta,
]
else: # no segment embedding
embed_node_inputs = [
input_ids,
"",
word_embedding_gather.input[0],
position_embedding_gather.input[0],
"",
gamma,
beta,
]
if position_ids is not None:
# Adding an empty input for mask before position_ids
embed_node_inputs.append("")
position_ids, _ = self.cast_to_int32(position_ids)
embed_node_inputs.append(position_ids)
embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
if embedding_sum_output:
name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
embed_node_outputs.append(name)
embed_node = helper.make_node(
"EmbedLayerNormalization",
embed_node_inputs,
outputs=embed_node_outputs,
name=node_name,
)
embed_node.domain = "com.microsoft"
# Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
for att in layernorm.attribute:
if att.name == "epsilon":
embed_node.attribute.extend([att])
# Set default value to 1e-12 if no attribute is found.
# OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
if len(embed_node.attribute) == 0:
embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
# Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
nodes_to_add.append(embed_node)
for node in nodes_to_add:
self.node_name_to_graph_name[node.name] = self.this_graph_name
self.nodes_to_add.extend(nodes_to_add)
self.embed_node = embed_node
return embed_node
def finish_fusion(self, layernorm, embed_node):
self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
# use prune graph to remove nodes that is not needed
self.prune_graph = True
def is_skip_layer_norm_with_sum_output(self, node):
return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
def fuse_gpt2(
self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
):
# graph checks
# gpt2 has optional segment embedding, subgraph pattern is like
# input_ids position_ids
# | |
# token_ids Gather Gather
# | \ /
# Gather (optional) Add _ _ _ _ _
# \ | |
# LayerNormalization |
# | |
# Attention |
# | |
# Matmul |
# | /
# Add /
# \ /
# Add
two_gather = self.match_two_gather(add_before_layernorm)
if two_gather is None:
return False
word_embedding_gather, position_embedding_gather = two_gather
input_ids = word_embedding_gather.input[1]
position_ids = position_embedding_gather.input[1]
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
return False
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
return False
# If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
# If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
# If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
# is the (optional) fourth index output of this node.
# When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
if layernorm.op_type == "SkipLayerNormalization":
need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
sum_output_index = 3
node_with_sum_output = layernorm
sum_output = layernorm.output[3] if need_embedding_sum_output else None
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
else: # layernorm.op_type == "LayerNormalization"
node_with_sum_output = add_before_layernorm
sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
sum_output = (
add_before_layernorm.output[sum_output_index]
if len(add_before_layernorm.output) > sum_output_index
else None
)
is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
is_sum_used_by_multiple_nodes = (
sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
)
need_embedding_sum_output = (sum_output is not None) and (
add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
)
# make the fused node
embed_node = self.create_fused_node(
input_ids,
layernorm,
word_embedding_gather,
position_embedding_gather,
optional_segment_gather,
position_ids,
embedding_sum_output=need_embedding_sum_output,
embedding_sum_name=sum_output if is_sum_graph_output else None,
)
if need_embedding_sum_output:
node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
if not is_sum_graph_output:
self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
self.finish_fusion(layernorm, embed_node)
return True
def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
"""Fuse embedding layer for DistilBert
Args:
layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
"""
# DistilBert has no segment embedding, subgraph pattern is like
# input_ids
# | \
# | (position_embedding_subgraph)
# | |
# Gather Gather
# \ /
# Add
# |
# LayerNormalization
two_gather = self.match_two_gather(add_before_layernorm)
if two_gather is None:
return False
word_embedding_gather, position_embedding_gather = two_gather
input_ids = word_embedding_gather.input[1]
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
return False
if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
return False
if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
return False
embed_node = self.create_fused_node(
input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
)
self.finish_fusion(layernorm, embed_node)
return True
def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
"""Fuse embedding layer for Bert
Args:
layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
"""
add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
if add_2_gather is None:
return False
two_gather = self.match_two_gather(add_2_gather[0])
if two_gather is None:
return False
word_embedding_gather, segment_embedding_gather = two_gather
input_ids = word_embedding_gather.input[1]
if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
return False
position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
if position_embedding_path is None:
return False
position_embedding_gather = position_embedding_path[0]
if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
return False
# position and segment are switched
temp = segment_embedding_gather
segment_embedding_gather = position_embedding_gather
position_embedding_gather = temp
if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
return False
embed_node = self.create_fused_node(
input_ids,
layernorm,
word_embedding_gather,
position_embedding_gather,
segment_embedding_gather,
)
self.finish_fusion(layernorm, embed_node)
return True
def fuse(self, node, input_name_to_nodes, output_name_to_node):
first_add_path = self.model.match_parent_path(node, ["Add"], [0])
if node.op_type == "LayerNormalization":
if first_add_path is None:
return
add_before_layernorm = first_add_path[0]
optional_segment_gather = None
else: # SkipLayerNormalization
gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
if gather_0_path is None and gather_1_path is not None:
if first_add_path is None:
return
add_before_layernorm = first_add_path[0]
optional_segment_gather = gather_1_path[0]
elif gather_0_path is not None and gather_1_path is None:
first_add_path = self.model.match_parent_path(node, ["Add"], [1])
if first_add_path is None:
return
add_before_layernorm = first_add_path[0]
optional_segment_gather = gather_0_path[0]
else:
add_before_layernorm = node # Add is fused into SkipLayerNormalization
optional_segment_gather = None
if self.fuse_gpt2(
node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
):
return
if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
return
if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
return
class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
def __init__(self, model: OnnxModel, use_mask_index=False):
super().__init__(model, "with mask")
self.use_mask_index = use_mask_index
def replace_mask(self, mask_int32, attention_nodes):
# Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
# segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
embed_node = self.embed_node
if len(embed_node.input) == 7:
embed_node.input.append(mask_int32)
logger.debug("append mask to %s", embed_node.name)
elif len(embed_node.input) > 7 and not embed_node.input[7]:
embed_node.input[7] = mask_int32
logger.debug("replace mask in %s", embed_node.name)
else:
logger.debug("skip mask in %s", embed_node.name)
return
for attention_node in attention_nodes:
logger.debug("update mask_index in %s", attention_node.name)
if attention_node.op_type == "Attention":
attention_node.input[3] = embed_node.output[1]
elif attention_node.op_type == "MultiHeadAttention":
attention_node.input[4] = embed_node.output[1]
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Reset attention and embed_node so that we know fusion is successful when they are not None.
self.attention = None
self.cross_attention = None
self.embed_node = None
super().fuse(node, input_name_to_nodes, output_name_to_node)
if self.embed_node is None:
return
if not self.use_mask_index:
logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
self.increase_counter("EmbedLayerNormalization(no mask)")
return
if self.attention is None and self.cross_attention is None:
logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
self.increase_counter("EmbedLayerNormalization(no mask)")
return
if self.attention:
mask_int32 = self.attention.input[3]
else:
mask_int32 = self.cross_attention.input[4]
children_nodes = input_name_to_nodes[mask_int32]
if self.model.find_graph_input(mask_int32):
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
self.replace_mask(mask_int32, attention_nodes)
self.increase_counter("EmbedLayerNormalization(with mask)")
return
if mask_int32 not in output_name_to_node:
logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
self.increase_counter("EmbedLayerNormalization(no mask)")
return
node = output_name_to_node[mask_int32]
if node.op_type in ["ReduceSum", "Cast"]:
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
if node.op_type == "ReduceSum":
mask_int32 = node.input[0]
if len(children_nodes) == len(attention_nodes):
self.nodes_to_remove.append(node)
self.replace_mask(mask_int32, attention_nodes)
self.increase_counter("EmbedLayerNormalization(with mask)")

View File

@ -0,0 +1,360 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, Optional
from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionFastGelu(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "FastGelu", "Tanh")
def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if self.fuse_1(tanh_node, input_name_to_nodes, output_name_to_node):
return
if self.fuse_2(tanh_node, input_name_to_nodes, output_name_to_node):
return
if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
return
def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
"""
Fuse Gelu with tanh into one node:
+---------------------------+
| |
| v
[root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul
| (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
| |
+------> Mul(B=0.5)--------------------------------------------+
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if tanh_node.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[tanh_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return
add_after_tanh = children[0]
if not self.model.has_constant_input(add_after_tanh, 1.0):
return
if add_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[add_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_after_tanh = children[0]
mul_half = self.model.match_parent(mul_after_tanh, "Mul", None, output_name_to_node)
if mul_half is None:
return
i = self.model.find_constant_input(mul_half, 0.5)
if i < 0:
return
root_input = mul_half.input[0 if i == 1 else 1]
# root_node could be None when root_input is graph input
root_node = self.model.get_parent(mul_half, 0 if i == 1 else 1, output_name_to_node)
mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
return
i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
if i < 0:
return
add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
if add_before_tanh is None:
return
mul_after_pow = self.model.match_parent(
add_before_tanh,
"Mul",
None,
output_name_to_node,
exclude=[root_node] if root_node else [],
)
if mul_after_pow is None:
return
i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
if i < 0:
return
pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
if pow is None:
return
if not self.model.has_constant_input(pow, 3.0):
return
if pow.input[0] != root_input:
return
subgraph_nodes = [
mul_after_tanh,
mul_half,
add_after_tanh,
tanh_node,
mul_before_tanh,
add_before_tanh,
mul_after_pow,
pow,
]
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
[mul_after_tanh.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_input],
outputs=mul_after_tanh.output,
name=self.model.create_node_name("FastGelu"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True
def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
"""
This pattern is from Tensorflow model.
Fuse Gelu with tanh into one node:
+---------------------------+
| |
| v
[root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul(B=0.5)-->Mul-->
| (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
| |
+---------------------------------------------------------------------------+
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
if tanh_node.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[tanh_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return
add_after_tanh = children[0]
if not self.model.has_constant_input(add_after_tanh, 1.0):
return
if add_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[add_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_half = children[0]
i = self.model.find_constant_input(mul_half, 0.5)
if i < 0:
return
if mul_half.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[mul_half.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_after_mul_half = children[0]
root_node = self.model.get_parent(
mul_after_mul_half,
0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
output_name_to_node,
)
if root_node is None:
return
mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
return
i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
if i < 0:
return
add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
if add_before_tanh is None:
return
mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node])
if mul_after_pow is None:
return
i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
if i < 0:
return
pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
if pow is None:
return
if not self.model.has_constant_input(pow, 3.0):
return
if pow.input[0] != root_node.output[0]:
return
subgraph_nodes = [
mul_after_mul_half,
mul_half,
add_after_tanh,
tanh_node,
mul_before_tanh,
add_before_tanh,
mul_after_pow,
pow,
]
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
[mul_after_mul_half.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_node.output[0]],
outputs=mul_after_mul_half.output,
name=self.model.create_node_name("FastGelu"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True
def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
"""
OpenAI's gelu implementation, also used in Megatron:
Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x)))
Fuse subgraph into a FastGelu node:
+------------ Mul (B=0.79788456) -------------------+
| |
+-------------------------------+ |
| | |
| v v
[root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul-->
| ^
| |
+-----------> Mul (B=0.5) --------------------------------------------------------+
"""
if tanh_node.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[tanh_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return
add_after_tanh = children[0]
if not self.model.has_constant_input(add_after_tanh, 1.0):
return
if add_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[add_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_last = children[0]
mul_half = self.model.match_parent(mul_last, "Mul", None, output_name_to_node)
if mul_half is None:
return
i = self.model.find_constant_input(mul_half, 0.5)
if i < 0:
return
root_input = mul_half.input[0 if i == 1 else 1]
mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
return
add_1 = self.model.match_parent(mul_before_tanh, "Add", None, output_name_to_node)
if add_1 is None:
return
j = self.model.find_constant_input(add_1, 1.0)
if j < 0:
return
mul_7978 = self.model.match_parent(mul_before_tanh, "Mul", None, output_name_to_node)
if mul_7978 is None:
return
k = self.model.find_constant_input(mul_7978, 0.7978, delta=0.0001)
if k < 0:
return
if mul_7978.input[0 if k == 1 else 1] != root_input:
return
mul_before_add_1 = self.model.match_parent(add_1, "Mul", 0 if j == 1 else 1, output_name_to_node)
if mul_before_add_1 is None:
return
if mul_before_add_1.input[0] == root_input:
another = 1
elif mul_before_add_1.input[1] == root_input:
another = 0
else:
return
mul_0447 = self.model.match_parent(mul_before_add_1, "Mul", another, output_name_to_node)
if mul_0447 is None:
return
m = self.model.find_constant_input(mul_0447, 0.0447, delta=0.0001)
if m < 0:
return
if mul_0447.input[0 if m == 1 else 1] != root_input:
return
subgraph_nodes = [
mul_0447,
mul_before_add_1,
add_1,
mul_before_tanh,
tanh_node,
add_after_tanh,
mul_7978,
mul_half,
mul_last,
]
if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
[mul_last.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return
self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_input],
outputs=mul_last.output,
name=self.model.create_node_name("FastGelu"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True

Some files were not shown because too many files have changed in this diff Show More