I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,98 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import logging
import os
import sys
from utils import (
chain_enc_dec_with_beamsearch,
export_summarization_edinit,
export_summarization_enc_dec_past,
onnx_inference,
)
# GLOBAL ENVS
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("generate")
def print_args(args):
for arg in vars(args):
logger.info(f"{arg}: {getattr(args, arg)}")
def user_command():
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument("--max_length", type=int, default=20, help="default to 20")
parent_parser.add_argument("--min_length", type=int, default=0, help="default to 0")
parent_parser.add_argument("-o", "--output", type=str, default="onnx_models", help="default name is onnx_models.")
parent_parser.add_argument("-i", "--input_text", type=str, default=None, help="input text")
parent_parser.add_argument("-s", "--spm_path", type=str, default=None, help="tokenizer model from sentencepice")
parent_parser.add_argument("-v", "--vocab_path", type=str, help="vocab dictionary")
parent_parser.add_argument("-b", "--num_beams", type=int, default=5, help="default to 5")
parent_parser.add_argument("--repetition_penalty", type=float, default=1.0, help="default to 1.0")
parent_parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
parent_parser.add_argument("--early_stopping", type=bool, default=False, help="default to False")
parent_parser.add_argument("--opset_version", type=int, default=14, help="minimum is 14")
parent_parser.add_argument("--no_encoder", action="store_true")
parent_parser.add_argument("--no_decoder", action="store_true")
parent_parser.add_argument("--no_chain", action="store_true")
parent_parser.add_argument("--no_inference", action="store_true")
required_args = parent_parser.add_argument_group("required input arguments")
required_args.add_argument(
"-m",
"--model_dir",
type=str,
required=True,
help="The directory contains input huggingface model. \
An official model like facebook/bart-base is also acceptable.",
)
print_args(parent_parser.parse_args())
return parent_parser.parse_args()
if __name__ == "__main__":
args = user_command()
if args.opset_version < 14:
raise ValueError(f"The minimum supported opset version is 14! The given one was {args.opset_version}.")
isExist = os.path.exists(args.output) # noqa: N816
if not isExist:
os.makedirs(args.output)
# beam search op only supports CPU for now
args.device = "cpu"
logger.info("ENV: CPU ...")
if not args.input_text:
args.input_text = (
"PG&E stated it scheduled the blackouts in response to forecasts for high winds "
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
if not args.no_encoder:
logger.info("========== EXPORTING ENCODER ==========")
export_summarization_edinit.export_encoder(args)
if not args.no_decoder:
logger.info("========== EXPORTING DECODER ==========")
export_summarization_enc_dec_past.export_decoder(args)
if not args.no_chain:
logger.info("========== CONVERTING MODELS ==========")
chain_enc_dec_with_beamsearch.convert_model(args)
if not args.no_inference:
logger.info("========== INFERENCING WITH ONNX MODEL ==========")
onnx_inference.run_inference(args)

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,329 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# This script evaluates accuracy of ONNX models for question-answering task on SQuAD data set.
# Example to evaluate raw and optimized model for CUDA in Linux:
# pip3 install datasets evaluate optimum transformers onnxruntime-gpu
#
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding
#
# python3 -m onnxruntime.transformers.optimizer \
# --input ./bert-large-uncased-whole-word-masking-finetuned-squad/model.onnx \
# --output ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
#
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding \
# --onnx ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
#
# Snippet of example output in A100:
# {'exact': 86.65089877010406, 'f1': 92.99433524952254, 'total': 10570, 'HasAns_exact': 86.65089877010406
# 'total_time_in_seconds': 81.69239814393222, 'samples_per_second': 129.387804008115,
# 'latency_in_seconds': 0.007728703703304846, 'provider': 'CUDAExecutionProvider',
# 'pretrained_model_name': 'bert-large-uncased-whole-word-masking-finetuned-squad',
# 'batch_size': 1, 'sequence_length': 384, 'use_io_binding': True}
import argparse
import csv
import os
import time
try:
from importlib.metadata import PackageNotFoundError, version
except ImportError:
from importlib_metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Any, Dict, List, Optional
from datasets import load_dataset
from evaluate import evaluator
from optimum.onnxruntime import ORTModelForQuestionAnswering
from optimum.version import __version__ as optimum_version
from packaging import version as version_check
from transformers import AutoTokenizer, pipeline
if version_check.parse(optimum_version) < version_check.parse("1.13.1"):
raise ImportError(f"Please install optimum>=1.13.1. Current version: {optimum_version}.")
PRETRAINED_SQUAD_MODELS = [
"bert-large-uncased-whole-word-masking-finetuned-squad",
"deepset/roberta-base-squad2",
"distilbert-base-cased-distilled-squad",
]
def get_package_version(package_name: str):
try:
return version(package_name)
except PackageNotFoundError:
return None
def load_onnx_model(
model_id: str, onnx_path: Optional[str] = None, provider="CUDAExecutionProvider", use_io_binding: bool = False
):
"""Load onnx model given pretrained model name and optional ONNX model path. If onnx_path is None,
the default onnx model from optimum will be used.
Args:
model_id (str): pretrained model name or checkpoint path
onnx_path (Optional[str], optional): path of onnx model to evaluate. Defaults to None.
Returns:
model: ORTModel for the onnx model
onnx_path: the path of onnx model
"""
if onnx_path is None:
# Export onnx to a sub-directory named by the model id
model = ORTModelForQuestionAnswering.from_pretrained(
model_id, export=True, provider=provider, use_io_binding=use_io_binding
)
save_onnx_dir = os.path.join(".", model_id)
model.save_pretrained(save_onnx_dir)
onnx_path = os.path.join(save_onnx_dir, "model.onnx")
print("Model is exported to onnx file:", onnx_path)
else:
model = ORTModelForQuestionAnswering.from_pretrained(
os.path.dirname(onnx_path),
file_name=Path(onnx_path).name,
provider=provider,
use_io_binding=use_io_binding,
# provider_options={"enable_skip_layer_norm_strict_mode": True},
)
return model, onnx_path
def output_details(results: List[Dict[str, Any]], csv_filename: str):
"""Output a CSV file with detail of each test results.
Args:
results (List[Dict[str, Any]]): list of JSON results.
csv_filename (str): path of output CSV file
"""
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = [
"pretrained_model_name",
"onnx_path",
"provider",
"disable_fused_attention",
"batch_size",
"sequence_length",
"use_io_binding",
"exact",
"f1",
"total",
"HasAns_exact",
"HasAns_f1",
"HasAns_total",
"best_exact",
"best_exact_thresh",
"best_f1",
"best_f1_thresh",
"total_time_in_seconds",
"samples_per_second",
"latency_in_seconds",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for result in results:
csv_writer.writerow(result)
csv_file.flush()
print(f"Detail results are saved to csv file: {csv_filename}")
def output_summary(results: List[Dict[str, Any]], csv_filename: str, metric_name: str):
"""Output a CSV file with summary of a metric on combinations of batch_size and sequence_length.
Args:
results (List[Dict[str, Any]]): list of JSON results.
csv_filename (str): path of output CSV file
metric_name (str): the metric to summarize
"""
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
header_names = [
"pretrained_model_name",
"onnx_path",
"provider",
"disable_fused_attention",
"use_io_binding",
]
model_list = list({result["onnx_path"] for result in results})
model_list.sort()
batch_sizes = list({result["batch_size"] for result in results})
batch_sizes.sort()
sequence_lengths = list({result["sequence_length"] for result in results})
sequence_lengths.sort()
key_names = []
for sequence_length in sequence_lengths:
for batch_size in batch_sizes:
key_names.append(f"b{batch_size}_s{sequence_length}")
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + key_names)
csv_writer.writeheader()
for model in model_list:
row = {}
# Metric value for given pair of batch_size and sequence_length.
# Assume that (onnx_path, batch_size and sequence_length) are unique so keep first occurrence only.
values = {}
values.update({k: "" for k in key_names})
for result in results:
if result["onnx_path"] == model and result[metric_name]:
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
batch_size = result["batch_size"]
sequence_length = result["sequence_length"]
key = f"b{batch_size}_s{sequence_length}"
if key in key_names:
values[key] = result[metric_name]
if row:
for key in key_names:
row[key] = values.get(key, "")
csv_writer.writerow(row)
csv_file.flush()
print(f"Summary results for {metric_name} are saved to csv file: {csv_filename}")
def main():
args = parse_arguments()
print(args)
for name in ["onnxruntime-gpu", "onnxruntime", "onnx", "torch", "transformers", "optimum", "datasets", "evaluate"]:
package_version = get_package_version(name)
if package_version:
print(f"{name} version", package_version)
pretrained_model_name = args.model_name
if args.onnx and not os.path.exists(args.onnx):
raise RuntimeError(f"Onnx model path does not exist: {args.onnx}")
disable_fused_attention = os.environ.get("ORT_DISABLE_FUSED_ATTENTION", "0") == "1"
all_results = []
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
for sequence_length in args.sequence_lengths:
tokenizer.model_max_length = sequence_length
tokenizer.doc_stride = min(sequence_length // 2, 128)
if args.onnx is None:
print("Exporting onnx model. It might take a few minutes...")
start_time = time.time()
ort_model, onnx_path = load_onnx_model(pretrained_model_name, args.onnx, args.provider, args.use_io_binding)
latency = time.time() - start_time
print(f"Onnx model exported or loaded in {latency:.1f} seconds")
print(ort_model.config)
if sequence_length > ort_model.config.max_position_embeddings:
raise RuntimeError("sequence length should not be larger than {ort_model.config.max_position_embeddings}")
qa_pipeline = pipeline(
"question-answering", model=ort_model, tokenizer=tokenizer, question_first=True, batch_size=args.batch_size
)
task_evaluator = evaluator("question-answering")
print("Loading dataset...")
start_time = time.time()
squad_dataset = load_dataset("squad", split=f"validation[:{args.total}]" if args.total > 0 else "validation")
latency = time.time() - start_time
print(f"Dataset loaded in {latency:.1f} seconds")
print("Evaluating squad_v2 with ORT. It might take a few minutes...")
start_time = time.time()
result = task_evaluator.compute(
model_or_pipeline=qa_pipeline,
data=squad_dataset,
metric="squad_v2",
squad_v2_format=True,
)
latency = time.time() - start_time
print(f"Evaluation done in {latency:.1f} seconds")
result["provider"] = args.provider
result["disable_fused_attention"] = disable_fused_attention
result["pretrained_model_name"] = pretrained_model_name
result["onnx_path"] = onnx_path
result["batch_size"] = args.batch_size
result["sequence_length"] = sequence_length
result["use_io_binding"] = args.use_io_binding
print(result)
all_results.append(result)
output_details(all_results, "detail.csv")
for metric_name in ["f1", "exact", "samples_per_second"]:
output_summary(all_results, f"{metric_name}.csv", metric_name)
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=False,
type=str,
default=PRETRAINED_SQUAD_MODELS[0],
help=f"Checkpoint directory or pre-trained model names in the list: {PRETRAINED_SQUAD_MODELS}",
)
parser.add_argument(
"-s",
"--sequence_lengths",
nargs="+",
type=int,
default=[384],
help="Sequence lengths for onnx model inputs. It could have multiple values.",
)
parser.add_argument(
"-b",
"--batch_size",
type=int,
default=1,
help="batch size for inference.",
)
parser.add_argument("-t", "--total", type=int, default=0, help="Total samples to test. 0 means all samples.")
parser.add_argument(
"--onnx",
required=False,
type=str,
default=None,
help="Optional onnx model path. If not specified, optimum will be used to export onnx model for testing.",
)
parser.add_argument(
"--provider",
required=False,
default="CUDAExecutionProvider",
help="Select which Execution Provider to use for runs. Default is CUDAExecutionProvider.",
)
parser.add_argument("--use_io_binding", required=False, action="store_true", help="Use IO Binding for GPU.")
parser.set_defaults(use_io_binding=False)
args = parser.parse_args(argv)
return args
if __name__ == "__main__":
main()

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,413 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script benchmarks gpt2 model with past state.
# For gpt2 model without past state, use benchmark.py to measure performance.
import argparse
import csv
import logging
import os
from datetime import datetime
import psutil
import torch
from benchmark_helper import (
Precision,
create_onnxruntime_session,
get_ort_environment_variables,
prepare_environment,
setup_logger,
)
from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
from packaging import version
from quantize_helper import QuantizeHelper
from transformers import AutoConfig
from transformers import __version__ as transformers_version
logger = logging.getLogger("")
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Model path, or pretrained model name selected in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
)
parser.add_argument(
"--model_class",
required=False,
type=str,
default="GPT2LMHeadModel",
choices=list(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--onnx_dir",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Directory to store onnx models",
)
parser.add_argument(
"--test_times",
required=False,
default=100,
type=int,
help="Number of repeat times to get average inference latency.",
)
parser.add_argument(
"-v",
"--validate_onnx",
required=False,
action="store_true",
help="Validate ONNX model",
)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument(
"--stage",
type=int,
default=0,
required=False,
choices=[0, 1, 2],
help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
"1 - decode the first token when past_sequence_length is zero; "
"2 - decode the remaining tokens when past_sequence_length is not zero; "
"0 - one onnx model for both stages 1 and 2. "
"Note that we will optimize 1 and 2 differently for best performance.",
)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"-p",
"--precision",
type=Precision,
default=Precision.FLOAT32,
choices=list(Precision),
help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization",
)
parser.add_argument("--torchscript", required=False, action="store_true", help="use Torchscript")
parser.set_defaults(torchscript=False)
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1], help="batch size")
parser.add_argument(
"--sequence_lengths",
nargs="+",
type=int,
default=[1],
help="sequence lengths (excluding past)",
)
parser.add_argument(
"-s",
"--past_sequence_lengths",
nargs="+",
type=int,
default=[8, 16, 32, 64, 128, 256],
help="past sequence lengths",
)
parser.add_argument(
"-r",
"--result_csv",
required=False,
default=None,
help="CSV file for saving summary results.",
)
parser.add_argument("--thread_num", required=False, type=int, default=-1, help="Threads to use")
parser.add_argument("--include_copy_output_latency", required=False, action="store_true")
parser.set_defaults(include_copy_output_latency=False)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument("--output_torch_latency", required=False, action="store_true")
parser.set_defaults(output_torch_latency=False)
parser.add_argument("--disable_io_binding", required=False, action="store_true")
parser.set_defaults(disable_io_binding=False)
args = parser.parse_args(argv)
return args
def main(args):
if version.parse(transformers_version) < version.parse(
"3.1.0"
): # past_key_values name does not exist in 3.0.2 or older
raise RuntimeError("This tool requires transformers 3.1.0 or later.")
logger.info(f"Arguments:{args}")
if args.precision == Precision.FLOAT16:
assert args.optimize_onnx and args.use_gpu, "fp16 requires --optimize_onnx --use_gpu"
if args.precision == Precision.INT8:
assert not args.use_gpu, "quantization only supports CPU"
if args.stage == 1:
assert args.past_sequence_lengths == [0], "past_sequence_lengths shall be 0 for stage==1 (init decoder)"
torch.set_num_threads(psutil.cpu_count(logical=True) if args.thread_num <= 0 else args.thread_num)
print(torch.__config__.parallel_info())
cache_dir = args.cache_dir
output_dir = args.onnx_dir
prepare_environment(cache_dir, output_dir, args.use_gpu)
model_class = MODEL_CLASSES[args.model_class][0]
gpt2helper = Gpt2Helper
config = AutoConfig.from_pretrained(args.model_name_or_path, torchscript=args.torchscript, cache_dir=cache_dir)
model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
# This script does not support float16 for PyTorch.
# if args.float16:
# model.half()
device = torch.device("cuda:0" if args.use_gpu else "cpu")
model.to(device)
use_external_data_format = config.n_layer > 24 # TODO: find a way to check model size > 2GB
onnx_model_paths = gpt2helper.get_onnx_paths(
output_dir,
args.model_name_or_path,
args.model_class,
has_past=True,
new_folder=use_external_data_format,
)
onnx_model_path = onnx_model_paths["raw"]
use_padding = MODEL_CLASSES[args.model_class][2]
gpt2helper.export_onnx(
model,
device,
onnx_model_path,
args.verbose,
use_external_data_format,
has_position_ids=use_padding,
has_attention_mask=use_padding,
)
if args.optimize_onnx or args.precision != Precision.FLOAT32:
onnx_model_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
gpt2helper.optimize_onnx(
onnx_model_paths["raw"],
onnx_model_path,
args.precision == Precision.FLOAT16,
model.config.num_attention_heads,
model.config.hidden_size,
use_external_data_format,
auto_mixed_precision=True,
stage=args.stage,
)
if args.precision == Precision.INT8:
logger.info("quantizing model...")
QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_paths["int8"], use_external_data_format)
model = QuantizeHelper.quantize_torch_model(model)
logger.info("finished quantizing model")
onnx_model_path = onnx_model_paths["int8"]
if args.torchscript:
model = gpt2helper.torchscript(
model,
config,
device,
has_position_ids=use_padding,
has_attention_mask=use_padding,
)
session = create_onnxruntime_session(
onnx_model_path,
args.use_gpu,
enable_all_optimization=False,
num_threads=args.thread_num,
verbose=args.verbose,
)
if session is None:
return
# Allocate output buffers for IO Binding
max_output_shapes = gpt2helper.get_output_shapes(
max(args.batch_sizes),
max(args.past_sequence_lengths),
max(args.sequence_lengths),
config,
args.model_class,
)
output_buffers = gpt2helper.get_output_buffers(max_output_shapes, device, args.precision == Precision.FLOAT16)
csv_filename = args.result_csv or "benchmark_result_{}.csv".format(datetime.now().strftime("%Y%m%d-%H%M%S"))
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"model_name",
"model_class",
"stage",
"environment_variables",
"gpu",
"precision",
"optimizer",
"torchscript",
"batch_size",
"sequence_length",
"past_sequence_length",
"disable_io_binding",
"torch_latency",
"onnxruntime_latency",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for batch_size in args.batch_sizes:
for sequence_length in args.sequence_lengths:
for past_sequence_length in args.past_sequence_lengths:
assert batch_size > 0 and sequence_length > 0 and past_sequence_length >= 0
logger.debug(
"Running test for batch_size=%d sequence_length=%d past_sequence_length=%d ...",
batch_size,
sequence_length,
past_sequence_length,
)
dummy_inputs = gpt2helper.get_dummy_inputs(
batch_size,
past_sequence_length,
sequence_length,
config.num_attention_heads,
config.hidden_size,
config.n_layer,
config.vocab_size,
device,
float16=(args.precision == Precision.FLOAT16),
has_position_ids=use_padding,
has_attention_mask=use_padding,
)
output_shapes = gpt2helper.get_output_shapes(
batch_size,
past_sequence_length,
sequence_length,
config,
args.model_class,
)
try:
if args.validate_onnx or args.output_torch_latency:
outputs, torch_latency = gpt2helper.pytorch_inference(model, dummy_inputs, args.test_times)
# Dump Torch output shape
for i, value in enumerate(outputs):
if isinstance(value, tuple):
logger.debug(
f"torch output {i} is tuple of size {len(value)}, shape {value[0].shape}"
)
else:
logger.debug(f"torch output {i} shape {value.shape}")
else:
outputs = None
torch_latency = None
if args.disable_io_binding:
ort_outputs, ort_latency = gpt2helper.onnxruntime_inference(
session, dummy_inputs, args.test_times
)
else:
ort_outputs, ort_latency = gpt2helper.onnxruntime_inference_with_binded_io(
session,
dummy_inputs,
output_buffers,
output_shapes,
args.test_times,
return_numpy=False,
include_copy_output_latency=args.include_copy_output_latency,
)
if args.validate_onnx:
copy_outputs = ort_outputs
if not args.disable_io_binding:
# Results of IO binding might be in GPU. Copy outputs to CPU for comparison.
copy_outputs = []
for output in ort_outputs:
copy_outputs.append(output.cpu().numpy())
if gpt2helper.compare_outputs(
outputs,
copy_outputs,
model_class=args.model_class,
rtol=DEFAULT_TOLERANCE[args.precision],
atol=DEFAULT_TOLERANCE[args.precision],
):
logger.info(
f"Pytorch and ONNX Runtime outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]})."
)
logger.info(
"batch_size=%d, sequence_length=%d, past_sequence_length=%d, onnxruntime_latency=%.2f %s %s",
batch_size,
sequence_length,
past_sequence_length,
ort_latency,
"(disable_io_binding)" if args.disable_io_binding else "",
", torch_latency={torch_latency}" if torch_latency else "",
)
row = {
"model_name": args.model_name_or_path,
"model_class": args.model_class,
"stage": args.stage,
"environment_variables": get_ort_environment_variables(),
"gpu": args.use_gpu,
"precision": args.precision,
"optimizer": args.optimize_onnx,
"torchscript": args.torchscript,
"batch_size": batch_size,
"sequence_length": sequence_length,
"past_sequence_length": past_sequence_length,
"disable_io_binding": args.disable_io_binding,
"torch_latency": f"{torch_latency:.2f}" if torch_latency else "None",
"onnxruntime_latency": f"{ort_latency:.2f}",
}
csv_writer.writerow(row)
except Exception:
logger.error("Exception", exc_info=True) # noqa: G201
return None
logger.info(f"Results are saved to file {csv_filename}")
return csv_filename
if __name__ == "__main__":
args = parse_arguments()
setup_logger(args.verbose)
main(args)

View File

@ -0,0 +1,557 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
This converts GPT2 model to onnx. Examples:
(1) Convert pretrained model 'gpt2' to ONNX
python convert_to_onnx.py -m gpt2 --output gpt2.onnx
(2) Convert pretrained model 'distilgpt2' to ONNX, and use optimizer to get float16 model.
python convert_to_onnx.py -m distilgpt2 --output distilgpt2_fp16.onnx -o -p fp16
(3) Convert a model check point to ONNX, and run optimization and int8 quantization
python convert_to_onnx.py -m ./my_model_checkpoint/ --output my_model_int8.onnx -o -p int8
"""
import argparse
import csv
import json
import logging
import os
import shutil
import sys
from pathlib import Path
import numpy
import torch
from benchmark_helper import (
Precision,
create_onnxruntime_session,
get_ort_environment_variables,
prepare_environment,
setup_logger,
)
from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
from gpt2_tester import Gpt2Tester
from packaging import version
from quantize_helper import QuantizeHelper
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from onnxruntime import __version__ as ort_version
logger = logging.getLogger("")
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
)
parser.add_argument(
"--model_class",
required=False,
type=str,
default="GPT2LMHeadModel",
choices=list(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--output",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Output directory, or model path ends with .onnx",
)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--provider",
required=False,
default=None,
choices=["dml", "rocm", "migraphx", "cuda", "tensorrt"],
help="use dml, rocm, cuda, tensorrt or migraphx for respective backend",
)
parser.add_argument(
"--tolerance",
required=False,
type=float,
default=0,
help="the absolute and relative tolerance for parity verification",
)
parser.add_argument(
"--input_test_file",
"-i",
required=False,
type=str,
default="",
help="Path to the file with inputs to test with",
)
parser.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=list(Precision),
help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision, and int8 for quantization",
)
parser.add_argument(
"-t",
"--test_cases",
required=False,
type=int,
default=1000,
help="Number of test cases per run for parity",
)
parser.add_argument(
"-r",
"--test_runs",
required=False,
type=int,
default=10,
help="Number of runs for parity. It is used for significance test.",
)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)
parser.add_argument("--overwrite", required=False, action="store_true")
parser.set_defaults(overwrite=False)
parser.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int32 instead of int64 for input_ids, position_ids and attention_mask.",
)
parser.set_defaults(use_int64_inputs=False)
parser.add_argument(
"-s",
"--stage",
type=int,
default=0,
required=False,
choices=[0, 1, 2],
help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
"1 - decode the first token when past_sequence_length is zero; "
"2 - decode the remaining tokens when past_sequence_length is not zero; "
"0 - one onnx model for both stages 1 and 2. "
"Note that we will optimize 1 and 2 differently for best performance.",
)
fp16_option_group = parser.add_argument_group(
'float to float16 conversion parameters that works when "--precision fp16" is specified'
)
fp16_option_group.add_argument(
"-a",
"--auto_mixed_precision",
required=False,
action="store_true",
help="Convert to mixed precision automatically. Other float16 conversion parameters will be ignored.",
)
fp16_option_group.set_defaults(auto_mixed_precision=False)
fp16_option_group.add_argument(
"--keep_io_types",
required=False,
action="store_true",
help="Use float32 for past inputs, present and logits outputs.",
)
fp16_option_group.set_defaults(keep_io_types=False)
fp16_option_group.add_argument(
"--io_block_list",
nargs="+",
default=[],
help="List of inputs or outputs in float32 instead of float16",
)
fp16_option_group.add_argument(
"--op_block_list",
nargs="+",
default=[],
help="List of operators (like Add LayerNormalization SkipLayerNormalization EmbedLayerNormalization FastGelu) "
"to compute in float32 instead of float16.",
)
fp16_option_group.add_argument(
"--node_block_list",
nargs="+",
default=[],
help="List of node names to compute in float32 instead of float16.",
)
fp16_option_group.add_argument(
"--force_fp16_initializers",
required=False,
action="store_true",
help="Convert all float initializers to float16.",
)
fp16_option_group.set_defaults(force_fp16_initializers=False)
args = parser.parse_args(argv)
return args
def get_onnx_model_size(onnx_path: str, use_external_data_format: bool):
if not use_external_data_format:
return os.path.getsize(onnx_path)
else:
return sum([f.stat().st_size for f in Path(onnx_path).parent.rglob("*")])
def get_latency_name(batch_size, sequence_length, past_sequence_length):
return f"average_latency(batch_size={batch_size},sequence_length={sequence_length},past_sequence_length={past_sequence_length})"
def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: str = "gpt2_parity_results.csv"):
result = {}
if version.parse(transformers_version) < version.parse(
"3.1.0"
): # past_key_values name does not exist in 3.0.2 or older
raise RuntimeError("This tool requires transformers 3.1.0 or later.")
args = parse_arguments(argv)
setup_logger(args.verbose)
if not experiment_name:
experiment_name = " ".join(argv if argv else sys.argv[1:])
if args.tolerance == 0:
args.tolerance = DEFAULT_TOLERANCE[args.precision]
logger.info(f"Arguments:{args}")
cache_dir = args.cache_dir
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
prepare_environment(cache_dir, output_dir, args.use_gpu)
if args.precision != Precision.FLOAT32:
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 requires --use_gpu"
if args.precision == Precision.INT8:
assert not args.use_gpu, "quantization only supports CPU"
model_class = MODEL_CLASSES[args.model_class][0]
use_padding = MODEL_CLASSES[args.model_class][2]
gpt2helper = Gpt2Helper
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=cache_dir)
model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
device = torch.device("cuda:0" if args.use_gpu else "cpu")
model.eval().to(device)
if (not args.use_external_data_format) and (config.n_layer > 24):
logger.info("Try --use_external_data_format when model size > 2GB")
onnx_model_paths = gpt2helper.get_onnx_paths(
output_dir,
args.model_name_or_path,
args.model_class,
new_folder=(args.precision == Precision.INT8),
remove_existing=["fp32", "fp16", "int8"],
) # Do not remove raw model to save time in parity test
raw_onnx_model = onnx_model_paths["raw"]
int_data_type = torch.int64 if args.use_int64_inputs else torch.int32
if os.path.exists(raw_onnx_model) and not args.overwrite:
logger.warning(f"Skip exporting ONNX model since it existed: {raw_onnx_model}")
else:
logger.info(f"Exporting ONNX model to {raw_onnx_model}")
gpt2helper.export_onnx(
model,
device,
raw_onnx_model,
args.verbose,
args.use_external_data_format,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
)
fp16_params = {"keep_io_types": args.keep_io_types}
if args.io_block_list:
fp16_params["keep_io_types"] = args.io_block_list
if args.node_block_list:
fp16_params["node_block_list"] = args.node_block_list
if args.op_block_list:
fp16_params["op_block_list"] = args.op_block_list
if args.force_fp16_initializers:
fp16_params["force_fp16_initializers"] = args.force_fp16_initializers
is_io_float16 = args.precision == Precision.FLOAT16 and not args.keep_io_types
optimized_ops = ""
all_ops = ""
if args.optimize_onnx or args.precision != Precision.FLOAT32:
output_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
logger.info(f"Optimizing model to {output_path}")
m = gpt2helper.optimize_onnx(
raw_onnx_model,
output_path,
args.precision == Precision.FLOAT16,
model.config.num_attention_heads,
model.config.hidden_size,
args.use_external_data_format,
auto_mixed_precision=args.auto_mixed_precision,
stage=args.stage,
**fp16_params,
)
nodes = m.nodes()
op_list = {node.op_type for node in nodes}
all_ops = ",".join(op_list)
# print optimized operators
optimized_op_counter = m.get_fused_operator_statistics()
if optimized_op_counter:
optimized_ops = ",".join([key for key in optimized_op_counter if optimized_op_counter[key] > 0])
else:
output_path = raw_onnx_model
if args.precision == Precision.INT8:
logger.info("quantizing model...")
QuantizeHelper.quantize_onnx_model(output_path, onnx_model_paths["int8"], args.use_external_data_format)
model = QuantizeHelper.quantize_torch_model(model)
logger.info("finished quantizing model")
output_path = onnx_model_paths["int8"]
if args.output.endswith(".onnx") and output_path != args.output and not args.use_external_data_format:
shutil.move(output_path, args.output)
output_path = args.output
logger.info(f"Output path: {output_path}")
model_size_in_MB = int(get_onnx_model_size(output_path, args.use_external_data_format) / 1024 / 1024) # noqa: N806
session = create_onnxruntime_session(
output_path, args.use_gpu, args.provider, enable_all_optimization=True, verbose=args.verbose
)
if args.model_class == "GPT2LMHeadModel" and session is not None:
parity_result = gpt2helper.test_parity(
session,
model,
device,
is_io_float16,
rtol=args.tolerance,
atol=args.tolerance,
model_class=args.model_class,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
test_cases_per_run=args.test_cases,
total_runs=args.test_runs,
stage=args.stage,
verbose=args.verbose,
)
# An example configuration for testing performance
batch_size = 8
sequence_length = 32 if args.stage == 1 else 1
past_sequence_length = 0 if args.stage == 1 else 32
latency = gpt2helper.test_performance(
session,
model,
device,
is_io_float16,
total_runs=100,
use_io_binding=True,
model_class=args.model_class,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
batch_size=batch_size,
sequence_length=sequence_length,
past_sequence_length=past_sequence_length,
)
if args.precision == Precision.FLOAT16:
logger.info(f"fp16 conversion parameters:{fp16_params}")
# Write results to file
latency_name = get_latency_name(batch_size, sequence_length, past_sequence_length)
csv_file_existed = os.path.exists(csv_filename)
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"experiment",
"run_id",
"model_name",
"model_class",
"stage",
"gpu",
"precision",
"optimizer",
"test_cases",
"runs",
"keep_io_types",
"io_block_list",
"op_block_list",
"node_block_list",
"force_fp16_initializers",
"auto_mixed_precision",
"optimized_operators",
"operators",
"environment_variables",
"onnxruntime",
latency_name,
"top1_match_rate",
"onnx_size_in_MB",
"diff_50_percentile",
"diff_90_percentile",
"diff_95_percentile",
"diff_99_percentile",
"diff_pass_rate",
"nan_rate",
"top1_match_rate_per_run",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
if not csv_file_existed:
csv_writer.writeheader()
row = {
"experiment": experiment_name,
"run_id": run_id,
"model_name": args.model_name_or_path,
"model_class": args.model_class,
"stage": args.stage,
"gpu": args.use_gpu,
"precision": args.precision,
"optimizer": args.optimize_onnx,
"test_cases": args.test_cases,
"runs": args.test_runs,
"keep_io_types": args.keep_io_types,
"io_block_list": args.io_block_list,
"op_block_list": args.op_block_list,
"node_block_list": args.node_block_list,
"force_fp16_initializers": args.force_fp16_initializers,
"auto_mixed_precision": args.auto_mixed_precision,
"optimized_operators": optimized_ops,
"operators": all_ops,
"environment_variables": get_ort_environment_variables(),
"onnxruntime": ort_version,
latency_name: f"{latency:.2f}",
"diff_50_percentile": parity_result["max_diff_percentile_50"],
"diff_90_percentile": parity_result["max_diff_percentile_90"],
"diff_95_percentile": parity_result["max_diff_percentile_95"],
"diff_99_percentile": parity_result["max_diff_percentile_99"],
"diff_pass_rate": parity_result["diff_pass_rate"],
"nan_rate": parity_result["nan_rate"],
"top1_match_rate": parity_result["top1_match_rate"],
"top1_match_rate_per_run": parity_result["top1_match_rate_per_run"],
"onnx_size_in_MB": f"{model_size_in_MB}",
}
logger.info(f"result: {row}")
result.update(row)
csv_writer.writerow(row)
if args.input_test_file:
test_inputs = []
# Each line of test file is a JSON string like:
# {"input_ids": [[14698, 257, 1310, 13688, 319, 326]]}
with open(args.input_test_file) as read_f:
for _, line in enumerate(read_f):
line = line.rstrip() # noqa: PLW2901
data = json.loads(line)
input_ids = torch.from_numpy(numpy.asarray(data["input_ids"], dtype=numpy.int64)).to(device)
if use_padding:
if "attention_mask" in data:
numpy_float = numpy.float16 if is_io_float16 else numpy.float32
attention_mask = torch.from_numpy(numpy.asarray(data["attention_mask"], dtype=numpy_float)).to(
device
)
else:
padding = -1
attention_mask = (input_ids != padding).type(torch.float16 if is_io_float16 else torch.float32)
input_ids.masked_fill_(input_ids == padding, 0)
if "position_ids" in data:
position_ids = torch.from_numpy(numpy.asarray(data["position_ids"], dtype=numpy.int64)).to(
device
)
else:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(position_ids < 0, 0)
inputs = {
"input_ids": input_ids.to(int_data_type),
"position_ids": position_ids.to(int_data_type),
"attention_mask": attention_mask.to(int_data_type),
}
else:
inputs = {"input_ids": input_ids.to(int_data_type)}
test_inputs.append(inputs)
Gpt2Tester.test_generation(
session,
model,
device,
test_inputs,
precision=args.precision,
model_class=args.model_class,
top_k=20,
top_k_no_order=True,
max_steps=24,
max_inputs=0,
verbose=args.verbose,
save_test_data=3,
save_test_data_dir=Path(output_path).parent,
)
logger.info(f"Done. Output model: {output_path}")
return result
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,513 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script uses different configurations in mixed precision conversion for GPT-2 model, and
# measures the inference latency, top 1 match rate (compared to PyTorch FP32 model) and ONNX model size.
# It outputs a csv file with Mann-Whitney U test and T-Test on each pair of experiments, where
# pvalue < 0.05 means two experiments have significant difference on top 1 match rate.
# User could use this script to select the best mixed precision model according to these metrics.
import argparse
import csv
import datetime
import json
import logging
import os
import onnx
import scipy.stats
from benchmark_helper import get_ort_environment_variables, setup_logger
from convert_to_onnx import main
from gpt2_helper import PRETRAINED_GPT2_MODELS, Gpt2Helper
from onnx_model import OnnxModel
logger = logging.getLogger("")
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
)
parser.add_argument(
"--csv",
required=False,
type=str,
default="gpt2_parity_results.csv",
help="path of csv file to save the result",
)
parser.add_argument(
"--test_cases",
required=False,
type=int,
default=500,
help="number of test cases per run",
)
parser.add_argument("--runs", required=False, type=int, default=40, help="number of repeated runs")
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--all",
required=False,
action="store_true",
help="run all combinations of mixed precision",
)
parser.set_defaults(all=False)
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument(
"--skip_test",
required=False,
action="store_true",
help="do not run test, and only rank experiments based on existing csv file",
)
parser.set_defaults(skip_test=False)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing csv file",
)
parser.set_defaults(overwrite=False)
args = parser.parse_args(argv)
return args
class ParityTask:
def __init__(self, test_cases, total_runs, csv_path):
self.total_runs = total_runs
self.test_cases = test_cases
self.csv_path = csv_path
self.results = []
self.run_id = 0
def run(self, argv, experiment_name):
start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
run_id = f"{start_time}_{self.run_id}"
self.run_id += 1
try:
result = main(
[*argv, "-t", f"{self.test_cases}", "-r", f"{self.total_runs}"],
experiment_name=experiment_name,
run_id=run_id,
csv_filename=self.csv_path,
)
if result:
self.results.append(result)
except Exception:
logger.exception(f"Failed to run experiment {experiment_name}")
result = None
return result
def load_results_from_csv(csv_path):
rows = []
import csv
with open(csv_path, newline="") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
rows.append(row) # noqa: PERF402
return rows
def get_latency(row):
for name in row:
if name.startswith("average_latency(batch_size="):
return float(row[name])
raise RuntimeError("Failed to get average_latency from output")
def score(row):
"""Scoring function based on 3 metrics. The larger score is better."""
latency_in_ms = get_latency(row)
top1_match_rate = float(row["top1_match_rate"])
onnx_size_in_MB = float(row["onnx_size_in_MB"]) # noqa: N806
# A simple scoring function: cost of 0.1ms latency ~ 0.1% match rate ~ 100MB size
return top1_match_rate * 1000 - latency_in_ms * 10 - onnx_size_in_MB / 100
def print_wins(wins, rows, test_name):
print()
print("*" * 10)
row_map = {}
for row in rows:
row_map[row["run_id"]] = row
sorted_wins = dict(
sorted(
wins.items(),
key=lambda item: (item[1], score(row_map[item[0]])),
reverse=True,
)
)
logger.debug(f"{test_name} Wins:{sorted_wins}")
logger.info(f"Based on {test_name} wins and a scoring function, the ranking:")
rank = 0
previous_value = -1
for count, (key, value) in enumerate(sorted_wins.items()):
if value != previous_value:
rank = count
previous_value = value
for row in rows:
if row["run_id"] == key:
logger.info(
"{:02d}: WINs={:02d}, run_id={}, latency={:5.2f}, top1_match={:.4f}, size={}_MB, experiment={}, {}".format( # noqa: G001
rank,
value,
key,
get_latency(row),
float(row["top1_match_rate"]),
row["onnx_size_in_MB"],
row["experiment"],
get_ort_environment_variables(),
)
)
break
def run_significance_test(rows, output_csv_path):
"""Run U test and T test."""
utest_wins = {}
ttest_wins = {}
for row in rows:
run_id = row["run_id"]
utest_wins[run_id] = 0
ttest_wins[run_id] = 0
with open(output_csv_path, "w", newline="") as csvfile:
column_names = [
"model_name",
"run_id_1",
"experiment_1",
"top1_match_rate_1",
"run_id_2",
"experiment_2",
"top1_match_rate_2",
"U_statistic",
"U_pvalue",
"T_statistic",
"T_pvalue",
]
writer = csv.DictWriter(csvfile, fieldnames=column_names)
writer.writeheader()
required_match_columns = ["model_name", "test_cases", "runs"]
num_results = len(rows)
for i in range(num_results - 1):
result1 = rows[i]
if isinstance(result1["top1_match_rate_per_run"], str):
a = json.loads(result1["top1_match_rate_per_run"])
else:
a = result1["top1_match_rate_per_run"]
for j in range(i + 1, num_results, 1):
result2 = rows[j]
all_matched = True
for column in required_match_columns:
if result1[column] != result2[column]:
all_matched = False
break
if not all_matched:
continue
if isinstance(result2["top1_match_rate_per_run"], str):
b = json.loads(result2["top1_match_rate_per_run"])
else:
b = result2["top1_match_rate_per_run"]
try:
utest_statistic, utest_pvalue = scipy.stats.mannwhitneyu(
a, b, use_continuity=True, alternative="two-sided"
) # TODO: shall we use one-sided: less or greater according to "top1_match_rate"
except ValueError: # ValueError: All numbers are identical in mannwhitneyu
utest_statistic = None
utest_pvalue = None
ttest_statistic, ttest_pvalue = scipy.stats.ttest_ind(a, b, axis=None, equal_var=True)
if utest_pvalue is not None and utest_pvalue < 0.05:
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
utest_wins[result1["run_id"]] += 1
else:
utest_wins[result2["run_id"]] += 1
if ttest_pvalue < 0.05:
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
ttest_wins[result1["run_id"]] += 1
else:
ttest_wins[result2["run_id"]] += 1
row = {
"model_name": result1["model_name"],
"run_id_1": result1["run_id"],
"experiment_1": result1["experiment"],
"top1_match_rate_1": float(result1["top1_match_rate"]),
"run_id_2": result2["run_id"],
"experiment_2": result2["experiment"],
"top1_match_rate_2": float(result2["top1_match_rate"]),
"U_statistic": utest_statistic,
"U_pvalue": utest_pvalue,
"T_statistic": ttest_statistic,
"T_pvalue": ttest_pvalue,
}
writer.writerow(row)
logger.info(f"U-Test and T-Test results are output to {output_csv_path}")
print_wins(utest_wins, rows, "U-Test")
print_wins(ttest_wins, rows, "T-Test")
def get_last_matmul_node_name(raw_onnx_model: str):
model = onnx.load(raw_onnx_model)
onnx_model = OnnxModel(model)
output_name_to_node = onnx_model.output_name_to_node()
assert model.graph.output[0].name in output_name_to_node
node = output_name_to_node[model.graph.output[0].name]
if node.op_type == "MatMul":
logger.info(f"Found last MatMul node for logits: {node.name}")
return node.name
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
return None
def get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list):
model = args.model_name_or_path
parameters = f"-m {model} -o --use_gpu -p fp16".split()
if args.use_external_data_format:
parameters.append("--use_external_data_format")
parameters += [
"--io_block_list",
"logits",
"--node_block_list",
last_matmul_node_name,
]
if op_block_list:
parameters.extend(["--op_block_list", *op_block_list])
return parameters
def run_candidate(
task: ParityTask,
args,
last_matmul_node_name,
op_block_list=["FastGelu", "LayerNormalization"], # noqa: B006
):
parameters = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list)
op_block_list_str = ",".join(sorted(op_block_list))
if op_block_list:
name = f"Mixed precision baseline + {op_block_list_str} in FP32"
else:
name = f"Mixed precision baseline (logits output and last MatMul node {last_matmul_node_name} in FP32)"
env_vars = get_ort_environment_variables()
if env_vars:
name = name + f" ({env_vars})"
task.run(parameters, name)
def get_baselines(args):
model = args.model_name_or_path
fp32_baseline = f"-m {model} -o -p fp32".split()
if args.use_gpu:
fp32_baseline.append("--use_gpu")
if args.use_external_data_format:
fp32_baseline.append("--use_external_data_format")
fp16_baseline = f"-m {model} -o --use_gpu -p fp16".split()
if args.use_external_data_format:
fp16_baseline.append("--use_external_data_format")
return fp32_baseline, fp16_baseline
def run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops):
"""Step 0 is to check which operator in FP16 causes most loss"""
fp32_logits = ["--io_block_list", "logits"]
task.run(fp16_baseline + fp32_logits, "FP16 except logits")
fp32_io = ["--keep_io_types"]
task.run(fp16_baseline + fp32_io, "Graph I/O FP32, Other FP16")
# Only weights in FP16
task.run(
fp16_baseline + fp32_io + ["--op_block_list"] + [o for o in all_ops] + ["--force_fp16_initializers"],
"FP32 except weights in FP16",
)
optimized_ops_results = []
op_list = optimized_ops
for op in op_list:
op_block_list = ["--op_block_list"] + [o for o in op_list if o != op]
result = task.run(fp16_baseline + fp32_io + op_block_list, f"FP32 except {op} in FP16")
if result:
optimized_ops_results.append(result)
# Check which optimized operator causes the most loss in precision
min_result = min(optimized_ops_results, key=lambda y: y["top1_match_rate"])
print("step 0: optimized operator causes the most loss in precision", min_result)
def run_tuning_step1(task, mixed_precision_baseline, optimized_ops):
"""Step 1 is to figure out which optimized operator in FP32 could benefit most"""
for op in optimized_ops:
op_block_list = ["--op_block_list", op]
task.run(
mixed_precision_baseline + op_block_list,
f"Mixed precision baseline + {op} in FP32",
)
def run_tuning_step2(task, mixed_precision_baseline, optimized_ops):
"""Assumed that you have run step 0 and 1 to figure out that Logits FP32 and some operators shall be in FP32,
This step will try add one more operator.
"""
candidate_fp32_ops = ["FastGelu", "LayerNormalization", "SkipLayerNormalization"]
fp32_ops = [x for x in candidate_fp32_ops if x in optimized_ops]
for op in optimized_ops:
if op not in fp32_ops:
op_block_list = [*fp32_ops, op]
task.run(
[*mixed_precision_baseline, "--op_block_list", *op_block_list],
"Mixed precision baseline + {},{} in FP32".format(",".join(fp32_ops), op),
)
def run_parity(task: ParityTask, args):
onnx_model_paths = Gpt2Helper.get_onnx_paths(
"onnx_models",
args.model_name_or_path,
new_folder=args.use_external_data_format,
remove_existing=[],
)
fp32_baseline, fp16_baseline = get_baselines(args)
result = task.run(fp32_baseline, "FP32 baseline")
optimized_ops = []
if result and ("optimized_operators" in result) and result["optimized_operators"]:
optimized_ops = result["optimized_operators"].split(",")
else:
raise RuntimeError("Failed to get optimized operators")
all_ops = []
if result and ("operators" in result) and result["operators"]:
all_ops = result["operators"].split(",")
else:
raise RuntimeError("Failed to get operators")
# The following tests for fp16 requires GPU
if not args.use_gpu:
logger.info("skip mixed precision since --use_gpu is not specified")
return
task.run(fp16_baseline, "FP16 baseline")
last_matmul_node_name = get_last_matmul_node_name(onnx_model_paths["raw"])
# Mixed precision baseline
run_candidate(task, args, last_matmul_node_name, op_block_list=[])
def get_fp32_ops(x):
return [op for op in x if op in all_ops]
if args.all:
run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops)
mixed_precision_baseline = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list=[])
run_tuning_step1(task, mixed_precision_baseline, optimized_ops)
run_tuning_step2(task, mixed_precision_baseline, optimized_ops)
else:
run_candidate(
task,
args,
last_matmul_node_name,
op_block_list=get_fp32_ops(["SkipLayerNormalization", "LayerNormalization", "Add"]),
)
run_candidate(task, args, last_matmul_node_name, op_block_list=["FastGelu"])
# Run a few good candidates
run_candidate(
task,
args,
last_matmul_node_name,
op_block_list=get_fp32_ops(["FastGelu", "SkipLayerNormalization", "LayerNormalization", "Add"]),
)
run_candidate(
task,
args,
last_matmul_node_name,
op_block_list=get_fp32_ops(
["FastGelu", "EmbedLayerNormalization", "SkipLayerNormalization", "LayerNormalization", "Add"]
),
)
if __name__ == "__main__":
args = parse_arguments()
setup_logger(args.verbose)
if args.test_cases < 100 or args.runs < 20 or args.test_cases * args.runs < 10000:
logger.warning(
"Not enough test cases or runs to get stable results or test significance. "
"Recommend test_cases >= 100, runs >= 20, test_cases * runs >= 10000."
)
if os.path.exists(args.csv) and not args.skip_test:
if not args.overwrite:
raise RuntimeError(
f"Output file {args.csv} existed. Please remove the file, or use either --skip_test or --overwrite."
)
else:
logger.info("Remove existing file %s since --overwrite is specified", args.csv)
os.remove(args.csv)
task = ParityTask(args.test_cases, args.runs, args.csv)
if not args.skip_test:
run_parity(task, args)
try:
rows = load_results_from_csv(task.csv_path)
except Exception:
logger.exception(f"Failed to load csv {task.csv_path}")
rows = task.results
logger.info("Start running significance tests...")
summary_csv = task.csv_path.replace(".csv", ".stats.csv")
run_significance_test(rows, summary_csv)

View File

@ -0,0 +1,501 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script helps evaluation of GPT-2 model.
import logging
import math
import os
import statistics
import timeit
import numpy
import torch
from benchmark_helper import Precision
from gpt2_helper import Gpt2Helper, Gpt2Inputs
logger = logging.getLogger(__name__)
class Gpt2Metric:
def __init__(self, treatment_name, baseline_name="Torch", top_k=20):
assert top_k > 1 and top_k <= 100
self.baseline = baseline_name
self.treatment = treatment_name
self.name: str = f"{treatment_name} vs {baseline_name}"
self.top_k = top_k
self.top_1_error: int = 0
self.top_k_error: int = 0
self.total_samples: int = 0
self.max_logits_diff: float = 0 # for non-empty past state
self.max_logits_diff_no_past: float = 0 # for empty past state
self.batch_top1_error: torch.FloatTensor = None # top 1 error for current batch
self.batch_topk_error: torch.FloatTensor = None # top k error for current batch
self.seq_len_latency = {}
def print(self):
if self.baseline != self.treatment:
print("---")
print(f"Metrics for {self.treatment} (baseline={self.baseline}):")
if self.total_samples > 0:
top_1_error_rate = 100.0 * self.top_1_error / self.total_samples
top_k_error_rate = 100.0 * self.top_k_error / self.total_samples
print(
f"Total={self.total_samples} Top1Error={self.top_1_error} ({top_1_error_rate:.2f}%) Top{self.top_k}Error={self.top_k_error} ({top_k_error_rate:.2f}%)"
)
print("Max logits diffs:")
print(f"\twith past = {self.max_logits_diff:.6f}")
print(f"\tempty past = {self.max_logits_diff_no_past:.6f}")
else:
print(f"Metrics for {self.treatment} (baseline):")
if self.seq_len_latency:
print("Past sequence length range and average latency:")
total = 0
count = 0
for key in sorted(self.seq_len_latency.keys()):
average = statistics.mean(self.seq_len_latency[key]) * 1000.0
if key == 0:
print(f"\t{key}: \t{average:.2f} ms")
else:
print(f"\t[{2**key}, {2 ** (key + 1) - 1}]:\t{average:.2f} ms")
total += average * len(self.seq_len_latency[key])
count += len(self.seq_len_latency[key])
print(f"Average Latency: {total / count:.2f} ms")
def diff_logits(self, baseline_logits, treatment_logits, is_empty_past: bool):
diff = (baseline_logits - treatment_logits).abs().max()
if is_empty_past:
self.max_logits_diff_no_past = max(self.max_logits_diff_no_past, diff)
else:
self.max_logits_diff = max(self.max_logits_diff, diff)
return diff
def start_batch(self, batch_size: int):
self.total_samples += batch_size
self.batch_top1_error = torch.zeros((batch_size, 1), dtype=torch.bool)
self.batch_topk_error = torch.zeros((batch_size, 1), dtype=torch.bool)
def eval_batch(self, baseline, treatment, past_seq_len, verbose=True):
self._eval_topk(baseline.top_1_tokens, treatment.top_1_tokens, 1, verbose)
self._eval_topk(baseline.top_k_tokens, treatment.top_k_tokens, self.top_k, verbose)
max_diff = self.diff_logits(baseline.logits, treatment.logits, past_seq_len == 0)
if verbose:
print(f"Max logits diffs of {self.name}: {max_diff}")
def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True):
if not torch.all(torch.eq(baseline_topk, treatment_topk)):
if top_k == 1:
if verbose:
print(f"Generated tokens not matched for {self.name}")
self.batch_top1_error |= torch.eq(baseline_topk, treatment_topk).logical_not()
else:
if verbose:
print(
f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results"
)
self.batch_topk_error |= (
torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0
)
def end_batch(self):
self.top_1_error += self.batch_top1_error.sum()
self.top_k_error += self.batch_topk_error.sum()
def add_latency(self, past_seq_len, latency):
key = int(math.log2(past_seq_len)) + 1 if past_seq_len > 0 else 0
if key not in self.seq_len_latency:
self.seq_len_latency[key] = []
self.seq_len_latency[key].append(latency)
class Gpt2Tester:
def __init__(
self,
input_ids,
position_ids,
attention_mask,
num_attention_heads,
hidden_size,
num_layer,
device,
is_fp16=False,
top_k=20,
top_k_required_order=False,
):
self.batch_size = input_ids.shape[0]
self.input_length = input_ids.shape[1]
self.n_layer = num_layer
self.input_ids = input_ids
self.position_ids = position_ids
self.attention_mask = attention_mask
self.has_position_ids = position_ids is not None
self.has_attention_mask = attention_mask is not None
# Empty past state for first inference
self.past = []
past_shape = [
2,
self.batch_size,
num_attention_heads,
0,
hidden_size // num_attention_heads,
]
for _i in range(num_layer):
empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32)
self.past.append(empty_past.to(device))
self.logits = None
self.top_1_tokens = None
self.top_k_tokens = None
self.top_k = top_k
self.top_k_required_order = top_k_required_order
def get_inputs(self) -> Gpt2Inputs:
return Gpt2Inputs(self.input_ids, self.position_ids, self.attention_mask, self.past)
def save_test_data(self, session, output, save_test_data_dir, test_case_id):
from onnx import numpy_helper
path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id))
if os.path.exists(path):
print(f"Directory {path} existed. Skip saving test data")
return
os.makedirs(path, exist_ok=True)
def add_tensor(input_tensors, torch_tensor, name):
input_tensors.append(numpy_helper.from_array(torch_tensor.clone().cpu().numpy(), name))
input_tensors = []
add_tensor(input_tensors, self.input_ids, "input_ids")
if self.has_position_ids:
add_tensor(input_tensors, self.position_ids, "position_ids")
if self.has_attention_mask:
add_tensor(input_tensors, self.attention_mask, "attention_mask")
for i in range(self.n_layer):
add_tensor(input_tensors, self.past[i], "past_" + str(i))
for i, tensor in enumerate(input_tensors):
with open(os.path.join(path, f"input_{i}.pb"), "wb") as f:
f.write(tensor.SerializeToString())
output_names = [output.name for output in session.get_outputs()]
for i, _name in enumerate(output_names):
tensor = numpy_helper.from_array(
output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy()
)
with open(os.path.join(path, f"output_{i}.pb"), "wb") as f:
f.write(tensor.SerializeToString())
print(f"Test data saved to directory {path}")
def update(self, output, step, device):
"""
Update the inputs for next inference.
"""
self.logits = (
torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu()
)
self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits)
self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order)
self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device)
if self.has_position_ids:
self.position_ids = (
torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device)
)
if self.has_attention_mask:
self.attention_mask = torch.cat(
[
self.attention_mask,
torch.ones([self.batch_size, 1]).type_as(self.attention_mask),
],
1,
).to(device)
self.past = []
if isinstance(output[1], tuple): # past in torch output is tuple
self.past = list(output[1])
else:
for i in range(self.n_layer):
past_i = (
torch.from_numpy(output[i + 1])
if isinstance(output[i + 1], numpy.ndarray)
else output[i + 1].clone().detach()
)
self.past.append(past_i.to(device))
def diff(self, baseline):
"""
Compare inputs and logits output.
"""
print("start diff...")
if self.logits is not None:
max_io_diff = (self.logits - baseline.logits).abs().max()
if max_io_diff > 1e-4:
print(f"Max logits difference is too large: {max_io_diff}")
if not torch.all(self.input_ids == baseline.input_ids):
print("Input_ids is different", self.input_ids, baseline.input_ids)
if self.has_position_ids:
if not torch.all(self.position_ids == baseline.position_ids):
print(
"position_ids is different",
self.position_ids,
baseline.position_ids,
)
if self.has_attention_mask:
if not torch.all(self.attention_mask == baseline.attention_mask):
print(
"attention_mask is different",
self.attention_mask,
baseline.attention_mask,
)
assert len(self.past) == len(baseline.past)
for i, past_i in enumerate(self.past):
assert past_i.shape == baseline.past[i].shape
if past_i.nelement() > 0:
max_past_diff = (past_i - baseline.past[i]).abs().max()
if max_past_diff > 1e-4:
print(f"max_past_diff[{i}]={max_past_diff}")
@staticmethod
def predict_next_token(logits, top_k=1, required_order=False):
"""
Get top k topkens based on logits.
"""
# logits has shape (batch_size, seq_len, vocab_size)
# last token logits has shape (batch_size, vocab_size)
lastTokenLogits = logits[:, -1] # noqa: N806
if top_k == 1:
generatedTokens = torch.argmax(lastTokenLogits, 1, True) # noqa: N806
return generatedTokens
else:
topk = torch.argsort(lastTokenLogits, -1, descending=True)[:, :top_k]
if not required_order:
sorted_topk, _ = topk.sort()
return sorted_topk
return topk
@staticmethod
def diff_present(onnx_output, onnx_io_output, n_layer):
"""
Compare the present outputs of two outputs from ONNX Runtime.
"""
present_diff_max = []
for i in range(n_layer):
onnx_present_i = (
torch.from_numpy(onnx_output[i + 1])
if isinstance(onnx_output[i + 1], numpy.ndarray)
else onnx_output[i + 1]
)
onnx_io_present_i = (
torch.from_numpy(onnx_io_output[i + 1])
if isinstance(onnx_io_output[i + 1], numpy.ndarray)
else onnx_io_output[i + 1]
)
max_diff = (onnx_present_i - onnx_io_present_i).abs().max()
present_diff_max.append(max_diff)
print(f"present_diff_max={present_diff_max}")
@staticmethod
def is_quantized_onnx_model(onnx_model_path):
"""
Returns True if the ONNX model is quantized.
"""
from onnx import load
model = load(onnx_model_path)
from onnxruntime.quantization.quantize import __producer__ as quantize_producer
return model.producer_name == quantize_producer
@staticmethod
def test_generation(
session,
model,
device,
test_inputs,
precision=Precision.FLOAT32,
model_class="Gpt2LMHeadModel",
top_k=20,
top_k_no_order=True,
max_steps=24,
max_inputs=0,
verbose=False,
save_test_data=0,
save_test_data_dir=".",
):
"""
Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model.
It will print top 1 and top k errors on the given test inputs.
"""
print(
f"start test generation: (top_k={top_k} top_k_no_order={top_k_no_order} max_steps={max_steps} test_inputs={len(test_inputs)} max_inputs={max_inputs})"
)
n_layer = model.config.n_layer
n_head = model.config.n_head
n_embd = model.config.n_embd
eos_token_id = model.config.eos_token_id
test_data_saved = 0
is_float16 = precision == Precision.FLOAT16
if is_float16:
assert "float16" in session.get_outputs()[0].type
# We will still use fp32 torch model as baseline when onnx model if fp16
model.eval().to(device)
# Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later.
init_output_shapes = Gpt2Helper.get_output_shapes(
batch_size=4,
past_sequence_length=128,
sequence_length=32,
config=model.config,
model_class=model_class,
)
output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16)
baseline_name = "Torch"
treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx"
torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k)
onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k)
onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k)
for i, inputs in enumerate(test_inputs):
if max_inputs > 0 and i == max_inputs:
break
if i % 10 == 0:
print(f"{i}")
input_ids = inputs["input_ids"]
position_ids = inputs.get("position_ids", None)
attention_mask = inputs.get("attention_mask", None)
onnx_runner = Gpt2Tester(
input_ids,
position_ids,
attention_mask,
n_head,
n_embd,
n_layer,
device,
is_float16,
top_k,
not top_k_no_order,
)
onnx_io_runner = Gpt2Tester(
input_ids,
position_ids,
attention_mask,
n_head,
n_embd,
n_layer,
device,
is_float16,
top_k,
not top_k_no_order,
)
torch_runner = Gpt2Tester(
input_ids,
position_ids,
attention_mask,
n_head,
n_embd,
n_layer,
device,
False,
top_k,
not top_k_no_order,
) # Torch model baseline is fp32
batch_size = torch_runner.batch_size
onnx_metric.start_batch(batch_size)
onnx_io_metric.start_batch(batch_size)
with torch.no_grad():
done = torch.zeros(batch_size, dtype=torch.bool)
for step in range(max_steps):
seq_len = list(onnx_runner.input_ids.size())[1]
past_seq_len = list(onnx_runner.past[0].size())[3]
start_time = timeit.default_timer()
pytorch_output = Gpt2Helper.pytorch_inference(model, torch_runner.get_inputs())
torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time)
torch_runner.update(pytorch_output, step, device)
onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference(
session, onnx_runner.get_inputs(), total_runs=1
)
onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
onnx_runner.update(onnx_output, step, device)
output_shapes = Gpt2Helper.get_output_shapes(
batch_size,
past_seq_len,
seq_len,
model.config,
model_class=model_class,
)
Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes)
(
onnx_io_output,
avg_latency_ms,
) = Gpt2Helper.onnxruntime_inference_with_binded_io(
session,
onnx_io_runner.get_inputs(),
output_buffers,
output_shapes,
total_runs=1,
return_numpy=False,
include_copy_output_latency=True,
)
onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
if test_data_saved < save_test_data:
onnx_io_runner.save_test_data(session, onnx_io_output, save_test_data_dir, test_data_saved)
test_data_saved += 1
onnx_io_runner.update(onnx_io_output, step, device)
if verbose:
onnx_runner.diff(onnx_io_runner)
Gpt2Tester.diff_present(onnx_output, onnx_io_output, n_layer)
print("Top 1 tokens:")
print("\tTorch", torch_runner.top_1_tokens)
print("\tONNX", onnx_runner.top_1_tokens)
print("\tONNX with IO binding", onnx_io_runner.top_1_tokens)
onnx_metric.eval_batch(torch_runner, onnx_runner, past_seq_len, verbose=verbose)
onnx_io_metric.eval_batch(torch_runner, onnx_io_runner, past_seq_len, verbose=verbose)
done = done | (torch_runner.top_1_tokens == eos_token_id).any()
if torch.all(done):
break
onnx_metric.end_batch()
onnx_io_metric.end_batch()
torch_metric.print()
onnx_metric.print()
onnx_io_metric.print()

View File

@ -0,0 +1,146 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script helps debugging parity issue for two same onnx models with fp16 and fp32 format
# Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON
import math
import multiprocessing
import os
from pathlib import Path
import numpy
import torch
from benchmark_helper import create_onnxruntime_session
from gpt2_helper import Gpt2Helper
from onnx import TensorProto, numpy_helper
NON_ZERO_VALUE = str(1)
ZERO_VALUE = str(0)
def environ_setting_nodes(node_name_filter=None, node_type_filter=None):
# Set I/O data as default
os.environ["ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA"] = ZERO_VALUE
os.environ["ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA"] = NON_ZERO_VALUE
os.environ["ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA"] = NON_ZERO_VALUE
if node_name_filter is not None:
os.environ["ORT_DEBUG_NODE_IO_NAME_FILTER"] = node_name_filter
elif node_type_filter is not None:
os.environ["ORT_DEBUG_NODE_IO_OP_TYPE_FILTER"] = node_type_filter
else:
os.environ["ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK"] = NON_ZERO_VALUE
def environ_setting_paths(output_path):
# Set dumping values to files as default
os.environ["ORT_DEBUG_NODE_IO_DUMP_DATA_DESTINATION"] = "files"
os.environ["ORT_DEBUG_NODE_IO_OUTPUT_DIR"] = output_path
def environ_reset():
for flag in [
"ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA",
"ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA",
"ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA",
"ORT_DEBUG_NODE_IO_NAME_FILTER",
"ORT_DEBUG_NODE_IO_OP_TYPE_FILTER",
"ORT_DEBUG_NODE_IO_DUMP_DATA_TO_FILES",
"ORT_DEBUG_NODE_IO_OUTPUT_DIR",
"ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK",
]:
if flag in os.environ:
del os.environ[flag]
def inference(model_path, dummy_inputs, outputs_path, use_gpu):
environ_reset()
environ_setting_nodes()
environ_setting_paths(outputs_path)
session = create_onnxruntime_session(model_path, use_gpu, enable_all_optimization=False)
Gpt2Helper.onnxruntime_inference(session, dummy_inputs)
def generate_outputs_files(model_path, dummy_inputs, outputs_path, use_gpu):
dir_path = Path(outputs_path)
if dir_path.exists() and dir_path.is_dir():
import shutil
shutil.rmtree(outputs_path)
dir_path.mkdir(parents=True, exist_ok=True)
process = multiprocessing.Process(target=inference, args=(model_path, dummy_inputs, outputs_path, use_gpu))
process.start()
process.join()
def post_processing(outputs_path, outputs_path_other):
# Compare outputs with e.g. fp16 and fp32
record = {}
if_close = {}
import glob
for filename in glob.glob(os.path.join(outputs_path, "*.tensorproto")):
filename_other = os.path.join(outputs_path_other, Path(filename).name)
if not os.path.exists(filename_other):
continue
with open(filename, "rb") as f:
tensor = TensorProto()
tensor.ParseFromString(f.read())
array = numpy_helper.to_array(tensor)
with open(filename_other, "rb") as f: # noqa: PLW2901
tensor_other = TensorProto()
tensor_other.ParseFromString(f.read())
array_other = numpy_helper.to_array(tensor_other)
if array_other.size == 0:
continue
diff = numpy.average(numpy.abs(array_other - array) / (numpy.abs(array_other) + 1e-6))
if math.isnan(diff):
continue
record[Path(filename).name.split(".")[0]] = diff
if_close[Path(filename).name.split(".")[0]] = numpy.allclose(array, array_other, rtol=1e-04, atol=1e-04)
results = ["Node\tDiff\tClose"]
for k, v in sorted(record.items(), key=lambda x: x[1], reverse=True):
results.append(f"{k}\t{v}\t{if_close[k]}")
for line in results:
print(line)
if __name__ == "__main__":
# Below example shows how to use this helper to investigate parity issue of gpt-2 fp32 and fp16 onnx model
# Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON !!
multiprocessing.set_start_method("spawn")
# Generate Inputs
sequence_length = 8
past_sequence_length = 8
batch_size = 5
dummy_inputs_fp16 = Gpt2Helper.get_dummy_inputs(
batch_size,
past_sequence_length,
sequence_length,
12,
768,
12,
50257,
device=torch.device("cpu"),
float16=True,
)
dummy_inputs_fp32 = dummy_inputs_fp16.to_fp32()
# Get GPT-2 model from huggingface using convert_to_onnx.py
os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp32.onnx -o -p fp32 --use_gpu")
os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp16.onnx -o -p fp16 --use_gpu")
# Specify the directory to dump the node's I/O
outputs_path_fp32_gpu = "./fp32_gpu"
outputs_path_fp16_gpu = "./fp16_gpu"
generate_outputs_files("./gpt2_fp32.onnx", dummy_inputs_fp32, outputs_path_fp32_gpu, use_gpu=True)
generate_outputs_files("./gpt2_fp16.onnx", dummy_inputs_fp16, outputs_path_fp16_gpu, use_gpu=True)
# Compare each node's I/O value and sort based on average rtol
post_processing(outputs_path_fp16_gpu, outputs_path_fp32_gpu)

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,703 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import gc
import itertools
import logging
import os
import sys
import time
import numpy as np
import onnx
import psutil
import torch
from benchmark_helper import measure_memory, setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings_as_ortvalues,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import onnxruntime as ort
logger = logging.getLogger(__name__)
# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
def get_ort_model_inputs_len(args, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
return 0
if args.benchmark_type == "hf-ort":
try:
# New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
return len(model.inputs_names)
except Exception:
# Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
return len(model.decoder.input_names)
return len(model.get_inputs())
def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
init_inputs, iter_inputs = None, None
# For past_present_share_buffer:
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
# Set max_seq_len to config value for other models
max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
elif args.benchmark_type in {"hf-ort"}:
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
# Using split models in Optimum (e.g. created by Optimum export)
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
else:
# Using merged model in Optimum (e.g. created by convert_to_onnx export)
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
elif args.benchmark_type == "ort-convert-to-onnx":
# Microsoft export from convert_to_onnx
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
elif args.benchmark_type == "ort-msft":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
init_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
iter_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
else:
raise Exception("Unable to auto-detect inputs for provided model")
return init_inputs, iter_inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
# 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
# 3) Benchmark LLaMA-2 from local download of model
# 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
# 5) Benchmark LLaMA-2 from convert_to_onnx
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_cache=True,
cache_dir=args.cache_dir,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export or convert_to_onnx.py export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
decoder_file_name = None
decoder_with_past_file_name = None
for filename in os.listdir(args.hf_ort_dir_path):
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
continue
if "decoder_model" in filename or filename == "model.onnx":
decoder_file_name = filename
if "decoder_with_past_model" in filename:
decoder_with_past_file_name = filename
if "decoder_merged_model" in filename:
decoder_file_name = filename
decoder_with_past_file_name = filename
start_time = time.time()
model = ORTModelForCausalLM.from_pretrained(
args.hf_ort_dir_path,
decoder_file_name=decoder_file_name,
decoder_with_past_file_name=decoder_with_past_file_name,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
use_merged=(True if decoder_file_name == "model.onnx" else None),
provider=provider,
provider_options=provider_options,
session_options=sess_options,
)
end_time = time.time()
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path.format(args.rank),
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(inputs)
logger.info(outputs)
input_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_inputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
output_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_outputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
for _ in warmup_range:
input_sync()
fn(inputs)
output_sync()
# Benchmark
total_time = 0
bench_range = (
range(args.num_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
input_sync()
start_time = time.time()
fn(inputs)
output_sync()
end_time = time.time()
total_time += end_time - start_time
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
logger.info("")
latency = total_time / args.num_runs
throughput = args.batch_size / latency
if args.rank == 0:
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} tps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
if args.rank == 0:
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, init_inputs, iter_inputs, model):
# Inference steps to measure
def get_logits(inputs):
# Inference pass without decoding
outputs = model(**inputs)
return outputs
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_pred_ids(inputs):
# # Inference pass with predicted token ids generation
# predicted_ids = model.generate(**inputs)
# return predicted_ids
# def gen_and_dec(inputs):
# # Inference pass with generation and decoding
# predicted_ids = get_pred_ids(inputs)
# transcription = []
# for bs in range(args.batch_size):
# for rs in range(args.num_return_sequences):
# transcription.append(
# args.tokenizer.batch_decode(
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
# )[0]
# )
# return transcription
generate_fn = get_logits
if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(init_inputs)
generate_fn(iter_inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder_with_past.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
time_fn(args, generate_fn, init_inputs)
measure_fn(args, generate_fn, init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
time_fn(args, generate_fn, iter_inputs)
measure_fn(args, generate_fn, iter_inputs)
def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding, kv_cache_ortvalues
return inputs, kv_cache_ortvalues
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
kv_cache_ortvalues = {}
if args.profile:
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
# Turn profiling off to stop appending to log
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)
def run_inference(args, init_inputs, iter_inputs, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, init_inputs, iter_inputs, model)
elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
run_ort_inference(args, init_inputs, iter_inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def get_args(rank=0):
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=[
"hf-pt-eager",
"hf-pt-compile",
"hf-ort",
"ort-msft",
"ort-convert-to-onnx",
],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
)
# Args for choosing the model
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-dir-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
default="32 64 128 256 512",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Args for decoding logic
parser.add_argument("--max-length", type=int, default=32)
parser.add_argument("--num-return-sequences", type=int, default=1)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--cache-dir",
type=str,
required=True,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
args.device = "cuda"
# Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.sequence_lengths = args.sequence_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
# Check that only one (batch_size, sequence_length) combination is set for profiling
if args.profile:
assert (
len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1
), "Please provide only one (batch_size, sequence_length) combination for profiling"
return args
def main():
rank = get_rank()
world_size = get_size()
args = get_args(rank)
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
args.rank = rank
args.world_size = world_size
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
config = AutoConfig.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "tokenizer", tokenizer) # noqa: B010
setattr(args, "config", config) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
# Get model and model info
model = get_model(args)
ort_model_inputs_len = get_ort_model_inputs_len(args, model)
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
else:
setattr(args, "use_buffer_share", False) # noqa: B010
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
if args.rank == 0:
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
setattr(args, "batch_size", int(batch_size)) # noqa: B010
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
run_inference(args, init_inputs, iter_inputs, model)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,492 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import json
import logging
import os
import subprocess
import torch
from benchmark_helper import setup_logger
from metrics import BenchmarkRecord
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--batch-sizes",
type=str,
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
type=str,
default="8 16 32 64 128 256 512",
)
parser.add_argument(
"-w",
"--warmup-runs",
type=int,
default=5,
)
parser.add_argument(
"-n",
"--num-runs",
type=int,
default=1000,
)
parser.add_argument(
"--hf-pt-eager",
default=False,
action="store_true",
help="Benchmark in PyTorch without `torch.compile`",
)
parser.add_argument(
"--hf-pt-compile",
default=False,
action="store_true",
help="Benchmark in PyTorch with `torch.compile`",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
parser.add_argument(
"--ort-msft-model-path",
type=str,
default="",
help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx",
)
parser.add_argument(
"--ort-convert-to-onnx-model-path",
type=str,
default="",
help="Path to ONNX model from convert_to_onnx",
)
parser.add_argument(
"--cache-dir",
type=str,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name in Hugging Face",
)
parser.add_argument(
"--precision",
type=str,
required=True,
choices=["int4", "int8", "fp16", "fp32"],
help="Precision to run model",
)
parser.add_argument(
"--device",
type=str,
required=True,
choices=["cpu", "cuda", "rocm"],
help="Device to benchmark models",
)
parser.add_argument(
"--device-id",
type=int,
default=0,
help="GPU device ID",
)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Print detailed logs",
)
parser.add_argument(
"--timeout",
type=int,
default=10,
help="Number of mins to attempt the benchmark before moving on",
)
parser.add_argument(
"--log-folder",
type=str,
default=None,
help="Path to folder to save logs and results",
)
args = parser.parse_args()
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
log_folder_name = f"./{args.model_size}_{args.precision}"
if not args.log_folder:
args.log_folder = log_folder_name
os.makedirs(args.log_folder, exist_ok=True)
# Convert timeout value to secs
args.timeout *= 60
return args
def process_log_file(device_id, log_file, base_results):
entries = []
batch_size, sequence_length, step = None, None, None
latency_s, latency_ms, throughput, memory = None, None, None, None
batch_pattern = "Batch Size: "
sequence_pattern = "Sequence Length: "
prompt_step_pattern = "to get past_key_values"
per_token_step_pattern = "with past_key_values"
latency_pattern = "Latency: "
throughput_pattern = "Throughput: "
memory_pattern = "peak="
with open(log_file) as f:
for input_line in f:
line = input_line.replace("\n", "")
if batch_pattern in line:
batch_size = int(line[len(batch_pattern) :])
elif sequence_pattern in line:
sequence_length = int(line[len(sequence_pattern) :])
elif prompt_step_pattern in line:
step = "prompt"
elif per_token_step_pattern in line:
step = "per-token"
elif latency_pattern in line:
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
latency_ms = latency_s * 1000
elif throughput_pattern in line:
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
elif memory_pattern in line:
if "CPU" in line:
# Example format for log entry:
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
else:
# Example format for log entry:
# GPU memory usage: before=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 69637.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] peak=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 73861.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}]
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
usage = json.loads(peak)[device_id]["max_used_MB"]
memory = float(usage) / 1000
# Append log entry to list of entries
entry = base_results + [ # noqa: RUF005
batch_size,
sequence_length,
step,
latency_s,
latency_ms,
throughput,
memory,
]
entries.append(entry)
return entries
def save_results(results, filename):
import pandas as pd
df = pd.DataFrame(
results,
columns=[
"Warmup Runs",
"Measured Runs",
"Model Name",
"Engine",
"Precision",
"Device",
"Batch Size",
"Sequence Length",
"Step",
"Latency (s)",
"Latency (ms)",
"Throughput (tps)",
"Memory (GB)",
],
)
# Set column types
df["Warmup Runs"] = df["Warmup Runs"].astype("int")
df["Measured Runs"] = df["Measured Runs"].astype("int")
df["Batch Size"] = df["Batch Size"].astype("int")
df["Sequence Length"] = df["Sequence Length"].astype("int")
df["Latency (s)"] = df["Latency (s)"].astype("float")
df["Latency (ms)"] = df["Latency (ms)"].astype("float")
df["Throughput (tps)"] = df["Throughput (tps)"].astype("float")
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
# get package name and version
import pkg_resources
installed_packages = pkg_resources.working_set
installed_packages_list = sorted(
[
f"{i.key}=={i.version}"
for i in installed_packages
if i.key in ["ort-nightly-gpu", "ort-nightly", "onnxruntime", "onnxruntime-gpu"]
]
)
ort_pkg_name = ""
ort_pkg_version = ""
if installed_packages_list:
ort_pkg_name = installed_packages_list[0].split("==")[0]
ort_pkg_version = installed_packages_list[0].split("==")[1]
# Save results to csv with standard format
records = []
for _, row in df.iterrows():
if row["Engine"] in ["optimum-ort", "onnxruntime"]:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], "onnxruntime", row["Device"], ort_pkg_name, ort_pkg_version
)
elif row["Engine"] in ["pytorch-eager", "pytorch-compile"]:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], "pytorch", row["Device"], torch.__name__, torch.__version__
)
else:
record = BenchmarkRecord(row["Model Name"], row["Precision"], row["Engine"], row["Device"], "", "")
record.config.warmup_runs = row["Warmup Runs"]
record.config.measured_runs = row["Measured Runs"]
record.config.batch_size = row["Batch Size"]
record.config.seq_length = row["Sequence Length"]
record.config.customized["measure_step"] = row["Step"]
record.config.customized["engine"] = row["Engine"]
record.metrics.customized["latency_s_mean"] = row["Latency (s)"]
record.metrics.latency_ms_mean = row["Latency (ms)"]
record.metrics.customized["throughput_tps"] = row["Throughput (tps)"]
record.metrics.max_memory_usage_GB = row["Memory (GB)"]
records.append(record)
BenchmarkRecord.save_as_csv(filename, records)
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
logger.info(f"Results saved in {filename}!")
def benchmark(args, benchmark_cmd, engine):
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
log_path = os.path.join(args.log_folder, log_filename)
with open(log_path, "w") as log_file:
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
try:
process.wait(args.timeout)
except subprocess.TimeoutExpired:
process.kill()
# Create entries for csv
logger.info("Gathering data from log files...")
base_results = [args.warmup_runs, args.num_runs, args.model_name, engine, args.precision, args.device]
results = process_log_file(args.device_id, log_path, base_results)
return results
def main():
args = get_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
all_results = []
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
# Benchmark PyTorch without torch.compile
if args.hf_pt_eager:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-pt-eager",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch without torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-eager")
all_results.extend(results)
# Benchmark PyTorch with torch.compile
if args.hf_pt_compile:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-pt-compile",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch with torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-compile")
all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
if args.hf_ort_dir_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-ort",
"--hf-ort-dir-path",
args.hf_ort_dir_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark Optimum + ONNX Runtime")
results = benchmark(args, benchmark_cmd, "optimum-ort")
all_results.extend(results)
# Benchmark Microsoft model in ONNX Runtime
if args.ort_msft_model_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"ort-msft",
"--ort-model-path",
args.ort_msft_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark Microsoft model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "ort-msft")
all_results.extend(results)
# Benchmark convert_to_onnx model in ONNX Runtime
if args.ort_convert_to_onnx_model_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"ort-convert-to-onnx",
"--ort-model-path",
args.ort_convert_to_onnx_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime")
all_results.extend(results)
csv_file = f"{args.model_size}_{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_results, os.path.join(args.log_folder, csv_file))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,606 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This is an end-to-end benchmarking script for the Hugging Face LLaMA-2 model.
#
# Prerequisites:
# 1) Install `huggingface-cli`:
#
# $ pip install huggingface_hub
#
# 2) Authenticate with Hugging Face's CLI:
#
# $ huggingface-cli login
#
# 3) Accept Meta's license in Hugging Face to access the models at https://huggingface.co/meta-llama/
#
# 4) Install the latest ONNX Runtime version
#
# $ pip install onnxruntime-gpu
#
# 5) Install flash attention v2
#
# $ pip install flash-attn --no-build-isolation
#
# 6) Install bitsandbytes
#
# $ pip install bitsandbytes
from __future__ import annotations
import argparse
import datetime
import gc
import itertools
import json
import logging
import os
import textwrap
import time
import numpy as np
import pandas as pd
import torch
from benchmark_helper import setup_logger
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import onnxruntime as ort
logger = logging.getLogger(__name__)
def get_model(args: argparse.Namespace):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
model = None
if args.onnx_precision == "int4" and args.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
try:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
except Exception as e:
# When flash_attention or sdpa doesn't support a model, it throws an exception.
# Rather than stopping a process, run as eager mode.
print("Try to load a model using eager mode: ", e)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="eager",
).to(args.target_device)
model.eval()
if args.benchmark_type == "pt-compile":
model = torch.compile(model)
else:
sess_options = ort.SessionOptions()
ep = (
("CUDAExecutionProvider", {"device_id": args.device_id})
if args.device == "cuda"
else "CPUExecutionProvider"
)
model = ort.InferenceSession(args.onnx_model_path, sess_options=sess_options, providers=[ep])
return model
def run_inference(args, model, runs, inputs, outputs):
if args.benchmark_type == "pt-compile":
with torch.no_grad():
outputs = model(**inputs)
# Synchronize inputs
io_binding = None
if args.benchmark_type in {"pt-eager", "pt-compile"}:
if args.device != "cpu":
torch.cuda.synchronize(args.target_device)
else:
io_binding = add_io_bindings_as_tensors(model, inputs, outputs, args.use_fp16, args.use_buffer_share)
io_binding.synchronize_inputs()
# Run inference
start = time.perf_counter()
for _ in range(runs):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
with torch.no_grad():
outputs = model(**inputs)
if args.device != "cpu":
torch.cuda.synchronize(args.target_device)
else:
model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
end = time.perf_counter()
avg = (end - start) / runs
return avg, outputs
def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
clear_cache()
inputs, outputs = get_initial_inputs_and_outputs(
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
)
_, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
return inputs, outputs
def clear_cache():
gc.collect()
torch.cuda.empty_cache()
def save_results(results, filename, gen_length):
df = pd.DataFrame(
results,
columns=[
"Batch Size",
"Prompt Length",
"Prompt Processing Latency (ms)",
"Prompt Processing Throughput (tps)",
"Sampling Latency (ms)",
"Sampling Throughput (tps)",
"First Token Generated Latency (ms)",
"First Token Generated Throughput (tps)",
f"Average Latency of First {gen_length // 2} Tokens Generated (ms)",
f"Average Throughput of First {gen_length // 2} Tokens Generated (tps)",
f"Average Latency of First {gen_length} Tokens Generated (ms)",
f"Average Throughput of First {gen_length} Tokens Generated (tps)",
"Wall-Clock Latency (s)",
"Wall-Clock Throughput (tps)",
],
)
df.to_csv(filename, index=False)
logger.info(f"Results saved in {filename}!")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=["pt-eager", "pt-compile", "ort"],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=False,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a",
"--auth",
default=False,
action="store_true",
help="Use Hugging Face authentication token to access model",
)
parser.add_argument(
"-t",
"--trust",
default=False,
action="store_true",
help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files",
)
parser.add_argument(
"-c",
"--cache-dir",
type=str,
default=os.path.join(".", "model_cache"),
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(model_name, cache_dir=cache_dir)`.",
)
parser.add_argument(
"--hf-dir-path",
type=str,
default="",
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(folder_path)`.",
)
parser.add_argument(
"-o",
"--onnx-model-path",
required=False,
help="Path to ONNX model",
)
parser.add_argument(
"-f",
"--prompts-file",
required=True,
default=os.path.join(".", "models", "llama", "prompts.json"),
help="JSON file containing entries in the format 'prompt length: prompt' where prompt length = tokenized length of prompt",
)
parser.add_argument(
"--use_buffer_share",
default=False,
action="store_true",
help="Use when GroupQueryAttention (GQA) is in ONNX model",
)
parser.add_argument(
"--anomaly-filtering",
default=False,
action="store_true",
help="Use this flag to filter anomaly accelerator times for tokens generated. \
This may give more accurate latency and throughput metrics for tokens generated. \
Wall-clock metrics are still reported with anomaly times though.",
),
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--prompt-lengths",
default="16 64 256 1024",
)
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"-g",
"--generation-length",
type=int,
default=256,
help="Number of new tokens to generate",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=100)
parser.add_argument("--seed", type=int, default=2)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
# Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "ort":
assert args.onnx_model_path, "Please specify a path to `--onnx-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.prompt_lengths = args.prompt_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
setattr(args, "onnx_precision", args.precision) # noqa: B010
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32
engine = "ort" if args.benchmark_type == "ort" else "pt"
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "torch_dtype", torch_dtype) # noqa: B010
setattr(args, "engine", engine) # noqa: B010
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
args.use_buffer_share = args.use_buffer_share and engine == "ort"
return args
def main():
args = get_args()
setup_logger(False)
logger.info(args.__dict__)
# Get prompts and prompt sizes
size_to_prompt = None
with open(args.prompts_file) as f:
size_to_prompt = json.load(f, object_hook=lambda d: {int(k): v for k, v in d.items()})
# Get config, tokenizer, and model
config = AutoConfig.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.trust,
)
tokenizer = AutoTokenizer.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.trust,
)
model = get_model(args)
all_csv_metrics = []
for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901
logger.info(f"Running batch size = {batch_size}, prompt length = {prompt_length}")
clear_cache()
max_length = prompt_length + args.generation_length
if prompt_length not in size_to_prompt:
raise NotImplementedError(
textwrap.dedent(
f"""
A prompt of size {prompt_length} was not found in '{args.prompts_file}'. There are a couple of solutions to fix this.
1) You can change one of the keys in '{args.prompts_file}' to be {prompt_length}.
If {prompt_length} < actual prompt's length, the benchmark E2E tool will repeat the first word in the prompt until {prompt_length} = actual prompt's length.
If {prompt_length} > actual prompt's length, the benchmark E2E tool will automatically trim the actual prompt's length so that {prompt_length} = actual prompt's length.
2) You can add a new key-value entry in '{args.prompts_file}' of the form '{prompt_length}': 'your prompt goes here'.
"""
)
)
prompt = [size_to_prompt[prompt_length]] * batch_size
csv_metrics = [batch_size, prompt_length]
try:
# Measure prompt processing
logger.info("Measuring prompt processing...")
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)
# Calculate prompt metrics
accelerator_prompt_latency_ms = accelerator_prompt_latency_s * 1000
accelerator_prompt_thrpt = batch_size * (prompt_length / accelerator_prompt_latency_s)
logger.info(f"Average Latency of Prompt Processing: {accelerator_prompt_latency_ms} ms")
logger.info(
f"Average Throughput of Prompt Processing: {batch_size * (prompt_length / accelerator_prompt_latency_s)} tps"
)
csv_metrics.extend([accelerator_prompt_latency_ms, accelerator_prompt_thrpt])
# Measure token generation
logger.info("Measuring token generation...")
clear_cache()
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
all_token_ids = inputs["input_ids"].clone()
current_length = all_token_ids.shape[-1]
num_heads = config.num_key_value_heads
head_size = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
has_eos = torch.zeros(batch_size, device=args.target_device, dtype=torch.bool)
# 0th entry will have prompt accelerator time, 1st entry onwards will have token generation accelerator time
accelerator_times = []
sampling_times = [] # cost to sample after each model run
wall_clock_start_time = time.perf_counter()
while current_length <= max_length:
# Run inference
accelerator_time_latency_s, outputs = run_inference(args, model, 1, inputs, outputs)
accelerator_times.append(accelerator_time_latency_s)
# Sample with argmax (greedy search)
sampling_start_time = time.perf_counter()
if outputs["logits"].shape[1] > 1:
prompt_end_indices = inputs["attention_mask"].sum(1) - 1
idxs = (
prompt_end_indices.unsqueeze(dim=1)
.repeat(1, config.vocab_size)
.view(batch_size, 1, config.vocab_size)
)
next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
else:
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check if we previously reached EOS token id or if generated token id is EOS token id
has_eos = has_eos | next_tokens == tokenizer.eos_token_id
# Determine which new tokens to add to list of all token ids
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
sampling_end_time = time.perf_counter()
sampling_times.append(sampling_end_time - sampling_start_time)
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
current_length += 1
# Update inputs for next inference run
inputs["input_ids"] = tokens_to_add
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
)
if "position_ids" in inputs:
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
outputs["logits"].zero_()
# Update KV caches for next inference run
if args.engine == "pt":
# Update KV caches for PyTorch
inputs["past_key_values"] = outputs["past_key_values"]
elif not args.use_buffer_share:
# Update KV caches for ONNX Runtime if buffer sharing is not used
for i in range(config.num_hidden_layers):
inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]
new_sequence_length = inputs["attention_mask"].shape[1]
for i in range(config.num_hidden_layers):
present_key = torch.zeros(
batch_size,
num_heads,
new_sequence_length,
head_size,
device=args.target_device,
dtype=args.torch_dtype,
)
present_value = torch.zeros(
batch_size,
num_heads,
new_sequence_length,
head_size,
device=args.target_device,
dtype=args.torch_dtype,
)
outputs.update(
{
f"present.{i}.key": present_key.contiguous(),
f"present.{i}.value": present_value.contiguous(),
}
)
wall_clock_end_time = time.perf_counter()
# Filter out any anomaly accelerator times (e.g. for `torch.compile`)
accelerator_times.pop(0) # Remove prompt processing time
if args.anomaly_filtering:
anomaly_threshold_factor = 10
min_time_s = min(accelerator_times)
orig_size = len(accelerator_times)
accelerator_times = list(
filter(lambda acc_time: acc_time < anomaly_threshold_factor * min_time_s, accelerator_times)
)
new_size = len(accelerator_times)
logger.info(
f"Filtered out {orig_size - new_size} anomaly accelerator times that are {anomaly_threshold_factor}x greater than {min_time_s * 1000} ms..."
)
#######################################################
# Calculate sampling and first token generated metrics
#######################################################
# Calculate sampling metrics
avg_sampling_latency_s = sum(sampling_times) / len(sampling_times)
avg_sampling_latency_ms = avg_sampling_latency_s * 1000
avg_sampling_thrpt = batch_size * (1 / avg_sampling_latency_s)
logger.info(f"Average Latency of Sampling: {avg_sampling_latency_ms} ms")
logger.info(f"Average Throughput of Sampling: {avg_sampling_thrpt} tps")
# Calculate first token generated metrics
first_token_latency_s = accelerator_times[0]
first_token_latency_ms = first_token_latency_s * 1000
first_token_thrpt = batch_size * (1 / first_token_latency_s)
logger.info(f"Latency of First Token Generated: {first_token_latency_ms} ms")
logger.info(f"Throughput of First Token Generated: {first_token_thrpt} tps")
####################################################
# Calculate first `halfway` token generated metrics
####################################################
halfway = args.generation_length // 2
halfway_token_latency_s = sum(accelerator_times[:halfway]) / len(accelerator_times[:halfway])
halfway_token_latency_ms = halfway_token_latency_s * 1000
halfway_token_thrpt = batch_size * (1 / halfway_token_latency_s)
logger.info(f"Average Latency of First {halfway} Tokens Generated: {halfway_token_latency_ms} ms")
logger.info(f"Average Throughput of First {halfway} Tokens Generated: {halfway_token_thrpt} tps")
#########################################
# Calculate all tokens generated metrics
#########################################
all_token_latency_s = sum(accelerator_times) / len(accelerator_times)
all_token_latency_ms = all_token_latency_s * 1000
all_token_thrpt = batch_size * (1 / all_token_latency_s)
logger.info(
f"Average Latency of First {args.generation_length} Tokens Generated: {all_token_latency_ms} ms"
)
logger.info(f"Average Throughput of First {args.generation_length} Tokens Generated: {all_token_thrpt} tps")
###############################
# Calculate wall clock metrics
###############################
wall_clock_latency_s = wall_clock_end_time - wall_clock_start_time
wall_clock_thrpt = batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)
logger.info(f"Wall-Clock Latency: {wall_clock_latency_s} s")
logger.info(
f"Wall-Clock Throughput: {batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)} tps"
)
# Add metrics to CSV
logger.info("Adding results to CSV")
csv_metrics.extend(
[
avg_sampling_latency_ms,
avg_sampling_thrpt,
first_token_latency_ms,
first_token_thrpt,
halfway_token_latency_ms,
halfway_token_thrpt,
all_token_latency_ms,
all_token_thrpt,
wall_clock_latency_s,
wall_clock_thrpt,
]
)
all_csv_metrics.append(csv_metrics)
except Exception as e:
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")
filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_csv_metrics, filename, args.generation_length)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,57 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import torch.distributed as dist
def init_dist():
if "LOCAL_RANK" in os.environ:
int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0))
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
else:
# don't need to do init for single process
pass
def _get_comm():
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
return comm
except ImportError:
return None
def get_rank():
comm = _get_comm()
return comm.Get_rank() if comm is not None else 0
def get_size():
comm = _get_comm()
return comm.Get_size() if comm is not None else 1
def barrier():
comm = _get_comm()
if comm is not None:
comm.Barrier()
def print_out(*args):
if get_rank() == 0:
print(*args)

View File

@ -0,0 +1,503 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer
from onnxruntime import InferenceSession, OrtValue
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if use_past_kv:
# Shape: (batch_size, 1)
position_ids = position_ids[:, -1].unsqueeze(-1)
# Shape: (batch_size, sequence_length)
return position_ids
# Inputs for first pass to get initial past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, sequence_length)
# position_ids: (batch_size, sequence_length)
def get_sample_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
engine: str = "pt",
return_dict: bool = False,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
if not return_dict:
# For export
return (input_ids, attention_mask, position_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return inputs
# Inputs for subsequent passes with past_key_values
# input_ids: (batch_size, 1)
# attention_mask: (batch_size, past_sequence_length + 1)
# position_ids: (batch_size, 1)
# past_key: (batch_size, num_heads, past_sequence_length, head_size)
# past_value: (batch_size, num_heads, past_sequence_length, head_size)
def get_sample_with_past_kv_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
past_seq_len: int,
use_fp16: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
# position_ids is of shape (batch_size, 1)
position_ids = get_position_ids(attention_mask, use_past_kv=True)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Inputs for all passes with past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, past_sequence_length + sequence_length)
# position_ids: (batch_size, sequence_length)
# past_key: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
# past_value: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
def get_merged_sample_with_past_kv_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
past_seq_len: int,
max_seq_len: int,
use_fp16: bool = False,
use_buffer_share: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
if use_buffer_share:
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
def get_msft_sample_inputs(
config: AutoConfig,
batch_size: int,
past_seq_len: int,
seq_len: int,
max_seq_len: int,
use_fp16: bool,
use_buffer_share: bool,
split_kv: bool,
):
np_dtype = np.float16 if use_fp16 else np.float32
head_size = config.hidden_size // config.num_attention_heads
if not split_kv:
ort_inputs = {
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
"attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
"k_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"v_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"pos": np.array(past_seq_len, dtype=np.int64),
}
else:
ort_inputs = {
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
"attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
np.int32
),
"pos": np.array(past_seq_len, dtype=np.int64),
}
for i in range(config.num_hidden_layers):
ort_inputs.update(
{
f"k_{i}_cache": np.random.rand(
batch_size, config.num_attention_heads, past_seq_len, head_size
).astype(np_dtype),
f"v_{i}_cache": np.random.rand(
batch_size, config.num_attention_heads, past_seq_len, head_size
).astype(np_dtype),
}
)
if use_buffer_share:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
num_heads = config.num_key_value_heads // world_size
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
)
for _ in range(config.num_hidden_layers)
]
return past_kv
# Convert list of past_key_values to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
past_kv = {}
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
return past_kv
# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
pt_inputs: dict,
use_buffer_share: bool = False,
past_seq_len: int = 0,
max_seq_len: int = 2048,
):
ort_inputs = {}
for k, v in pt_inputs.items():
if isinstance(v, np.ndarray):
ort_inputs[k] = v
elif k == "past_key_values":
ort_inputs.update(flatten_past_kv_inputs(v))
else:
ort_inputs[k] = v.detach().cpu().numpy()
# Reshape KV caches if using past-present-share-buffer
if use_buffer_share:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
# Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
# (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
for k, v in ort_inputs.items():
# Allocate new buffers with max_sequence_length for GQA
if "cache" in k or "past_key_values" in k:
# Copy v (BxSxPxH) into new_v (BxSxMxH)
batch_size, num_heads, _, head_size = v.shape
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
ort_inputs[k] = new_v
return ort_inputs
# Verify ONNX Runtime inputs with model
def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(ort_inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
print(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
del ort_inputs[unnecessary_input]
return ort_inputs
# Add IO bindings for execution providers using OrtValue
# Use when you need to run inference once or twice to save memory
def add_io_bindings_as_ortvalues(
model: InferenceSession,
ort_inputs: dict,
device: str,
device_id: int,
use_buffer_share: bool,
kv_cache_ortvalues: dict,
):
io_binding = model.io_binding()
model_inputs = set(map(lambda i: i.name, model.get_inputs()))
for k, v in ort_inputs.items():
# Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
# GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
# but `position_ids` is used as a PyTorch model input
if k not in model_inputs:
continue
# Bind OrtValue inputs to device
if use_buffer_share and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
kv_cache_ortvalues[k] = v_device
else:
kv_cache_ortvalues[k].update_inplace(v)
io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
else:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
for output in model.get_outputs():
name = output.name
if use_buffer_share and ("out" in name or "present" in name):
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
input_name = name.replace("out", "cache").replace("present", "past_key_values")
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
else:
io_binding.bind_output(name, device_type=device, device_id=device_id)
return io_binding, kv_cache_ortvalues
# Add IO bindings for execution providers using PyTorch tensors
# Use when you need to run inference many times
def add_io_bindings_as_tensors(
model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
):
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)
device = None
pt_to_np = {
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.float16": np.float16,
"torch.float32": np.float32,
}
# Bind inputs/outputs to IO binding
io_binding = model.io_binding()
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type=v.device.type,
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
device = v.device
for output in model.get_outputs():
name = output.name
# Bind KV cache outputs to KV cache inputs
v = (
inputs[name.replace("present", "past_key_values")]
if use_buffer_share and "present" in name
else outputs[name]
)
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
return io_binding
# Get actual inputs when using real data (instead of sample data) and initialize outputs
def get_initial_inputs_and_outputs(
config: AutoConfig,
tokenizer: AutoTokenizer,
requested_length: int,
prompt: list[str],
device: torch.device,
use_fp16: bool,
use_buffer_share: bool,
engine: str,
):
tokenizer.pad_token = tokenizer.eos_token
encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
torch_dtype = torch.float16 if use_fp16 else torch.float32
# input_ids: pad token id is 0
# attention_mask: pad token id is 0
# position_ids: pad token id is 1
input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
# Check if tokenized prompt length matches the requested prompt length
tokenized_length = input_ids.shape[-1]
if tokenized_length > requested_length:
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
input_ids = input_ids[:, :requested_length]
attention_mask = attention_mask[:, :requested_length]
position_ids = get_position_ids(attention_mask, use_past_kv=False)
elif tokenized_length < requested_length:
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
for _ in range(requested_length - tokenized_length):
input_ids = torch.hstack((input_ids_first_col, input_ids))
attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
position_ids = get_position_ids(attention_mask, use_past_kv=False)
tokenized_length = input_ids.shape[-1]
assert tokenized_length == requested_length
# Create inputs
inputs = {
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
}
if engine != "ort":
inputs["past_key_values"] = []
# Get shape of KV cache inputs
batch_size, sequence_length = input_ids.shape
max_sequence_length = config.max_position_embeddings
num_heads = config.num_key_value_heads
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
# Create KV cache inputs
for i in range(config.num_hidden_layers):
past_key = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
past_value = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
if engine == "ort":
inputs.update(
{
f"past_key_values.{i}.key": past_key.contiguous(),
f"past_key_values.{i}.value": past_value.contiguous(),
}
)
else:
inputs["past_key_values"].append((past_key, past_value))
outputs = None
if engine == "ort":
# Create outputs
logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
outputs = {"logits": logits.contiguous()}
if not use_buffer_share:
for i in range(config.num_hidden_layers):
present_key = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
present_value = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
outputs.update(
{f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
)
return inputs, outputs

View File

@ -0,0 +1,309 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import argparse
import logging
import os
import time
import numpy as np
import torch
from benchmark_helper import setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings_as_ortvalues,
convert_inputs_for_ort,
get_merged_sample_with_past_kv_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from llama_torch import setup_torch_model
from transformers import AutoConfig
import onnxruntime as ort
logger = logging.getLogger("")
def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
max_sequence_length = config.max_position_embeddings
return past_sequence_length, curr_sequence_length, max_sequence_length
def get_inputs(args: argparse.Namespace, config: AutoConfig):
# Dummy values for parity
world_size = get_size()
batch_size = 2
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)
if args.merged:
inputs = get_merged_sample_with_past_kv_inputs(
config,
args.device,
batch_size,
seq_len=sequence_length,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
return_dict=True,
world_size=world_size,
)
elif args.use_past_kv:
inputs = get_sample_with_past_kv_inputs(
config,
args.device,
batch_size,
sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
world_size=world_size,
)
else:
inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
return inputs
def verify_parity(
args: argparse.Namespace,
location: str,
use_auth_token: bool,
kv_cache_ortvalues: dict,
pytorch_model: None | torch.nn.Module = None,
config: None | AutoConfig = None,
):
# If it's running in a machine which GPU memory < 36GB, it should unload the llama in GPU in time and free the GPU memory for ORT.
py_model = pytorch_model
if py_model is None:
config, py_model = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)
inputs = get_inputs(args, config)
# Run inference with PyTorch
if args.execution_provider != "cpu":
torch.cuda.synchronize()
start_time = time.time()
pt_outputs = py_model(**inputs).logits.detach().cpu().numpy()
if args.execution_provider != "cpu":
torch.cuda.synchronize()
end_time = time.time()
logger.info(f"PyTorch took {end_time - start_time} s")
if args.small_gpu and py_model is not None:
del py_model
torch.cuda.empty_cache()
# Run inference with ORT
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
inputs = convert_inputs_for_ort(
inputs,
use_buffer_share=args.use_buffer_share,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
)
ep = f"{args.execution_provider.upper()}ExecutionProvider"
if ep == "CUDAExecutionProvider":
ep = (ep, {"device_id": args.rank})
ort_model = ort.InferenceSession(
args.onnx_model_path,
sess_options=ort.SessionOptions(),
providers=[ep],
)
inputs = verify_ort_inputs(ort_model, inputs)
# Add IO bindings for non-CPU execution providers
if args.execution_provider != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
ort_model,
ort_inputs=inputs,
device=args.execution_provider,
device_id=int(args.rank),
use_buffer_share=args.use_buffer_share,
kv_cache_ortvalues=kv_cache_ortvalues,
)
io_binding.synchronize_inputs()
start_time = time.time()
ort_model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
end_time = time.time()
ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
del ort_model
else:
start_time = time.time()
ort_outputs = ort_model.run(None, inputs)
end_time = time.time()
ort_outputs = ort_outputs[0] # Get logits
logger.info(f"ONNX Runtime took {end_time - start_time} s")
# Compare PyTorch and ONNX Runtime accuracy
tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
if not parity:
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
return kv_cache_ortvalues
def get_args(argv: list[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=False,
help="Model name in Hugging Face",
)
parser.add_argument(
"-t",
"--torch_model_directory",
required=False,
default=os.path.join("."),
help="Path to folder containing PyTorch model and associated files if saved on disk",
)
parser.add_argument(
"-o",
"--onnx_model_path",
required=True,
default=os.path.join("."),
help="Path to ONNX model (with external data files saved in the same folder as the model)",
)
parser.add_argument(
"-ep",
"--execution_provider",
required=False,
default="cpu",
choices=["cpu", "cuda", "rocm"],
help="Execution provider to verify parity with",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Print verbose logs",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"-p",
"--use_past_kv",
action="store_true",
help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
)
parser.set_defaults(use_past_kv=False)
parser.add_argument(
"-g",
"--use_buffer_share",
action="store_true",
help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
)
parser.set_defaults(use_buffer_share=False)
parser.add_argument(
"--merged",
action="store_true",
help="Use merged model (i.e. decoder_merged_model.onnx).",
)
parser.set_defaults(merged=False)
parser.add_argument(
"-fp",
"--precision",
required=True,
choices=["int4", "int8", "fp16", "fp32"],
help="Precision of model",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./model_cache",
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)
# The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
parser.add_argument(
"--small_gpu",
action="store_true",
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
)
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32"
if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
else "fp16"
)
return args
def main(argv: list[str] = []): # noqa: B006
args = get_args(argv)
setup_logger(args.verbose)
logger.info(f"Arguments: {args}")
rank = get_rank()
# Load model and config
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
args.rank = rank
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
use_auth_token = args.torch_model_directory == os.path.join(".")
location = args.model_name if use_auth_token else args.torch_model_directory
kv_cache_ortvalues = {}
if not args.merged:
verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
else:
config = llama = None
if not args.small_gpu:
config, llama = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)
# Verify prompt processing in merged model (decoder_model.onnx)
args.use_past_kv = False
kv_cache_ortvalues = verify_parity(
args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
)
# Verify token generation in merged model (decoder_with_past_model.onnx)
args.use_past_kv = True
verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)
if __name__ == "__main__":
seed = 2
np.random.seed(seed)
torch.manual_seed(seed)
main()

View File

@ -0,0 +1,47 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import torch
from dist_settings import barrier, get_rank, get_size
from transformers import AutoConfig, AutoModelForCausalLM
logger = logging.getLogger("")
def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=None):
world_size = get_size()
logger.info(f"world_size: {world_size}")
rank = get_rank()
barrier()
if not os.path.exists(args.cache_dir):
os.makedirs(args.cache_dir, exist_ok=True)
for i in range(world_size):
if i == rank % (world_size):
l_config = AutoConfig.from_pretrained(
location, use_auth_token=auth, cache_dir=args.cache_dir, trust_remote_code=auth
)
l_config.use_cache = True
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=auth,
trust_remote_code=auth,
config=l_config,
torch_dtype=torch_dtype,
cache_dir=args.cache_dir,
)
if world_size > 1:
llama.parallel_model()
if device:
llama.to(device)
llama.eval()
llama.requires_grad_(False)
barrier()
return l_config, llama

View File

@ -0,0 +1,108 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import numpy as np
import torch
from benchmark_helper import create_onnxruntime_session
from datasets import load_dataset
from llama_inputs import get_position_ids
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer
class QuantKVDataLoader:
def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
self.batch_size = 1
self.pad_max = args.pad_max
tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
dataset = load_dataset(args.smooth_quant_dataset, split="train")
dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
self.dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self.collate_batch,
)
self.decoder_model = (
create_onnxruntime_session(
onnx_model_path,
args.execution_provider != "cpu", # use_gpu
provider=args.execution_provider,
verbose=args.verbose,
)
if onnx_model_path
else None
)
def collate_batch(self, batch):
input_ids_batched = []
attention_mask_batched = []
position_ids_batched = []
labels = []
for text in batch:
# Set inputs for model
input_ids = text["input_ids"]
attention_mask = torch.ones(len(input_ids))
position_ids = get_position_ids(attention_mask, use_past_kv=False)
label = len(input_ids) - 1
# Pad input data because all model inputs must have same shape
pad_len = self.pad_max - input_ids.shape[0]
input_ids = pad(input_ids, (0, pad_len), value=1)
attention_mask = pad(attention_mask, (0, pad_len), value=0)
position_ids = pad(position_ids, (0, pad_len), value=0)
input_ids_batched.append(input_ids)
attention_mask_batched.append(attention_mask)
position_ids_batched.append(position_ids)
labels.append(label)
input_ids_batched = torch.vstack(input_ids_batched)
attention_mask_batched = torch.vstack(attention_mask_batched)
position_ids_batched = torch.vstack(position_ids_batched)
labels = torch.tensor(labels)
return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
def __iter__(self):
try:
for (input_ids, attention_mask, position_ids), labels in self.dataloader:
# Inputs for decoder_model.onnx
inputs = {
"input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
"attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
"position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
}
label = labels.detach().cpu().numpy()
if self.decoder_model is not None:
# Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
outputs = self.decoder_model.run(None, inputs)
for i in range(int((len(outputs) - 1) / 2)):
inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
past_sequence_length = inputs["past_key_values.0.key"].shape[2]
inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
inputs["position_ids"] = (
get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
)
# Yield (inputs, label) tuple for Intel's Neural Compressor:
# https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
yield (inputs, label)
except StopIteration:
return

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,821 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
#
# This script run benchmark of latency or peak memory usage of Longformer model inference.
# Please run convert_to_onnx.py to get onnx model before running benchmark.
#
# It is tested with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.11.0, transformers 4.18.0, CUDA 11.3 like:
# conda create -n gpu_env python=3.8
# conda activate gpu_env
# pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# pip3 install onnx transformers onnxruntime-gpu numpy sympy coloredlogs psutil py3nvml
# python benchmark_longformer.py
#
# When there is no parameter, pre-defined tests will run on the longformer-base-4096 model.
# Benchmark the latency:
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 512 1024 2048 4096 \
# --global_lengths 8 --onnx ./longformer-base-4096_fp16.onnx -t 100
#
# Benchmark GPU peak memory:
# export ORT_LONGFORMER_COMPACT_MEMORY=0
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
# export ORT_LONGFORMER_COMPACT_MEMORY=1
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
#
# By default, compact memory kernel is enabled. To disable it, set environment variable ORT_LONGFORMER_COMPACT_MEMORY=0.
import argparse
import csv
import logging
import math
import os
import re
import sys
import timeit
import traceback
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import Any, Dict, List
import benchmark_helper
import numpy as np
import torch
from longformer_helper import PRETRAINED_LONGFORMER_MODELS, LongformerHelper, LongformerInputs
from transformers import LongformerModel
import onnxruntime
logger = logging.getLogger("")
def test_torch_latency(
device,
model,
model_name,
batch_sizes,
sequence_lengths,
global_lengths,
test_times,
num_threads,
) -> List[Dict[str, Any]]:
if num_threads > 0:
torch.set_num_threads(num_threads)
results = []
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
for global_length in global_lengths:
logger.info(f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}")
inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
input_list = inputs.to_list()
_ = model(*input_list)
runtimes = timeit.repeat(lambda: model(*input_list), repeat=test_times, number=1) # noqa: B023
result = {
"engine": "torch", # TODO: test torchscript
"version": torch.__version__,
"device": "cuda",
"optimizer": "",
"precision": "fp32",
"io_binding": "",
"model_name": model_name,
"description": model_name + " [torch]",
"inputs": 3,
"threads": num_threads,
"batch_size": batch_size,
"sequence_length": sequence_length,
"global_length": global_length,
"datetime": str(datetime.now()),
"memory": "NA",
"diff_max": 0,
"diff_90_percentile": 0,
"diff_95_percentile": 0,
"diff_99_percentile": 0,
"use_compact_memory": "NA",
}
result.update(benchmark_helper.get_latency_result(runtimes, batch_size))
logger.info("%s", result)
results.append(result)
return results
def test_parity(device, model, ort_session, batch_size, sequence_length, global_length, verbose=True):
parameters = f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}"
logger.info(f"Comparing Torch and ORT outputs for {parameters}...")
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
ort_inputs = dummy_inputs.get_ort_inputs()
ort_outputs = ort_session.run(None, ort_inputs)
input_list = dummy_inputs.to_list()
torch_outputs = model(*input_list)
max_diff = np.amax(torch_outputs[0].cpu().numpy() - ort_outputs[0])
logger.info(f"last_state max diff = {max_diff}")
if verbose and (math.isnan(max_diff) or max_diff > 0.001):
print("torch last_state:", torch_outputs[0])
print("ort last_state:", ort_outputs[0])
return float(max_diff)
def test_ort_latency(
device,
model,
model_name,
description,
ort_session,
batch_sizes,
sequence_lengths,
global_lengths,
test_times,
num_threads,
optimizer=False,
precision="fp32",
disable_io_binding=False,
verbose=True,
use_compact_memory=False,
use_half4=False,
disable_parity=False,
) -> List[Dict[str, Any]]:
results = []
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
for global_length in global_lengths:
assert (
global_length <= model.config.attention_window[0]
), "Limitation of current implementation: number of global token <= attention_window"
logger.info(
f"Testing batch_size={batch_size} sequence_length={sequence_length} global_length={global_length} "
f"optimizer={optimizer}, precision={precision} io_binding={not disable_io_binding}..."
)
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
# Run OnnxRuntime
ort_inputs = dummy_inputs.get_ort_inputs()
if verbose:
print(ort_inputs)
# run one query for warm up
ort_outputs = ort_session.run(None, ort_inputs)
result_template = {
"model_name": model_name,
"description": description,
"inputs": 3,
"engine": "OnnxRuntime",
"version": str(onnxruntime.__version__),
"device": "cuda",
"precision": str(precision),
"optimizer": int(optimizer),
"threads": int(num_threads),
"batch_size": int(batch_size),
"sequence_length": int(sequence_length),
"global_length": int(global_length),
"test_times": int(test_times),
"datetime": str(datetime.now()),
"memory": "",
"diff_max": None,
"diff_90_percentile": None,
"diff_95_percentile": None,
"diff_99_percentile": None,
"use_compact_memory": use_compact_memory,
"use_half4": use_half4,
}
if not disable_io_binding:
max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size
max_pooler_size = max(batch_sizes) * max(sequence_lengths)
result = benchmark_helper.inference_ort_with_io_binding(
ort_session,
ort_inputs,
result_template=result_template,
repeat_times=test_times,
ort_output_names=["last_state", "pooler"],
ort_outputs=ort_outputs,
output_buffers=[],
output_buffer_max_sizes=[max_last_state_size, max_pooler_size],
batch_size=batch_size,
device=device,
data_type=np.longlong, # input data type
)
else:
result = benchmark_helper.inference_ort(
ort_session,
ort_inputs,
result_template=result_template,
repeat_times=test_times,
batch_size=batch_size,
)
# measure result difference between PyTorch and OnnxRuntime
if not disable_parity:
diff_results = [
test_parity(
device,
model,
ort_session,
batch_size,
sequence_length,
global_length,
verbose,
)
for _ in range(test_times)
]
result["diff_max"] = max(diff_results)
result["diff_90_percentile"] = np.percentile(diff_results, 90)
result["diff_95_percentile"] = np.percentile(diff_results, 95)
result["diff_99_percentile"] = np.percentile(diff_results, 99)
results.append(result)
return results
def test_ort_memory(
device,
onnx_model_path,
batch_size,
sequence_length,
global_length,
test_times,
num_threads,
) -> Dict[str, Any]:
logger.info(
f"Testing memory for model={onnx_model_path}, batch_size={batch_size}, sequence_length={sequence_length}, "
f"global_length={global_length}, test_times={test_times}, num_threads={num_threads}"
)
def inference():
# Update Arena strategy so that we can measure the minimum memory required
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
session = benchmark_helper.create_onnxruntime_session(
onnx_model_path,
use_gpu=True,
enable_all_optimization=True,
num_threads=num_threads,
provider_options=provider_options,
)
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
ort_inputs = dummy_inputs.get_ort_inputs()
for _ in range(test_times):
_ = session.run(None, ort_inputs)
memory_used = benchmark_helper.measure_memory(is_gpu=True, func=inference)
return {
"onnx_model": onnx_model_path,
"batch_size": batch_size,
"sequence_length": sequence_length,
"global_length": global_length,
"test_times": test_times,
"num_threads": num_threads,
"memory": memory_used,
}
def load_torch_model(model_name, device):
torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS.get(model_name, model_name)
model = LongformerModel.from_pretrained(torch_model_name_or_dir)
model.to(device)
return model
def find_onnx_model(model_name, onnx_dir="."):
# Search onnx model in the following order: optimized fp16 model, optimized fp32 model, raw model
onnx_model_path = os.path.join(onnx_dir, model_name + ".onnx")
optimized_fp32_model = os.path.join(onnx_dir, model_name + "_fp32.onnx")
optimized_fp16_model = os.path.join(onnx_dir, model_name + "_fp16.onnx")
if os.path.isfile(optimized_fp16_model):
onnx_model_path = optimized_fp16_model
elif os.path.isfile(optimized_fp32_model):
onnx_model_path = optimized_fp32_model
return onnx_model_path
def test_memory(args, device) -> Dict[str, Any]:
if len(args.batch_sizes) > 1:
raise RuntimeError("For memory test, only one batch_size (-b) is allowed.")
if len(args.sequence_lengths) > 1:
raise RuntimeError("For memory test, only one sequence_length (-s) is allowed.")
if len(args.global_lengths) > 1:
raise RuntimeError("For memory test, only one global_length (-g) is allowed.")
model_name = args.model
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
torch.cuda.empty_cache()
return test_ort_memory(
device,
onnx_model_path,
args.batch_sizes[0],
args.sequence_lengths[0],
args.global_lengths[0],
args.test_times,
args.num_threads,
)
def test_ort(args, device) -> List[Dict[str, Any]]:
model_name = args.model
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx") # noqa: PIE810
precision = "fp32" if not onnx_model_path.endswith("_fp16.onnx") else "fp16"
model = load_torch_model(model_name, device)
num_threads = args.num_threads
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
session = benchmark_helper.create_onnxruntime_session(
onnx_model_path,
use_gpu=True,
enable_all_optimization=True,
num_threads=num_threads,
provider_options=provider_options,
)
if session is None:
raise RuntimeError(f"Failed to create ORT session from ONNX file {onnx_model_path}")
use_compact_memory = os.environ.get("ORT_LONGFORMER_COMPACT_MEMORY", "1") == "1"
description = onnx_model_path
if not use_compact_memory:
description += "[non_compact_memory]"
if args.use_half4:
description += "[half4]" if precision == "fp16" else "[float4]"
else:
description += "[half2]" if precision == "fp16" else "[float4]"
return test_ort_latency(
device,
model,
model_name,
description,
session,
args.batch_sizes,
args.sequence_lengths,
args.global_lengths,
args.test_times,
num_threads,
optimized,
precision,
args.disable_io_binding,
args.verbose,
use_compact_memory,
args.use_half4,
args.disable_parity,
)
def test_torch(args, device) -> List[Dict[str, Any]]:
model = load_torch_model(args.model, device)
return test_torch_latency(
device,
model,
args.model,
args.batch_sizes,
args.sequence_lengths,
args.global_lengths,
args.test_times,
args.num_threads,
)
def test_latency(args, device) -> List[Dict[str, Any]]:
if args.engine == "onnxruntime":
return test_ort(args, device)
return test_torch(args, device)
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="longformer-base-4096",
help="Checkpoint directory or pre-trained model names in the list: "
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
)
parser.add_argument(
"-e",
"--engine",
required=False,
type=str,
default="onnxruntime",
choices=["onnxruntime", "torch"],
help="Engine to benchmark.",
)
parser.add_argument(
"-t",
"--test_times",
required=False,
default=1000,
type=int,
help="Number of repeat times to get average inference latency.",
)
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
# If --export_padding is not used in exporting onnx model, there is no padding in ONNX model,
# and you will need padding inputs by yourself before running onnx model.
# Here, we only test sequence length that is multiple of attention window size.
parser.add_argument(
"-s",
"--sequence_lengths",
nargs="+",
type=int,
default=[512, 1024, 2048, 4096],
help="Sequence lengths. It could have multiple values in latency test."
"If --export_padding is not used, sequence length shall be multiple of window size.",
)
parser.add_argument("--onnx", required=False, type=str, default=None, help="Onnx model path")
parser.add_argument(
"-g",
"--global_lengths",
nargs="+",
type=int,
default=[0],
help="Number of global tokens. It could have multiple values in latency test.",
)
parser.add_argument(
"-n",
"--num_threads",
required=False,
type=int,
default=0,
help="Threads to use.",
)
parser.add_argument(
"--disable_io_binding",
required=False,
action="store_true",
help="Do not use IO Binding.",
)
parser.add_argument(
"--memory",
required=False,
action="store_true",
help="Test memory usage instead of latency.",
)
parser.add_argument("--verbose", required=False, action="store_true", help="Print more information.")
parser.set_defaults(verbose=False)
parser.add_argument("--use_half4", required=False, action="store_true", help="Use half4 kernel.")
parser.set_defaults(use_half4=False)
parser.add_argument("--disable_parity", required=False, action="store_true", help="Do not run parity test.")
parser.set_defaults(disable_parity=False)
args = parser.parse_args(argv)
return args
def output_details(results, csv_filename):
latency_results = [result for result in results if "average_latency_ms" in result]
if len(latency_results) == 0:
print("No latency results for output.")
return
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = [
"engine",
"version",
"device",
"precision",
"optimizer",
"io_binding",
"model_name",
"inputs",
"threads",
"datetime",
"test_times",
"description",
"batch_size",
"sequence_length",
"global_length",
"use_compact_memory",
"use_half4",
"diff_max",
"diff_90_percentile",
"diff_95_percentile",
"diff_99_percentile",
"memory",
"QPS",
"average_latency_ms",
"latency_variance",
"latency_90_percentile",
"latency_95_percentile",
"latency_99_percentile",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for result in latency_results:
print(result)
csv_writer.writerow(result)
csv_file.flush()
print(f"Detail results are saved to csv file: {csv_filename}")
def run(args) -> List[Dict[str, Any]]:
torch.set_grad_enabled(False)
# set random seed manually to get deterministic results
benchmark_helper.set_random_seed(123)
# Currently, the longformer attention operator could only run in GPU (no CPU implementation yet).
device = torch.device("cuda:0")
if args.memory:
return [test_memory(args, device)] # Convert to List so that return type is same as test_latency
return test_latency(args, device)
def launch_test(arguments) -> List[Dict[str, Any]]:
if not torch.cuda.is_available():
raise RuntimeError("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
with ProcessPoolExecutor() as executor:
results = list(executor.map(run, [arguments]))
assert len(results) == 1
return results[0]
def run_tests(
use_compact_memory=True,
run_torch=False,
run_memory=True,
use_io_binding=True,
use_fp16=True,
use_merged_qkv_weights=True,
use_half4=True,
batch_size=1,
):
compact_memory = "1" if use_compact_memory else "0"
os.environ["ORT_LONGFORMER_COMPACT_MEMORY"] = compact_memory
logger.info(f"ORT_LONGFORMER_COMPACT_MEMORY={compact_memory}")
os.environ["ORT_LONGFORMER_USE_HALF4"] = "1" if use_half4 else "0"
logger.info("ORT_LONGFORMER_USE_HALF4={}".format("1" if use_half4 else "0")) # noqa: G001
results = []
test_times = 1000
sequence_lengths = [4096, 2048, 1024, 512]
batch_sizes = [batch_size]
for model_name in ["longformer-base-4096"]:
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
for global_length in [16]:
if run_torch:
engine_name = "torch"
args = parse_arguments(
f"-e {engine_name} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} "
f"-t {test_times} -m {model_name}".split(" ")
)
results += run(args)
engine_name = "onnxruntime"
file_format = 1 if use_merged_qkv_weights else 0
onnx_path = (
f"{model_name}_f{file_format}_fp16.onnx"
if use_fp16
else f"{model_name}_f{file_format}_fp32.onnx"
)
if not os.path.exists(onnx_path):
raise RuntimeError(f"onnx file not exists:{onnx_path}")
arguments = (
f"-e {engine_name} --onnx {onnx_path} "
f"-b {batch_size} -s {sequence_length} -g {global_length} -m {model_name}"
)
if not use_io_binding:
arguments += " --disable_io_binding"
if use_half4:
arguments += " --use_half4"
# Disable parity test to avoid out of memory for large batch size
if batch_size >= 4:
arguments += " --disable_parity"
memory_results = None
try:
if run_memory:
args = parse_arguments(f"{arguments} -t 10 --memory".split(" "))
memory_results = launch_test(args)
args = parse_arguments(f"{arguments} -t {test_times}".split(" "))
latency_results = launch_test(args)
except KeyboardInterrupt as exc:
raise RuntimeError("Keyboard Interrupted") from exc
except Exception:
traceback.print_exc()
continue
if len(latency_results) == 1:
latency_results[0]["memory"] = memory_results[0]["memory"] if memory_results else "N/A"
else:
raise RuntimeError("length of latency_results should be 1")
logger.info("%s", latency_results)
results += latency_results
return results
def output_summary(results, csv_filename, data_field="average_latency_ms"):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
header_names = [
"model_name",
"precision",
"engine",
"version",
"global_length",
"use_compact_memory",
"use_half4",
"description",
]
description_list = list({result["description"] for result in results})
description_list.sort()
batch_sizes = list({result["batch_size"] for result in results})
batch_sizes.sort()
sequence_lengths = list({result["sequence_length"] for result in results})
sequence_lengths.sort()
data_names = []
for sequence_length in sequence_lengths:
for batch_size in batch_sizes:
data_names.append(f"b{batch_size}_s{sequence_length}")
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
csv_writer.writeheader()
for description in description_list:
row = {}
sum_latency = {}
sum_latency.update({k: 0 for k in data_names})
count_latency = {}
count_latency.update({k: 0 for k in data_names})
for result in results:
if result["description"] == description and result[data_field]:
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
else:
for k in header_names:
if row[k] != headers[k]:
raise RuntimeError("Description shall be unique")
batch_size = result["batch_size"]
sequence_length = result["sequence_length"]
key = f"b{batch_size}_s{sequence_length}"
try:
latency = float(result[data_field])
except ValueError:
continue
sum_latency[key] += latency
count_latency[key] += 1
if row:
for key in data_names:
if key in count_latency and count_latency[key] > 0:
row[key] = sum_latency[key] / count_latency[key]
else:
row[key] = ""
csv_writer.writerow(row)
csv_file.flush()
def run_experiments(use_fp16, batch_size, is_baseline=False):
"""Run experiments to compare different algorithms on one batch size"""
test_results = run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=True,
use_half4=False,
batch_size=batch_size,
)
if is_baseline:
return test_results
if use_fp16:
test_results += run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=True,
use_half4=True,
batch_size=batch_size,
)
test_results += run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=False,
use_half4=True,
batch_size=batch_size,
)
test_results += run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=False,
use_half4=False,
batch_size=batch_size,
)
return test_results
def main():
torch.multiprocessing.set_start_method("spawn")
args = parse_arguments()
benchmark_helper.setup_logger(args.verbose)
if len(sys.argv) > 1:
test_results = launch_test(args)
time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
csv_filename = f"benchmark_detail_{time_stamp}.csv"
output_details(test_results, csv_filename)
return
gpu_list = benchmark_helper.get_gpu_info()
logger.info("GPU info: %s", gpu_list)
fp16_batch_sizes = [16, 8, 4, 2, 1]
fp32_batch_sizes = [4, 2, 1]
if gpu_list and gpu_list[0]["total"] >= 32 * 1024 * 1024 * 1024: # 32 GB
fp16_batch_sizes = [64, 32, 16, 8, 4, 2, 1]
fp32_batch_sizes = [16, 8, 4, 2, 1]
gpu_name = re.sub(r"(?u)[^-\w.]", "_", gpu_list[0]["name"]) if gpu_list else "gpu"
is_baseline = os.environ.get("ORT_LONGFORMER_BASELINE", "0") == "1"
experiment_name = f"longformer_base_{gpu_name}" + ("_baseline" if is_baseline else "")
logger.info(
f"experiment_name={experiment_name}, fp16_batch_sizes={fp16_batch_sizes}, fp32_batch_sizes={fp32_batch_sizes}"
)
total_runs = 1
all_results = []
for _ in range(total_runs):
for batch_size in fp16_batch_sizes:
fp16_results = run_experiments(use_fp16=True, batch_size=batch_size, is_baseline=is_baseline)
output_details(fp16_results, "longformer_base_fp16.csv")
all_results += fp16_results
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
all_results = []
for _ in range(total_runs):
for batch_size in fp32_batch_sizes:
fp32_results = run_experiments(use_fp16=False, batch_size=batch_size, is_baseline=is_baseline)
output_details(fp32_results, "longformer_base_fp32.csv")
all_results += fp32_results
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,413 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script converts Longformer model from huggingface transformers 4.0 or later to ONNX.
# It translates LongformerSelfAttention to the LongformerAttention operator in ONNX Runtime.
#
# Before running this script, prepare a python environment in Linux with PyTorch 1.9.0 and other packages installed.
# Then run "python setup.py install" in ./torch_extensions directory. If your python version is not 3.8, you will need
# update this script with correct name of longformer_attention.cpython-*.so (search TODO below).
#
# It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0.
# Warning: Using PyTorch 1.10 or newer version might encounter issue in exporting, but they are fine for benchmarking.
#
# Example commands to export longformer base model in Linux:
# conda create -n longformer python=3.8
# conda activate longformer
# python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
# python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0
# python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu
# cd ./torch_extensions
# rm -rf build
# python setup.py install
# cd ..
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx --no_merge_qkv
#
# GPU is not needed for this script. You can run it in CPU. For --optimize_onnx, you can use either onnxruntime or onnxruntime-gpu package.
#
# For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version.
import argparse
import inspect
from pathlib import Path
import torch
import transformers
from longformer_helper import PRETRAINED_LONGFORMER_MODELS
from onnx import load_model
from onnx_model_bert import BertOnnxModel
from packaging import version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
from torch_onnx_export_helper import torch_onnx_export
from transformers import LongformerModel, LongformerSelfAttention
# Supports format 0 or 1
weight_bias_format = 0
@parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i")
def my_longformer_attention(
g,
input,
weight,
bias,
mask,
global_weight,
global_bias,
global_mask,
num_heads,
window,
):
return g.op(
"com.microsoft::LongformerAttention",
input,
weight,
bias,
mask,
global_weight,
global_bias,
global_mask,
num_heads_i=num_heads,
window_i=window,
)
# namespace is onnxruntime which is registered in longformer_attention.cpp
register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9)
# TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8
torch.ops.load_library(
r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so"
)
def parse_arguments():
"""Parse arguments
Returns:
args: Namespace
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="longformer-base-4096",
help="Checkpoint directory or pre-trained model names in the list: "
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
)
parser.add_argument(
"--export_padding",
required=False,
action="store_true",
help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.",
)
parser.set_defaults(export_padding=False)
parser.add_argument(
"--no_merge_qkv",
required=False,
action="store_true",
help="Stack the weights of q, k and v on dimension 0 instead of dimension 1.",
)
parser.set_defaults(no_merge_qkv=False)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model.",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument(
"-p",
"--precision",
required=False,
type=str,
default="fp32",
choices=["fp32", "fp16"],
help="Precision of model to run: fp32 for full precision, fp16 for mixed precision",
)
args = parser.parse_args()
return args
# Create a dummy input for ONNX export.
def get_dummy_inputs(config, export_padding, device):
# When sequence length is multiple of windows size, there is no padding logic in ONNX graph
sequence_length = config.attention_window[0] + 1 if export_padding else config.attention_window[0]
# Create dummy inputs
input_ids = torch.arange(sequence_length).unsqueeze(0).to(device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
attention_mask[:, sequence_length - 1] = 0 # last token is masked
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
global_attention_mask[:, 0] = 1 # first token is global token
return input_ids, attention_mask, global_attention_mask
# A new function to replace LongformerSelfAttention.forward
# For transformers 4.0.0
def my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
):
global_mask = is_index_global_attn.int()
# The following check is based on the dummy inputs (only the first token is global).
assert (
len(global_mask.shape) == 2
and global_mask.shape[0] == 1
and global_mask.count_nonzero().item() == 1
and global_mask.tolist()[0][0] == 1
)
input_mask = is_index_masked.float()
# TODO: The filtering value may be -10000.0 or -inf. Check the huggingface implementation.
input_mask = input_mask.masked_fill(is_index_masked, -10000.0)
# Yet another way to generate input_mask = torch.masked_fill(attention_mask, is_index_global_attn, 0.0)
# TODO: add postprocessing of ONNX model to calculate based on graph input: input_mask = (attention_mask - 1) * 10000.0
# TODO: add postprocessing of ONNX model to use graph input directly: global_mask = global_attention_mask
# The following check is based on the dummy inputs (only the last token is masked).
assert (
len(input_mask.shape) == 2
and input_mask.shape[0] == 1
and input_mask.count_nonzero().item() == 1
and input_mask.tolist()[0][-1] == -10000.0
)
weight = torch.stack(
(
self.query.weight.transpose(0, 1),
self.key.weight.transpose(0, 1),
self.value.weight.transpose(0, 1),
),
dim=weight_bias_format,
)
if weight_bias_format == 1:
# shape is (hidden_size, 3*hidden_size) for format 1, otherwise (3, hidden_size, hidden_size) by default
weight = weight.reshape(self.embed_dim, 3 * self.embed_dim)
global_weight = torch.stack(
(
self.query_global.weight.transpose(0, 1),
self.key_global.weight.transpose(0, 1),
self.value_global.weight.transpose(0, 1),
),
dim=weight_bias_format,
)
if weight_bias_format == 1:
global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim)
if weight_bias_format == 1:
bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0)
bias = bias.reshape(3 * self.embed_dim)
global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0)
global_bias = global_bias.reshape(3 * self.embed_dim)
else:
bias = torch.stack(
(self.query.bias, self.key.bias, self.value.bias, self.key_global.bias, self.value_global.bias), dim=0
)
bias = bias.reshape(5 * self.embed_dim)
global_bias = self.query_global.bias
global_bias = global_bias.reshape(1 * self.embed_dim)
attn_output = torch.ops.onnxruntime.LongformerAttention(
hidden_states,
weight,
bias,
input_mask,
global_weight,
global_bias,
global_mask,
self.num_heads,
self.one_sided_attn_window_size,
)
assert attn_output.size() == hidden_states.size(), "Unexpected size"
outputs = (attn_output,)
return outputs
# For transformers 4.3.0
def my_longformer_self_attention_forward_4_3(
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
assert output_attentions is False
return my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
# For transformers 4.3.2 or later versions
def my_longformer_self_attention_forward_4_3_2(
self,
hidden_states,
attention_mask=None,
layer_head_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
assert output_attentions is False
assert layer_head_mask is None
return my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
def export_longformer(model: LongformerModel, onnx_model_path: str, export_padding: bool):
"""Export longformer model to ONNX
Args:
model (LongformerModel): longformer model
onnx_model_path (str): output onnx path
export_padding (bool): whether export padding logic to ONNX so that input string can be any length.
Raises:
RuntimeError: This tool requires transformers 4.0.0 or later.
RuntimeError: LongformerSelfAttention.forward arguments are different.
"""
input_ids, attention_mask, global_attention_mask = get_dummy_inputs(
model.config, export_padding, device=torch.device("cpu")
)
_ = model(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
)
if version.parse(transformers.__version__) < version.parse("4.0.0"):
raise RuntimeError("This tool requires transformers 4.0.0 or later.")
# Here we replace LongformerSelfAttention.forward using our implementation for exporting ONNX model
key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args)
args_to_func = {
"self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2,
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3,
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4,
}
if key not in args_to_func:
print(
"Current arguments",
inspect.getfullargspec(LongformerSelfAttention.forward).args,
)
raise RuntimeError(
"LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)."
)
# Store for restoring later
original_forward = LongformerSelfAttention.forward
LongformerSelfAttention.forward = args_to_func[key]
example_inputs = (input_ids, attention_mask, global_attention_mask)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
example_inputs,
onnx_model_path,
opset_version=12,
input_names=["input_ids", "attention_mask", "global_attention_mask"],
output_names=["last_state", "pooler"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"global_attention_mask": {0: "batch_size", 1: "sequence_length"},
"last_state": {0: "batch_size", 1: "sequence_length"},
"pooler": {0: "batch_size", 1: "sequence_length"},
},
custom_opsets={"com.microsoft": 1},
)
print(f"ONNX model exported to {onnx_model_path}")
# Restore original implementation:
LongformerSelfAttention.forward = original_forward
def optimize_longformer(onnx_model_path: str, fp32_model_path: str, fp16_model_path=None):
"""Optimize longformer onnx model
Args:
onnx_model_path (str): path of original ONNX model.
fp32_model_path (str): path of optimized fp32 model.
fp16_model_path (str, optional): path of optimized fp16 model. Defaults to None.
"""
model = load_model(onnx_model_path, format=None, load_external_data=True)
optimizer = BertOnnxModel(model)
optimizer.optimize()
use_external_data_format = False
if fp32_model_path:
optimizer.save_model_to_file(fp32_model_path, use_external_data_format)
print(f"optimized fp32 model saved to {fp32_model_path}")
if fp16_model_path:
optimizer.convert_float_to_float16(keep_io_types=True)
optimizer.save_model_to_file(fp16_model_path, use_external_data_format)
print(f"optimized fp16 model saved to {fp16_model_path}")
def main(args):
model_name = args.model
onnx_model_path = model_name + ".onnx"
global weight_bias_format # noqa: PLW0603
weight_bias_format = 0 if args.no_merge_qkv else 1
model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name])
export_longformer(model, onnx_model_path, args.export_padding)
if args.optimize_onnx or args.precision != "fp32":
fp32_model_path = model_name + f"_f{weight_bias_format}" + "_fp32.onnx"
fp16_model_path = model_name + f"_f{weight_bias_format}" + "_fp16.onnx" if args.precision == "fp16" else None
optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path)
if __name__ == "__main__":
args = parse_arguments()
main(args)

View File

@ -0,0 +1,347 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Generate test data for a longformer model, so that we can use onnxruntime_perf_test.exe to evaluate the inference latency.
import argparse
import os
import random
from pathlib import Path
import numpy as np
from bert_test_data import fake_input_ids_data, fake_input_mask_data, output_test_data
from onnx import ModelProto, TensorProto
from onnx_model import OnnxModel
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
parser.add_argument(
"--output_dir",
required=False,
type=str,
default=None,
help="output test data path. If not specified, .",
)
parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
parser.add_argument(
"--sequence_length",
required=False,
type=int,
default=128,
help="maximum sequence length of input",
)
parser.add_argument(
"-a",
"--average_sequence_length",
default=-1,
type=int,
help="average sequence length excluding padding",
)
parser.add_argument(
"-r",
"--random_sequence_length",
required=False,
action="store_true",
help="use uniform random instead of fixed sequence length",
)
parser.set_defaults(random_sequence_length=False)
parser.add_argument(
"--global_tokens",
required=False,
type=int,
default=10,
help="number of global tokens",
)
parser.add_argument(
"--input_ids_name",
required=False,
type=str,
default=None,
help="input name for input ids",
)
parser.add_argument(
"--input_mask_name",
required=False,
type=str,
default=None,
help="input name for attention mask",
)
parser.add_argument(
"--global_mask_name",
required=False,
type=str,
default=None,
help="input name for global attention mask",
)
parser.add_argument(
"--samples",
required=False,
type=int,
default=1,
help="number of test cases to be generated",
)
parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="print verbose information",
)
parser.set_defaults(verbose=False)
args = parser.parse_args()
return args
def get_longformer_inputs(onnx_file, input_ids_name=None, input_mask_name=None, global_mask_name=None):
"""
Get graph inputs for longformer model.
"""
model = ModelProto()
with open(onnx_file, "rb") as f:
model.ParseFromString(f.read())
onnx_model = OnnxModel(model)
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
if input_ids_name is not None:
input_ids = onnx_model.find_graph_input(input_ids_name)
if input_ids is None:
raise ValueError(f"Graph does not have input named {input_ids_name}")
input_mask = None
if input_mask_name:
input_mask = onnx_model.find_graph_input(input_mask_name)
if input_mask is None:
raise ValueError(f"Graph does not have input named {input_mask_name}")
global_mask = None
if global_mask_name:
global_mask = onnx_model.find_graph_input(global_mask_name)
if global_mask is None:
raise ValueError(f"Graph does not have input named {global_mask_name}")
expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else 0)
if len(graph_inputs) != expected_inputs:
raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
return input_ids, input_mask, global_mask
if len(graph_inputs) != 3:
raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
# Try guess the inputs based on naming.
input_ids = None
input_mask = None
global_mask = None
for input in graph_inputs:
input_name_lower = input.name.lower()
if "global" in input_name_lower:
global_mask = input
elif "mask" in input_name_lower:
input_mask = input
else:
input_ids = input
if input_ids and input_mask and global_mask:
return input_ids, input_mask, global_mask
raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_tokens):
"""
Fake data based on the graph input of segment_ids.
Args:
segment_ids (TensorProto): graph input of input tensor.
Returns:
data (np.array): the data for input tensor
"""
data_type = global_mask.type.tensor_type.elem_type
assert data_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
if num_global_tokens > 0:
assert num_global_tokens <= sequence_length
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
temp = np.ones((batch_size, num_global_tokens), dtype=np.int32)
data[: temp.shape[0], : temp.shape[1]] = temp
else:
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
if data_type == TensorProto.FLOAT:
data = np.float32(data)
elif data_type == TensorProto.INT64:
data = np.int64(data)
return data
def fake_test_data(
batch_size,
sequence_length,
test_cases,
dictionary_size,
verbose,
random_seed,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
):
"""
Generate fake input data for test.
"""
assert input_ids is not None
np.random.seed(random_seed)
random.seed(random_seed)
all_inputs = []
for _ in range(test_cases):
input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
inputs = {input_ids.name: input_1}
if input_mask:
inputs[input_mask.name] = fake_input_mask_data(
input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length
)
if global_mask:
inputs[global_mask.name] = fake_global_mask_data(
global_mask, batch_size, sequence_length, num_global_tokens
)
if verbose and len(all_inputs) == 0:
print("Example inputs", inputs)
all_inputs.append(inputs)
return all_inputs
def generate_test_data(
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
):
dictionary_size = 10000
all_inputs = fake_test_data(
batch_size,
sequence_length,
test_cases,
dictionary_size,
verbose,
seed,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
)
if len(all_inputs) != test_cases:
print("Failed to create test data for test.")
return all_inputs
def create_longformer_test_data(
model,
output_dir,
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids_name,
input_mask_name,
global_mask_name,
num_global_tokens,
average_sequence_length,
random_sequence_length,
):
input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name)
all_inputs = generate_test_data(
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
)
for i, inputs in enumerate(all_inputs):
output_test_data(output_dir, i, inputs)
def main():
args = parse_arguments()
output_dir = args.output_dir
if output_dir is None:
# Default output directory is a sub-directory under the directory of model.
output_dir = os.path.join(
Path(args.model).parent,
f"b{args.batch_size}_s{args.sequence_length}_g{args.global_tokens}",
)
if output_dir is not None:
# create the output directory if not existed
path = Path(output_dir)
path.mkdir(parents=True, exist_ok=True)
else:
print("Directory existed. test data files will be overwritten.")
if args.average_sequence_length <= 0:
args.average_sequence_length = args.sequence_length
create_longformer_test_data(
args.model,
output_dir,
args.batch_size,
args.sequence_length,
args.samples,
args.seed,
args.verbose,
args.input_ids_name,
args.input_mask_name,
args.global_mask_name,
args.global_tokens,
args.average_sequence_length,
)
print("Test data is saved to directory:", output_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,77 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script helps creating dummy inputs for Longformer model.
import logging
from typing import Dict, List, Tuple, Union
import numpy
import torch
logger = logging.getLogger(__name__)
PRETRAINED_LONGFORMER_MODELS = {
"longformer-base-4096": "allenai/longformer-base-4096",
"longformer-large-4096": "allenai/longformer-large-4096",
"longformer-random-tiny": "patrickvonplaten/longformer-random-tiny", # A tiny model for debugging
}
class LongformerInputs:
def __init__(self, input_ids, attention_mask, global_attention_mask):
self.input_ids: torch.LongTensor = input_ids
self.attention_mask: Union[torch.FloatTensor, torch.HalfTensor] = attention_mask
self.global_attention_mask: Union[torch.FloatTensor, torch.HalfTensor] = global_attention_mask
def to_list(self) -> List:
return [v for v in [self.input_ids, self.attention_mask, self.global_attention_mask] if v is not None]
def to_tuple(self) -> Tuple:
return tuple(v for v in self.to_list())
def get_ort_inputs(self) -> Dict:
return {
"input_ids": numpy.ascontiguousarray(self.input_ids.cpu().numpy()),
"attention_mask": numpy.ascontiguousarray(self.attention_mask.cpu().numpy()),
"global_attention_mask": numpy.ascontiguousarray(self.global_attention_mask.cpu().numpy()),
}
class LongformerHelper:
"""A helper class for Longformer model conversion, inference and verification."""
@staticmethod
def get_dummy_inputs(
batch_size: int,
sequence_length: int,
num_global_tokens: int,
device: torch.device,
vocab_size: int = 100,
) -> LongformerInputs:
"""Create random inputs for Longformer model.
Returns torch tensors of input_ids, attention_mask and global_attention_mask tensors.
"""
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=torch.long,
device=device,
)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
global_token_index = list(range(num_global_tokens))
global_attention_mask[:, global_token_index] = 1
return LongformerInputs(input_ids, attention_mask, global_attention_mask)
@staticmethod
def get_output_shapes(batch_size: int, sequence_length: int, hidden_size: int) -> Dict[str, List[int]]:
"""Returns a dictionary with output name as key, and shape as value."""
return {
"last_state": [batch_size, sequence_length, hidden_size],
"pooler": [batch_size, sequence_length],
}

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,576 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations
import argparse
import logging
import os
from pathlib import Path
import onnx
import torch
from benchmark_helper import Precision
from fusion_options import AttentionOpType
from onnx_model import OnnxModel
from transformers import AutoConfig, AutoModelForCausalLM
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
class ConvertPhi2ToONNX:
def __init__(
self,
device: torch.device,
model_class: str = "microsoft/phi-2",
cache_dir: str = "./cache",
):
self.model_class = model_class
self.device = device
self.cache_dir = cache_dir
self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir)
self.phi_model = None
self.batch_size = 2
self.sequence_length = 8
self.attn_op_type = None
self.precision = None
self.block_size = 16
self.accuracy_level = None
def set_quantization_params(self, block_size: int, accuracy_level: int | None):
self.block_size = block_size
self.accuracy_level = accuracy_level
def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision):
self.attn_op_type = attn_op_type
self.precision = precision
def erase_onnx_model(self, onnx_path: str) -> None:
assert onnx_path.endswith(".onnx")
if not os.path.exists(onnx_path):
return
model = onnx.load_model(onnx_path, load_external_data=False)
onnx_data_path = None
for initializer in model.graph.initializer:
if initializer.data_location == 1 and initializer.external_data[0].key == "location":
onnx_data_path = "./" + initializer.external_data[0].value
break
logging.info(f"Erasing {onnx_path}...")
os.remove(onnx_path)
if onnx_data_path is not None:
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path)
logging.info(f"Erasing {onnx_data_path}...")
os.remove(onnx_data_path)
def get_phi2_torch_model(self):
logging.info("Loading phi2 torch model...")
if self.phi_model is not None:
return
self.phi_model = AutoModelForCausalLM.from_pretrained(
self.model_class, trust_remote_code=True, cache_dir=self.cache_dir
)
self.phi_model.eval()
self.phi_model.to(self.device)
def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int):
input_ids = torch.randint(
low=0,
high=self.phi_config.vocab_size,
size=(batch_size, sequence_length),
dtype=torch.int64,
device=self.device,
)
self.get_phi2_torch_model()
torch_inputs = self.phi_model.prepare_inputs_for_generation(
input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"]
)
return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"]
def dynamo_export(self, onnx_path: str):
input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length)
self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
from torch._dynamo import config
config.capture_scalar_outputs = True
logging.info("Exporting Phi2 torch model to ONNX...")
torch.onnx.dynamo_export(
self.phi_model,
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
).save(onnx_path)
onnx.checker.check_model(onnx_path)
onnx.shape_inference.infer_shapes_path(onnx_path)
def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
from fusion_options import FusionOptions
from optimizer import optimize_model
optimization_options = FusionOptions("phi")
optimization_options.set_attention_op_type(self.attn_op_type)
optimizer = optimize_model(
onnx_path,
model_type="phi",
num_heads=self.phi_config.num_attention_heads,
hidden_size=self.phi_config.hidden_size,
opt_level=0,
optimization_options=optimization_options,
only_onnxruntime=False,
)
fused_op_count = optimizer.get_fused_operator_statistics()
if optimizer.is_fully_optimized(fused_op_count):
logging.info("Model is fully optimized.")
else:
logging.info("Model is not fully optimized.")
if self.precision == Precision.FLOAT32:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
if (
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
node_block_list = (
[
"Attention_29",
"Attention_30",
"Attention_31",
]
if self.attn_op_type != AttentionOpType.PagedAttention
else []
) # TODO: temp setting for paged attention
logging.info("Converting onnx model to float16/bfloat16...")
optimizer.convert_float_to_float16(
keep_io_types=False,
node_block_list=node_block_list,
use_symbolic_shape_infer=True,
use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention,
)
logging.info("Converting onnx model to float16/bfloat16 done.")
if self.precision == Precision.FLOAT16:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
else:
assert self.precision == Precision.INT4
quant = MatMul4BitsQuantizer(
model=optimizer.model,
block_size=self.block_size,
is_symmetric=True,
accuracy_level=self.accuracy_level,
)
quant.process()
quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True)
# This function currently only works for phi2 model
def convert_to_use_cuda_graph(self, in_onnx_path: str, out_onnx_path: str):
onnx_model = OnnxModel(onnx.load(in_onnx_path, load_external_data=True))
from onnx import TensorProto, helper
graph = onnx_model.graph()
new_inputs = []
for vi in graph.input:
if "attention_mask" in vi.name:
vi_seqlen_k = helper.make_tensor_value_info(
"seqlens_k",
elem_type=TensorProto.INT32,
shape=["batch_size"],
)
vi_total_seq_len = helper.make_tensor_value_info(
"total_sequence_length",
elem_type=TensorProto.INT32,
shape=[1],
)
new_inputs.extend([vi_seqlen_k, vi_total_seq_len])
else:
new_inputs.append(vi)
graph.ClearField("input")
graph.input.extend(new_inputs)
gqas = onnx_model.get_nodes_by_op_type("GroupQueryAttention")
gqa = gqas[0]
seqlens_path = onnx_model.match_parent_path(
gqa,
["Cast", "Sub", "ReduceSum", "Cast"],
[5, 0, 0, 0],
)
if seqlens_path is None:
raise RuntimeError("Failed to find seqlens path for GroupQueryAttention node.")
total_seq_len_path = onnx_model.match_parent_path(
gqa,
["Cast", "Gather", "Shape"],
[6, 0, 0],
)
if total_seq_len_path is None:
raise RuntimeError("Failed to find total_seq_len path for GroupQueryAttention node.")
onnx_model.remove_nodes(seqlens_path)
onnx_model.remove_nodes(total_seq_len_path)
for gqa in gqas:
gqa.input[5] = "seqlens_k"
gqa.input[6] = "total_sequence_length"
onnx_model.save(onnx_model.model, out_onnx_path, save_as_external_data=True)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--fp32_cpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for CPU",
)
parser.add_argument(
"--int4_cpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for CPU",
)
parser.add_argument(
"--fp32_gpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--int4_gpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu_sm8x",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--int4_gpu_sm8x",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--fp16_vllm",
required=False,
action="store_true",
help="Generate fp16 ONNX model for ORT VLLM",
)
parser.add_argument(
"--int4_vllm",
required=False,
action="store_true",
help="Generate int4 ONNX model for ORT VLLM",
)
parser.add_argument(
"--use_cuda_graph",
required=False,
action="store_true",
help="Use CUDA Graph in decoding process",
)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing ONNX models",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./cache",
help="The cache directory for the pytorch model",
)
parser.add_argument(
"--device_id",
required=False,
type=int,
default=0,
help="The device id for the pytorch model",
)
parser.add_argument(
"--run_example",
required=False,
action="store_true",
help="Run ORT inference example",
)
parser.add_argument(
"--run_benchmark",
required=False,
action="store_true",
help="Run ORT benchmark",
)
parser.add_argument(
"--skip_export",
required=False,
action="store_true",
help="Skip exporting ONNX model",
)
parser.add_argument(
"--output_dir",
type=str,
help="The output directory for the ONNX models",
default="phi2_onnx_models",
)
parser.add_argument(
"--block_size",
required=False,
default=16,
type=int,
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
)
parser.add_argument(
"--int4_accuracy_level",
required=False,
type=int,
help="Accuracy level of the 4-bit quantized MatMul computation. "
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu")
converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir)
converter.set_quantization_params(args.block_size, args.int4_accuracy_level)
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
original_onnx_path = os.path.join(output_dir, "phi2_original.onnx")
if not args.skip_export:
if not os.path.exists(original_onnx_path) or args.overwrite:
converter.dynamo_export(original_onnx_path)
model_type_to_args = {
"fp32_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"),
),
"int4_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"),
),
"fp32_gpu": (
AttentionOpType.Attention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"),
),
"fp16_gpu": (
AttentionOpType.Attention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"),
),
"int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")),
"fp16_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"),
),
"int4_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
),
"fp16_vllm": (
AttentionOpType.PagedAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"),
),
"int4_vllm": (
AttentionOpType.PagedAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"),
),
}
if not args.skip_export:
from multiprocessing import Process
def run_optimize_phi2_onnx(
converter: ConvertPhi2ToONNX,
original_onnx_path: str,
attention_type: AttentionOpType,
precision: Precision,
optimized_onnx_path: str,
):
converter.init_attn_type_and_precision(attention_type, precision)
converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path)
if args.use_cuda_graph:
assert args.fp16_gpu_sm8x or args.int4_gpu_sm8x
converter.convert_to_use_cuda_graph(optimized_onnx_path, optimized_onnx_path)
processes = []
if args.fp32_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"])
)
)
if args.int4_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"])
)
)
if args.fp32_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"])
)
)
if args.fp16_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"])
)
)
if args.int4_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"])
)
)
if args.fp16_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]),
)
)
if args.int4_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]),
)
)
if args.fp16_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]),
)
)
if args.int4_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]),
)
)
[p.start() for p in processes]
[p.join() for p in processes]
if args.run_example or args.run_benchmark:
from inference_example import run_phi2
if args.fp16_gpu_sm8x:
logging.info("Running fp16_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
use_cuda_graph=args.use_cuda_graph,
run_benchmark=args.run_benchmark,
)
if args.int4_gpu_sm8x:
logging.info("Running int4_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
use_cuda_graph=args.use_cuda_graph,
run_benchmark=args.run_benchmark,
)
if args.fp32_gpu:
logging.info("Running fp32_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp32_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
use_fp16=False,
run_benchmark=args.run_benchmark,
)
if args.fp16_gpu:
logging.info("Running fp16_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
run_benchmark=args.run_benchmark,
)
if args.int4_gpu:
logging.info("Running int4_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
run_benchmark=args.run_benchmark,
)
if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm:
raise NotImplementedError("CPU/vllm inference example is not implemented yet.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,414 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import numpy as np
import torch
from transformers import AutoTokenizer
import onnxruntime as ort
pt_to_np = {
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.float32": np.float32,
"torch.float16": np.float16,
}
def cuda_memcpy(dst, src):
from cuda import cudart
cudart.cudaMemcpy(
dst.data_ptr(),
src.data_ptr(),
src.element_size() * src.nelement(),
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
)
class ORTGenerator:
def __init__(self, decoder_path):
self.onnx_decoder_path = decoder_path
self.num_heads = 32
self.head_size = 80
self.num_layers = 32
self.max_sequence_length = 2048
self.device_id = 0
self.use_cuda_graph = False
self.use_traced_inputs = False
self.static_inputs_map = {}
def append_static_inputs(self, batch_size):
# Only use this function with GQA and with use_cuda_graph=True
if batch_size in self.static_inputs_map:
return
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda", self.device_id)
static_io = {}
static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
for i in range(self.num_layers):
cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
self.static_inputs_map[batch_size] = static_io
def get_initial_inputs_and_outputs(self, encodings_dict):
self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
batch_size, sequence_length = input_ids.shape
self.use_traced_inputs = (
self.use_cuda_graph
and (batch_size in self.static_inputs_map)
and self.use_buffer_share
and not self.packed_kv
)
step = (
torch.tensor([0], device=self.device, dtype=torch.int64)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["step"]
)
seqlens_k = (
torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["seqlens_k"]
)
cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
total_seq_length = (
torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["total_sequence_length"]
)
total_seq_length[0] = sequence_length
inputs = {
"input_ids": input_ids.contiguous(),
"attention_mask": attention_mask.contiguous(),
}
if self.use_step:
inputs["step"] = step.contiguous()
if self.use_cuda_graph:
inputs["seqlens_k"] = seqlens_k.contiguous()
inputs["total_sequence_length"] = total_seq_length.contiguous()
del inputs["attention_mask"]
past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
past_shape = (
(2, batch_size, self.num_heads, past_seq_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, past_seq_length, self.head_size)
)
if not self.use_traced_inputs:
for i in range(self.num_layers):
past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
(
inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
if not self.packed_kv
else inputs.update({f"past_{i}": past.contiguous()})
)
else:
for i in range(self.num_layers):
inputs.update(
{
f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
}
)
logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
outputs = {"logits": logits.contiguous()}
if not self.use_buffer_share:
present_shape = (
(2, batch_size, self.num_heads, sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
(
outputs.update(
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
)
if not self.packed_kv
else outputs.update({f"present_{i}": present.contiguous()})
)
return inputs, outputs
def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
io_binding = model.io_binding()
device = None
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type=v.device.type,
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
device = v.device
for output in model.get_outputs():
name = output.name
if self.use_buffer_share and "present" in name:
v = inputs[name.replace("present", "past")]
io_binding.bind_output(
name=name,
device_type=v.device.type,
device_id=v.device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
else:
v = outputs[name]
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
return io_binding
def create_session(
self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
):
self.device_id = device_id
sess_options = ort.SessionOptions()
sess_options.log_verbosity_level = 4
sess_options.log_severity_level = 4
self.use_cuda_graph = use_cuda_graph
ep = (
("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
if self.device_id >= 0
else "CPUExecutionProvider"
)
self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
self.ro = ort.RunOptions()
self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
self.use_fp16 = use_fp16
self.use_buffer_share = use_buffer_share
self.packed_kv = packed_kv
self.use_step = use_step
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
self.tokenizer.pad_token = "[PAD]"
def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
all_token_ids = inputs["input_ids"].clone()
batch_size, sequence_length = all_token_ids.shape
current_length = sequence_length
has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
if benchmark:
import time
latency = []
prompt_run = True
while current_length < max_length:
io_binding = self.apply_io_binding(self.sess, inputs, outputs)
if benchmark:
start = time.time()
io_binding.synchronize_inputs()
if prompt_run:
if self.use_cuda_graph:
# Disable CUDA graph for the prompt run
self.ro.add_run_config_entry("gpu_graph_id", "-1")
self.sess.run_with_iobinding(io_binding, self.ro)
if self.use_cuda_graph:
# Enable CUDA graph for the decoding run
self.ro.add_run_config_entry(
"gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
)
prompt_run = False
else:
self.sess.run_with_iobinding(io_binding, self.ro)
io_binding.synchronize_outputs()
if benchmark:
end = time.time()
latency.append(end - start)
# Sample with argmax (greedy search)
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check if we previously reached EOS token id or if generated token id is EOS token id
has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
# Determine which new tokens to add to list of all token ids
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
# Return early if all batch entries have reached EOS token id
if torch.all(has_eos):
break
# Update inputs for next inference run
current_length += 1
inputs["input_ids"] = tokens_to_add.to(torch.int32)
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
if self.use_step:
inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
inputs["step"] = self.static_inputs_map[batch_size]["step"]
if self.use_cuda_graph:
previous_seqlens_k = inputs["seqlens_k"]
inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
inputs["total_sequence_length"][0] = current_length
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
else:
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
).to(torch.int32)
# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
if self.use_traced_inputs:
outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
outputs["logits"].zero_()
if not self.use_buffer_share:
for i in range(self.num_layers):
if not self.packed_kv:
inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
else:
inputs[f"past_{i}"] = outputs[f"present_{i}"]
new_sequence_length = inputs["attention_mask"].shape[1]
present_shape = (
(2, batch_size, self.num_heads, new_sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, new_sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
(
outputs.update(
{
f"present_key_{i}": present.contiguous(),
f"present_value_{i}": present.clone().contiguous(),
}
)
if not self.packed_kv
else outputs.update({f"present_{i}": present.contiguous()})
)
if benchmark:
print(
f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
)
print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
return
texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
return texts
def generate(self, prompt, max_length, cuda_graph_annotation):
encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
batch_size, sequence_length = prompt_shape
max_length = sequence_length + token_num
encodings_dict = {}
encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
# Warm up run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
# Benchmark run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
def run_phi2(
onnx_model_path,
use_buffer_share,
device_id,
packed_kv=False,
use_fp16=True,
use_step=False,
use_cuda_graph=False,
run_benchmark=False,
):
generator = ORTGenerator(onnx_model_path)
generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
def simple_run(prompt):
example_batch_size = len(prompt)
if use_cuda_graph:
generator.append_static_inputs(batch_size=example_batch_size)
texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
for i in range(len(texts)):
print("Prompt: ", prompt[i])
print("Texts: ", texts[i])
prompt = [
'''```python
def print_prime(n):
"""
Print all primes between 1 and n
"""'''
]
if not run_benchmark:
simple_run(prompt)
# Run simple benchmark. Time the decoder only.
if run_benchmark:
token_num = 32
for batch_size in [1, 2, 4, 8]:
generator.append_static_inputs(batch_size)
for sequence_length in [16, 512]:
prompt_shape = (batch_size, sequence_length)
generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,426 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import gc
import importlib.util
import time
from statistics import mean
import torch
from demo_utils import PipelineInfo
from diffusers import (
AutoencoderKL,
ControlNetModel,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
StableDiffusionXLControlNetPipeline,
)
from engine_builder import EngineType, get_engine_paths
from pipeline_stable_diffusion import StableDiffusionPipeline
"""
Benchmark script for SDXL-Turbo with control net for engines like PyTorch or Stable Fast.
Setup for Stable Fast (see https://github.com/chengzeyi/stable-fast/blob/main/README.md for more info):
git clone https://github.com/chengzeyi/stable-fast.git
cd stable-fast
git submodule update --init
pip3 install torch torchvision torchaudio ninja
pip3 install -e '.[dev,xformers,triton,transformers,diffusers]' -v
sudo apt install libgoogle-perftools-dev
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so
"""
def get_canny_image():
import cv2
import numpy as np
from PIL import Image
# Test Image can be downloaded from https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png
image = Image.open("input_image_vermeer.png").convert("RGB")
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
return Image.fromarray(image)
def compile_stable_fast(pipeline, enable_cuda_graph=True):
from sfast.compilers.stable_diffusion_pipeline_compiler import CompilationConfig, compile
config = CompilationConfig.Default()
if importlib.util.find_spec("xformers") is not None:
config.enable_xformers = True
if importlib.util.find_spec("triton") is not None:
config.enable_triton = True
config.enable_cuda_graph = enable_cuda_graph
pipeline = compile(pipeline, config)
return pipeline
def compile_torch(pipeline, use_nhwc=False):
if use_nhwc:
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
if hasattr(pipeline, "controlnet"):
if use_nhwc:
pipeline.controlnet.to(memory_format=torch.channels_last)
pipeline.controlnet = torch.compile(pipeline.controlnet, mode="reduce-overhead", fullgraph=True)
return pipeline
def load_pipeline(name, engine, use_control_net=False, use_nhwc=False, enable_cuda_graph=True):
gc.collect()
torch.cuda.empty_cache()
before_memory = torch.cuda.memory_allocated()
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(name, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
if use_control_net:
assert "xl" in name
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
name,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
else:
pipeline = DiffusionPipeline.from_pretrained(
name,
vae=vae,
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
pipeline.safety_checker = None
gc.collect()
after_memory = torch.cuda.memory_allocated()
print(f"Loaded model with {after_memory - before_memory} bytes allocated")
if engine == "stable_fast":
pipeline = compile_stable_fast(pipeline, enable_cuda_graph=enable_cuda_graph)
elif engine == "torch":
pipeline = compile_torch(pipeline, use_nhwc=use_nhwc)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def get_prompt():
return "little cute gremlin wearing a jacket, cinematic, vivid colors, intricate masterpiece, golden ratio, highly detailed"
def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_graph=True, work_dir="."):
version = PipelineInfo.supported_models()[name]
guidance_scale = 0.0
pipeline_info = PipelineInfo(
version,
use_vae=True,
use_fp16_vae=True,
do_classifier_free_guidance=(guidance_scale > 1.0),
controlnet=["canny"] if use_control_net else [],
)
engine_type = EngineType.ORT_CUDA if engine == "ort_cuda" else EngineType.ORT_TRT
onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type
)
pipeline = StableDiffusionPipeline(
pipeline_info,
scheduler="EulerA",
max_batch_size=32,
use_cuda_graph=enable_cuda_graph,
framework_model_dir=framework_model_dir,
output_dir=output_dir,
engine_type=engine_type,
)
pipeline.backend.build_engines(
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
device_id=torch.cuda.current_device(),
)
return pipeline
def test_ort_cuda(
pipeline,
batch_size=1,
steps=4,
control_image=None,
warmup_runs=3,
test_runs=10,
seed=123,
verbose=False,
image_height=512,
image_width=512,
):
if batch_size > 4 and pipeline.pipeline_info.version == "xl-1.0":
pipeline.backend.enable_vae_slicing()
pipeline.load_resources(image_height, image_width, batch_size)
warmup_prompt = "warm up"
for _ in range(warmup_runs):
images, _ = pipeline.run(
[warmup_prompt] * batch_size,
[""] * batch_size,
image_height=image_height,
image_width=image_width,
denoising_steps=steps,
guidance=0.0,
seed=seed,
controlnet_images=[control_image],
controlnet_scales=torch.FloatTensor([0.5]),
output_type="image",
)
assert len(images) == batch_size
generator = torch.Generator(device="cuda")
generator.manual_seed(seed)
prompt = get_prompt()
latency_list = []
images = None
for _ in range(test_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
images, _ = pipeline.run(
[prompt] * batch_size,
[""] * batch_size,
image_height=image_height,
image_width=image_width,
denoising_steps=steps,
guidance=0.0,
seed=seed,
controlnet_images=[control_image],
controlnet_scales=torch.FloatTensor([0.5]),
output_type="pil",
)
torch.cuda.synchronize()
seconds = time.perf_counter() - start_time
latency_list.append(seconds)
if verbose:
print(latency_list)
return images, latency_list
def test(pipeline, batch_size=1, steps=4, control_image=None, warmup_runs=3, test_runs=10, seed=123, verbose=False):
control_net_args = {}
if hasattr(pipeline, "controlnet"):
control_net_args = {
"image": control_image,
"controlnet_conditioning_scale": 0.5,
}
warmup_prompt = "warm up"
for _ in range(warmup_runs):
images = pipeline(
prompt=warmup_prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=0.0,
**control_net_args,
).images
assert len(images) == batch_size
generator = torch.Generator(device="cuda")
generator.manual_seed(seed)
prompt = get_prompt()
latency_list = []
images = None
for _ in range(test_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
images = pipeline(
prompt=prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=0.0,
generator=generator,
**control_net_args,
).images
torch.cuda.synchronize()
seconds = time.perf_counter() - start_time
latency_list.append(seconds)
if verbose:
print(latency_list)
return images, latency_list
def arguments():
import argparse
parser = argparse.ArgumentParser(description="Benchmark Stable Diffusion pipeline (optional control net for SDXL)")
parser.add_argument(
"--engine",
type=str,
default="torch",
choices=["torch", "stable_fast", "ort_cuda", "ort_trt"],
help="Backend engine: torch, stable_fast or ort_cuda",
)
parser.add_argument(
"--name",
type=str,
choices=list(PipelineInfo.supported_models().keys()),
default="stabilityai/sdxl-turbo",
help="Stable diffusion model name. Default is stabilityai/sdxl-turbo",
)
parser.add_argument(
"--work-dir",
type=str,
default=".",
help="working directory for ort_cuda or ort_trt",
)
parser.add_argument(
"--use_control_net",
action="store_true",
help="Use control net diffusers/controlnet-canny-sdxl-1.0",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--steps",
type=int,
default=1,
help="Denoising steps",
)
parser.add_argument(
"--warmup_runs",
type=int,
default=3,
help="Number of warmup runs before measurement",
)
parser.add_argument(
"--use_nhwc",
action="store_true",
help="use channel last format for torch compile",
)
parser.add_argument(
"--enable_cuda_graph",
action="store_true",
help="enable cuda graph for stable fast",
)
parser.add_argument(
"--verbose",
action="store_true",
help="print more information",
)
args = parser.parse_args()
return args
def main():
args = arguments()
with torch.no_grad():
if args.engine == "ort_cuda":
pipeline = load_ort_cuda_pipeline(
args.name,
args.engine,
use_control_net=args.use_control_net,
enable_cuda_graph=args.enable_cuda_graph,
work_dir=args.work_dir,
)
else:
pipeline = load_pipeline(
args.name,
args.engine,
use_control_net=args.use_control_net,
use_nhwc=args.use_nhwc,
enable_cuda_graph=args.enable_cuda_graph,
)
canny_image = get_canny_image()
if args.engine == "ort_cuda":
images, latency_list = test_ort_cuda(
pipeline,
args.batch_size,
args.steps,
control_image=canny_image,
warmup_runs=args.warmup_runs,
verbose=args.verbose,
)
elif args.engine == "stable_fast":
from sfast.utils.compute_precision import low_compute_precision
with low_compute_precision():
images, latency_list = test(
pipeline,
args.batch_size,
args.steps,
control_image=canny_image,
warmup_runs=args.warmup_runs,
verbose=args.verbose,
)
else:
images, latency_list = test(
pipeline,
args.batch_size,
args.steps,
control_image=canny_image,
warmup_runs=args.warmup_runs,
verbose=args.verbose,
)
# Save the first output image to inspect the result.
if images:
images[0].save(
f"{args.engine}_{args.name.replace('/', '_')}_{args.batch_size}_{args.steps}_c{int(args.use_control_net)}.png"
)
result = {
"engine": args.engine,
"batch_size": args.batch_size,
"steps": args.steps,
"control_net": args.use_control_net,
"nhwc": args.use_nhwc,
"enable_cuda_graph": args.enable_cuda_graph,
"average_latency_in_ms": mean(latency_list) * 1000,
}
print(result)
main()

View File

@ -0,0 +1,102 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import coloredlogs
from cuda import cudart
from demo_utils import (
add_controlnet_arguments,
arg_parser,
get_metadata,
load_pipelines,
parse_arguments,
process_controlnet_arguments,
repeat_prompt,
)
def main(args):
controlnet_images, controlnet_scale = process_controlnet_arguments(args)
pipeline, refiner = load_pipelines(args)
assert refiner is None
prompt, negative_prompt = repeat_prompt(args)
batch_size = len(prompt)
pipeline.load_resources(args.height, args.width, batch_size)
def run_inference(warmup=False):
return pipeline.run(
prompt,
negative_prompt,
args.height,
args.width,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
controlnet_images=controlnet_images,
controlnet_scales=controlnet_scale,
show_latency=not warmup,
output_type="pil",
deterministic=args.deterministic,
)
if not args.disable_cuda_graph:
# inference once to get cuda graph
_, _ = run_inference(warmup=True)
print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
_, _ = run_inference(warmup=True)
print("[I] Running StableDiffusion pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
images, perf_data = run_inference(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()
metadata = get_metadata(args, False)
metadata.update(pipeline.metadata())
if perf_data:
metadata.update(perf_data)
metadata["images"] = len(images)
print(metadata)
pipeline.save_images(images, prompt, negative_prompt, metadata)
pipeline.teardown()
if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
parser = arg_parser("Options for Stable Diffusion Demo")
add_controlnet_arguments(parser)
args = parse_arguments(is_xl=False, parser=parser)
if args.user_compute_stream:
import torch
s = torch.cuda.Stream()
with torch.cuda.stream(s):
main(args)
else:
main(args)

View File

@ -0,0 +1,268 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import coloredlogs
from cuda import cudart
from demo_utils import (
add_controlnet_arguments,
arg_parser,
get_metadata,
load_pipelines,
parse_arguments,
process_controlnet_arguments,
repeat_prompt,
)
def run_pipelines(
args, base, refiner, prompt, negative_prompt, controlnet_image=None, controlnet_scale=None, is_warm_up=False
):
image_height = args.height
image_width = args.width
batch_size = len(prompt)
base.load_resources(image_height, image_width, batch_size)
if refiner:
refiner.load_resources(image_height, image_width, batch_size)
def run_base_and_refiner(warmup=False):
images, base_perf = base.run(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
controlnet_images=controlnet_image,
controlnet_scales=controlnet_scale,
show_latency=not warmup,
output_type="latent" if refiner else "pil",
)
if refiner is None:
return images, base_perf
# Use same seed in base and refiner.
seed = base.get_current_seed()
images, refiner_perf = refiner.run(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=args.refiner_denoising_steps,
image=images,
strength=args.strength,
guidance=args.refiner_guidance,
seed=seed,
show_latency=not warmup,
)
perf_data = None
if base_perf and refiner_perf:
perf_data = {"latency": base_perf["latency"] + refiner_perf["latency"]}
perf_data.update({"base." + key: val for key, val in base_perf.items()})
perf_data.update({"refiner." + key: val for key, val in refiner_perf.items()})
return images, perf_data
if not args.disable_cuda_graph:
# inference once to get cuda graph
_, _ = run_base_and_refiner(warmup=True)
if args.num_warmup_runs > 0:
print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
_, _ = run_base_and_refiner(warmup=True)
if is_warm_up:
return
print("[I] Running StableDiffusion XL pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
images, perf_data = run_base_and_refiner(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()
if refiner:
print("|----------------|--------------|")
print("| {:^14} | {:>9.2f} ms |".format("e2e", perf_data["latency"]))
print("|----------------|--------------|")
metadata = get_metadata(args, True)
metadata.update({"base." + key: val for key, val in base.metadata().items()})
if refiner:
metadata.update({"refiner." + key: val for key, val in refiner.metadata().items()})
if perf_data:
metadata.update(perf_data)
metadata["images"] = len(images)
print(metadata)
(refiner or base).save_images(images, prompt, negative_prompt, metadata)
def run_demo(args):
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""
controlnet_image, controlnet_scale = process_controlnet_arguments(args)
prompt, negative_prompt = repeat_prompt(args)
batch_size = len(prompt)
base, refiner = load_pipelines(args, batch_size)
run_pipelines(args, base, refiner, prompt, negative_prompt, controlnet_image, controlnet_scale)
base.teardown()
if refiner:
refiner.teardown()
def run_dynamic_shape_demo(args):
"""
Run demo of generating images with different settings with ORT CUDA provider.
Try "python demo_txt2img_xl.py --max-cuda-graphs 3 --user-compute-stream" to see the effect of multiple CUDA graphs.
"""
args.engine = "ORT_CUDA"
base, refiner = load_pipelines(args, 1)
prompts = [
"starry night over Golden Gate Bridge by van gogh",
"beautiful photograph of Mt. Fuji during cherry blossom",
"little cute gremlin sitting on a bed, cinematic",
"cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
"beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
"blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
"An astronaut riding a rainbow unicorn, cinematic, dramatic",
"close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm",
]
# batch size, height, width, scheduler, steps, prompt, seed, guidance, refiner scheduler, refiner steps, refiner strength
configs = [
(1, 832, 1216, "UniPC", 8, prompts[0], None, 5.0, "UniPC", 10, 0.3),
(1, 1024, 1024, "DDIM", 24, prompts[1], None, 5.0, "DDIM", 30, 0.3),
(1, 1216, 832, "EulerA", 16, prompts[2], 1716921396712843, 5.0, "EulerA", 10, 0.3),
(1, 1344, 768, "EulerA", 24, prompts[3], 123698071912362, 5.0, "EulerA", 20, 0.3),
(2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712, 5.0, "UniPC", 10, 0.3),
(2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906, 5.0, "UniPC", 20, 0.3),
]
# In testing LCM, refiner is disabled so the settings of refiner is not used.
if args.lcm:
configs = [
(1, 1024, 1024, "LCM", 8, prompts[6], None, 1.0, "UniPC", 20, 0.3),
(1, 1216, 832, "LCM", 6, prompts[7], 1337, 1.0, "UniPC", 20, 0.3),
]
# Warm up each combination of (batch size, height, width) once before serving.
args.prompt = ["warm up"]
args.num_warmup_runs = 1
for batch_size, height, width, _, _, _, _, _, _, _, _ in configs:
args.batch_size = batch_size
args.height = height
args.width = width
print(f"\nWarm up batch_size={batch_size}, height={height}, width={width}")
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=True)
# Run pipeline on a list of prompts.
args.num_warmup_runs = 0
for (
batch_size,
height,
width,
scheduler,
steps,
example_prompt,
seed,
guidance,
refiner_scheduler,
refiner_denoising_steps,
strength,
) in configs:
args.prompt = [example_prompt]
args.batch_size = batch_size
args.height = height
args.width = width
args.scheduler = scheduler
args.denoising_steps = steps
args.seed = seed
args.guidance = guidance
args.refiner_scheduler = refiner_scheduler
args.refiner_denoising_steps = refiner_denoising_steps
args.strength = strength
base.set_scheduler(scheduler)
if refiner:
refiner.set_scheduler(refiner_scheduler)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)
base.teardown()
if refiner:
refiner.teardown()
def run_turbo_demo(args):
"""Run demo of generating images with test prompts with ORT CUDA provider."""
args.engine = "ORT_CUDA"
base, refiner = load_pipelines(args, 1)
from datasets import load_dataset
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
num_rows = dataset["test"].num_rows
batch_size = args.batch_size
num_batch = int(num_rows / batch_size)
args.batch_size = 1
for i in range(num_batch):
args.prompt = [dataset["test"][i]["Prompt"] for i in range(i * batch_size, (i + 1) * batch_size)]
base.set_scheduler(args.scheduler)
if refiner:
refiner.set_scheduler(args.refiner_scheduler)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)
base.teardown()
if refiner:
refiner.teardown()
def main(args):
no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0]
if no_prompt:
if args.version == "xl-turbo":
run_turbo_demo(args)
else:
run_dynamic_shape_demo(args)
else:
run_demo(args)
if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
parser = arg_parser("Options for Stable Diffusion XL Demo")
add_controlnet_arguments(parser)
args = parse_arguments(is_xl=True, parser=parser)
if args.user_compute_stream:
import torch
s = torch.cuda.Stream()
with torch.cuda.stream(s):
main(args)
else:
main(args)

View File

@ -0,0 +1,778 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import argparse
import os
import sys
from importlib.metadata import PackageNotFoundError, version
from typing import Any, Dict, List, Optional
import controlnet_aux
import cv2
import numpy as np
import torch
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineType, get_engine_paths, get_engine_type
from PIL import Image
from pipeline_stable_diffusion import StableDiffusionPipeline
class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
pass
def arg_parser(description: str):
return argparse.ArgumentParser(
description=description,
formatter_class=RawTextArgumentDefaultsHelpFormatter,
)
def set_default_arguments(args):
# set default value for some arguments if not provided
if args.height is None:
args.height = PipelineInfo.default_resolution(args.version)
if args.width is None:
args.width = PipelineInfo.default_resolution(args.version)
is_lcm = (args.version == "xl-1.0" and args.lcm) or "lcm" in args.lora_weights
is_turbo = args.version in ["sd-turbo", "xl-turbo"]
if args.denoising_steps is None:
args.denoising_steps = 4 if is_turbo else 8 if is_lcm else (30 if args.version == "xl-1.0" else 50)
if args.scheduler is None:
args.scheduler = "LCM" if (is_lcm or is_turbo) else ("EulerA" if args.version == "xl-1.0" else "DDIM")
if args.guidance is None:
args.guidance = 0.0 if (is_lcm or is_turbo) else (5.0 if args.version == "xl-1.0" else 7.5)
def parse_arguments(is_xl: bool, parser):
engines = ["ORT_CUDA", "ORT_TRT", "TRT", "TORCH"]
parser.add_argument(
"-e",
"--engine",
type=str,
default=engines[0],
choices=engines,
help="Backend engine in {engines}. "
"ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT",
)
supported_versions = PipelineInfo.supported_versions(is_xl)
parser.add_argument(
"-v",
"--version",
type=str,
default="xl-1.0" if is_xl else "1.5",
choices=supported_versions,
help="Version of Stable Diffusion" + (" XL." if is_xl else "."),
)
parser.add_argument(
"-y",
"--height",
type=int,
default=None,
help="Height of image to generate (must be multiple of 8).",
)
parser.add_argument(
"-x", "--width", type=int, default=None, help="Height of image to generate (must be multiple of 8)."
)
parser.add_argument(
"-s",
"--scheduler",
type=str,
default=None,
choices=["DDIM", "EulerA", "UniPC", "LCM"],
help="Scheduler for diffusion process" + " of base" if is_xl else "",
)
parser.add_argument(
"-wd",
"--work-dir",
default=".",
help="Root Directory to store torch or ONNX models, built engines and output images etc.",
)
parser.add_argument(
"-i",
"--engine-dir",
default=None,
help="Root Directory to store built engines or optimized ONNX models etc.",
)
parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.")
parser.add_argument(
"-n",
"--negative-prompt",
nargs="*",
default=[""],
help="Optional negative prompt(s) to guide the image generation.",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=1,
choices=[1, 2, 4, 8, 16],
help="Number of times to repeat the prompt (batch size multiplier).",
)
parser.add_argument(
"-d",
"--denoising-steps",
type=int,
default=None,
help="Number of denoising steps" + (" in base." if is_xl else "."),
)
parser.add_argument(
"-g",
"--guidance",
type=float,
default=None,
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.",
)
parser.add_argument(
"-ls", "--lora-scale", type=float, default=1, help="Scale of LoRA weights, default 1 (must between 0 and 1)"
)
parser.add_argument("-lw", "--lora-weights", type=str, default="", help="LoRA weights to apply in the base model")
if is_xl:
parser.add_argument(
"--lcm",
action="store_true",
help="Use fine-tuned latent consistency model to replace the UNet in base.",
)
parser.add_argument(
"-rs",
"--refiner-scheduler",
type=str,
default="EulerA",
choices=["DDIM", "EulerA", "UniPC"],
help="Scheduler for diffusion process of refiner.",
)
parser.add_argument(
"-rg",
"--refiner-guidance",
type=float,
default=5.0,
help="Guidance scale used in refiner.",
)
parser.add_argument(
"-rd",
"--refiner-denoising-steps",
type=int,
default=30,
help="Number of denoising steps in refiner. Note that actual steps is refiner_denoising_steps * strength.",
)
parser.add_argument(
"--strength",
type=float,
default=0.3,
help="A value between 0 and 1. The higher the value less the final image similar to the seed image.",
)
parser.add_argument(
"-r",
"--enable-refiner",
action="store_true",
help="Enable SDXL refiner to refine image from base pipeline.",
)
# ONNX export
parser.add_argument(
"--onnx-opset",
type=int,
default=None,
choices=range(14, 18),
help="Select ONNX opset version to target for exported models.",
)
# Engine build options.
parser.add_argument(
"-db",
"--build-dynamic-batch",
action="store_true",
help="Build TensorRT engines to support dynamic batch size.",
)
parser.add_argument(
"-ds",
"--build-dynamic-shape",
action="store_true",
help="Build TensorRT engines to support dynamic image sizes.",
)
parser.add_argument("--max-batch-size", type=int, default=None, choices=[1, 2, 4, 8, 16, 32], help="Max batch size")
# Inference related options
parser.add_argument(
"-nw", "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance."
)
parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.")
parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.")
parser.add_argument("--deterministic", action="store_true", help="use deterministic algorithms.")
parser.add_argument("-dc", "--disable-cuda-graph", action="store_true", help="Disable cuda graph.")
parser.add_argument("--framework-model-dir", default=None, help="framework model directory")
group = parser.add_argument_group("Options for ORT_CUDA engine only")
group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")
group.add_argument("--max-cuda-graphs", type=int, default=1, help="Max number of cuda graphs to use. Default 1.")
group.add_argument("--user-compute-stream", action="store_true", help="Use user compute stream.")
# TensorRT only options
group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only")
group.add_argument(
"--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources."
)
args = parser.parse_args()
set_default_arguments(args)
# Validate image dimensions
if args.height % 64 != 0 or args.width % 64 != 0:
raise ValueError(
f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}."
)
if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
print("[I] CUDA Graph is disabled since dynamic input shape is configured.")
args.disable_cuda_graph = True
if args.onnx_opset is None:
args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17
if is_xl:
if args.version == "xl-turbo":
if args.lcm:
print("[I] sdxl-turbo cannot use with LCM.")
args.lcm = False
assert args.strength > 0.0 and args.strength < 1.0
assert not (args.lcm and args.lora_weights), "it is not supported to use both lcm unet and Lora together"
if args.scheduler == "LCM":
if args.guidance > 2.0:
print("[I] Use --guidance=0.0 (no more than 2.0) when LCM scheduler is used.")
args.guidance = 0.0
if args.denoising_steps > 16:
print("[I] Use --denoising_steps=8 (no more than 16) when LCM scheduler is used.")
args.denoising_steps = 8
print(args)
return args
def max_batch(args):
if args.max_batch_size:
max_batch_size = args.max_batch_size
else:
do_classifier_free_guidance = args.guidance > 1.0
batch_multiplier = 2 if do_classifier_free_guidance else 1
max_batch_size = 32 // batch_multiplier
if args.engine != "ORT_CUDA" and (args.build_dynamic_shape or args.height > 512 or args.width > 512):
max_batch_size = 8 // batch_multiplier
return max_batch_size
def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]:
metadata = {
"command": " ".join(['"' + x + '"' if " " in x else x for x in sys.argv]),
"args.prompt": args.prompt,
"args.negative_prompt": args.negative_prompt,
"args.batch_size": args.batch_size,
"height": args.height,
"width": args.width,
"cuda_graph": not args.disable_cuda_graph,
"vae_slicing": args.enable_vae_slicing,
"engine": args.engine,
}
if args.lora_weights:
metadata["lora_weights"] = args.lora_weights
metadata["lora_scale"] = args.lora_scale
if args.controlnet_type:
metadata["controlnet_type"] = args.controlnet_type
metadata["controlnet_scale"] = args.controlnet_scale
if is_xl and args.enable_refiner:
metadata["base.scheduler"] = args.scheduler
metadata["base.denoising_steps"] = args.denoising_steps
metadata["base.guidance"] = args.guidance
metadata["refiner.strength"] = args.strength
metadata["refiner.scheduler"] = args.refiner_scheduler
metadata["refiner.denoising_steps"] = args.refiner_denoising_steps
metadata["refiner.guidance"] = args.refiner_guidance
else:
metadata["scheduler"] = args.scheduler
metadata["denoising_steps"] = args.denoising_steps
metadata["guidance"] = args.guidance
# Version of installed python packages
packages = ""
for name in [
"onnxruntime-gpu",
"torch",
"tensorrt",
"transformers",
"diffusers",
"onnx",
"onnx-graphsurgeon",
"polygraphy",
"controlnet_aux",
]:
try:
packages += (" " if packages else "") + f"{name}=={version(name)}"
except PackageNotFoundError:
continue
metadata["packages"] = packages
metadata["device"] = torch.cuda.get_device_name()
metadata["torch.version.cuda"] = torch.version.cuda
return metadata
def repeat_prompt(args):
if not isinstance(args.prompt, list):
raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
prompt = args.prompt * args.batch_size
if not isinstance(args.negative_prompt, list):
raise ValueError(
f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}"
)
if len(args.negative_prompt) == 1:
negative_prompt = args.negative_prompt * len(prompt)
else:
negative_prompt = args.negative_prompt
return prompt, negative_prompt
def initialize_pipeline(
version="xl-turbo",
is_refiner: bool = False,
is_inpaint: bool = False,
engine_type=EngineType.ORT_CUDA,
work_dir: str = ".",
engine_dir=None,
onnx_opset: int = 17,
scheduler="EulerA",
height=512,
width=512,
nvtx_profile=False,
use_cuda_graph=True,
build_dynamic_batch=False,
build_dynamic_shape=False,
min_image_size: int = 512,
max_image_size: int = 1024,
max_batch_size: int = 16,
opt_batch_size: int = 1,
build_all_tactics: bool = False,
do_classifier_free_guidance: bool = False,
lcm: bool = False,
controlnet=None,
lora_weights=None,
lora_scale: float = 1.0,
use_fp16_vae: bool = True,
use_vae: bool = True,
framework_model_dir: Optional[str] = None,
max_cuda_graphs: int = 1,
):
pipeline_info = PipelineInfo(
version,
is_refiner=is_refiner,
is_inpaint=is_inpaint,
use_vae=use_vae,
min_image_size=min_image_size,
max_image_size=max_image_size,
use_fp16_vae=use_fp16_vae,
use_lcm=lcm,
do_classifier_free_guidance=do_classifier_free_guidance,
controlnet=controlnet,
lora_weights=lora_weights,
lora_scale=lora_scale,
)
input_engine_dir = engine_dir
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type, framework_model_dir=framework_model_dir
)
pipeline = StableDiffusionPipeline(
pipeline_info,
scheduler=scheduler,
output_dir=output_dir,
verbose=False,
nvtx_profile=nvtx_profile,
max_batch_size=max_batch_size,
use_cuda_graph=use_cuda_graph,
framework_model_dir=framework_model_dir,
engine_type=engine_type,
)
import_engine_dir = None
if input_engine_dir:
if not os.path.exists(input_engine_dir):
raise RuntimeError(f"--engine_dir directory does not exist: {input_engine_dir}")
# Support importing from optimized diffusers onnx pipeline
if engine_type == EngineType.ORT_CUDA and os.path.exists(os.path.join(input_engine_dir, "model_index.json")):
import_engine_dir = input_engine_dir
else:
engine_dir = input_engine_dir
opt_image_height = pipeline_info.default_image_size() if build_dynamic_shape else height
opt_image_width = pipeline_info.default_image_size() if build_dynamic_shape else width
if engine_type == EngineType.ORT_CUDA:
pipeline.backend.build_engines(
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
tmp_dir=os.path.join(work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"),
device_id=torch.cuda.current_device(),
import_engine_dir=import_engine_dir,
max_cuda_graphs=max_cuda_graphs,
)
elif engine_type == EngineType.ORT_TRT:
pipeline.backend.build_engines(
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
opt_batch_size=opt_batch_size,
static_batch=not build_dynamic_batch,
static_image_shape=not build_dynamic_shape,
max_workspace_size=0,
device_id=torch.cuda.current_device(),
timing_cache=timing_cache,
)
elif engine_type == EngineType.TRT:
pipeline.backend.load_engines(
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_batch_size=opt_batch_size,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
static_batch=not build_dynamic_batch,
static_shape=not build_dynamic_shape,
enable_all_tactics=build_all_tactics,
timing_cache=timing_cache,
)
elif engine_type == EngineType.TORCH:
pipeline.backend.build_engines(framework_model_dir)
else:
raise RuntimeError("invalid engine type")
return pipeline
def load_pipelines(args, batch_size=None):
engine_type = get_engine_type(args.engine)
# Register TensorRT plugins
if engine_type == EngineType.TRT:
from trt_utilities import init_trt_plugins
init_trt_plugins()
max_batch_size = max_batch(args)
if batch_size is None:
assert isinstance(args.prompt, list)
batch_size = len(args.prompt) * args.batch_size
if batch_size > max_batch_size:
raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")
# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
# This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
if args.version == "xl-turbo":
min_image_size = 512
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
elif args.version == "xl-1.0":
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048
else:
# This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768.
min_image_size = 512 if args.engine != "ORT_CUDA" else 256
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
params = {
"version": args.version,
"is_refiner": False,
"is_inpaint": False,
"engine_type": engine_type,
"work_dir": args.work_dir,
"engine_dir": args.engine_dir,
"onnx_opset": args.onnx_opset,
"scheduler": args.scheduler,
"height": args.height,
"width": args.width,
"nvtx_profile": args.nvtx_profile,
"use_cuda_graph": not args.disable_cuda_graph,
"build_dynamic_batch": args.build_dynamic_batch,
"build_dynamic_shape": args.build_dynamic_shape,
"min_image_size": min_image_size,
"max_image_size": max_image_size,
"max_batch_size": max_batch_size,
"opt_batch_size": 1 if args.build_dynamic_batch else batch_size,
"build_all_tactics": args.build_all_tactics,
"do_classifier_free_guidance": args.guidance > 1.0,
"controlnet": args.controlnet_type,
"lora_weights": args.lora_weights,
"lora_scale": args.lora_scale,
"use_fp16_vae": "xl" in args.version,
"use_vae": True,
"framework_model_dir": args.framework_model_dir,
"max_cuda_graphs": args.max_cuda_graphs,
}
if "xl" in args.version:
params["lcm"] = args.lcm
params["use_vae"] = not args.enable_refiner
base = initialize_pipeline(**params)
refiner = None
if "xl" in args.version and args.enable_refiner:
params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner.
params["is_refiner"] = True
params["scheduler"] = args.refiner_scheduler
params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0
params["lcm"] = False
params["controlnet"] = None
params["lora_weights"] = None
params["use_vae"] = True
params["use_fp16_vae"] = True
refiner = initialize_pipeline(**params)
if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
base.backend.activate_engines(shared_device_memory)
if refiner:
refiner.backend.activate_engines(shared_device_memory)
if engine_type == EngineType.ORT_CUDA:
enable_vae_slicing = args.enable_vae_slicing
if batch_size > 4 and not enable_vae_slicing and (args.height >= 1024 and args.width >= 1024):
print(
"Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4 and resolution >= 1024."
)
enable_vae_slicing = True
if enable_vae_slicing:
(refiner or base).backend.enable_vae_slicing()
return base, refiner
def get_depth_image(image):
"""
Create depth map for SDXL depth control net.
"""
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = depth_estimator(image).predicted_depth
# The depth map is 384x384 by default, here we interpolate to the default output size.
# Note that it will be resized to output image size later. May change the size here to avoid interpolate twice.
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(1024, 1024),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
def get_canny_image(image) -> Image.Image:
"""
Create canny image for SDXL control net.
"""
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
return image
def process_controlnet_images_xl(args) -> List[Image.Image]:
"""
Process control image for SDXL control net.
"""
assert len(args.controlnet_image) == 1
image = Image.open(args.controlnet_image[0]).convert("RGB")
controlnet_images = []
if args.controlnet_type[0] == "canny":
controlnet_images.append(get_canny_image(image))
elif args.controlnet_type[0] == "depth":
controlnet_images.append(get_depth_image(image))
else:
raise ValueError(f"This controlnet type is not supported for SDXL or Turbo: {args.controlnet_type}.")
return controlnet_images
def add_controlnet_arguments(parser, is_xl: bool = False):
"""
Add control net related arguments.
"""
group = parser.add_argument_group("Options for ControlNet (supports 1.5, sd-turbo, xl-turbo, xl-1.0).")
group.add_argument(
"-ci",
"--controlnet-image",
nargs="*",
type=str,
default=[],
help="Path to the input regular RGB image/images for controlnet",
)
group.add_argument(
"-ct",
"--controlnet-type",
nargs="*",
type=str,
default=[],
choices=list(PipelineInfo.supported_controlnet("xl-1.0" if is_xl else "1.5").keys()),
help="A list of controlnet type",
)
group.add_argument(
"-cs",
"--controlnet-scale",
nargs="*",
type=float,
default=[],
help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original unet. Default is 0.5 for SDXL, or 1.0 for SD 1.5",
)
def process_controlnet_image(controlnet_type: str, image: Image.Image, height, width):
"""
Process control images of control net v1.1 for Stable Diffusion 1.5.
"""
control_image = None
shape = (height, width)
image = image.convert("RGB")
if controlnet_type == "canny":
canny_image = controlnet_aux.CannyDetector()(image)
control_image = canny_image.resize(shape)
elif controlnet_type == "normalbae":
normal_image = controlnet_aux.NormalBaeDetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = normal_image.resize(shape)
elif controlnet_type == "depth":
depth_image = controlnet_aux.LeresDetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = depth_image.resize(shape)
elif controlnet_type == "mlsd":
mlsd_image = controlnet_aux.MLSDdetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = mlsd_image.resize(shape)
elif controlnet_type == "openpose":
openpose_image = controlnet_aux.OpenposeDetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = openpose_image.resize(shape)
elif controlnet_type == "scribble":
scribble_image = controlnet_aux.HEDdetector.from_pretrained("lllyasviel/Annotators")(image, scribble=True)
control_image = scribble_image.resize(shape)
elif controlnet_type == "seg":
seg_image = controlnet_aux.SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")(
image
)
control_image = seg_image.resize(shape)
else:
raise ValueError(f"There is no demo image of this controlnet_type: {controlnet_type}")
return control_image
def process_controlnet_arguments(args):
"""
Process control net arguments, and returns a list of control images and a tensor of control net scales.
"""
assert isinstance(args.controlnet_type, list)
assert isinstance(args.controlnet_scale, list)
assert isinstance(args.controlnet_image, list)
if len(args.controlnet_image) != len(args.controlnet_type):
raise ValueError(
f"Numbers of controlnet_image {len(args.controlnet_image)} should be equal to number of controlnet_type {len(args.controlnet_type)}."
)
if len(args.controlnet_type) == 0:
return None, None
if args.version not in ["1.5", "xl-1.0", "xl-turbo", "sd-turbo"]:
raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.")
is_xl = "xl" in args.version
if is_xl and len(args.controlnet_type) > 1:
raise ValueError("This demo only support one ControlNet for Stable Diffusion XL or Turbo.")
if len(args.controlnet_scale) == 0:
args.controlnet_scale = [0.5 if is_xl else 1.0] * len(args.controlnet_type)
elif len(args.controlnet_type) != len(args.controlnet_scale):
raise ValueError(
f"Numbers of controlnet_type {len(args.controlnet_type)} should be equal to number of controlnet_scale {len(args.controlnet_scale)}."
)
# Convert controlnet scales to tensor
controlnet_scale = torch.FloatTensor(args.controlnet_scale)
if is_xl:
images = process_controlnet_images_xl(args)
else:
images = []
for i, image in enumerate(args.controlnet_image):
images.append(process_controlnet_image(args.controlnet_type[i], Image.open(image), args.height, args.width))
return images, controlnet_scale

View File

@ -0,0 +1,296 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import hashlib
import os
from enum import Enum
from typing import Optional
import torch
from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL
class EngineType(Enum):
ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider
ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider
TRT = 2 # TensorRT
TORCH = 3 # PyTorch
def get_engine_type(name: str) -> EngineType:
name_to_type = {
"ORT_CUDA": EngineType.ORT_CUDA,
"ORT_TRT": EngineType.ORT_TRT,
"TRT": EngineType.TRT,
"TORCH": EngineType.TORCH,
}
return name_to_type[name]
class EngineBuilder:
def __init__(
self,
engine_type: EngineType,
pipeline_info: PipelineInfo,
device="cuda",
max_batch_size=16,
use_cuda_graph=False,
):
"""
Initializes the Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
device (str | torch.device):
device to run engine
max_batch_size (int):
Maximum batch size for dynamic batch engine.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
self.engine_type = engine_type
self.pipeline_info = pipeline_info
self.max_batch_size = max_batch_size
self.use_cuda_graph = use_cuda_graph
self.device = torch.device(device)
self.torch_device = torch.device(device, torch.cuda.current_device())
self.stages = pipeline_info.stages()
self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() and self.engine_type != EngineType.TORCH
self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae()
self.models = {}
self.engines = {}
self.torch_models = {}
self.use_vae_slicing = False
self.torch_sdpa = getattr(torch.nn.functional, "scaled_dot_product_attention", None)
def enable_vae_slicing(self):
self.use_vae_slicing = True
def disable_torch_spda(self):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
delattr(torch.nn.functional, "scaled_dot_product_attention")
def enable_torch_spda(self):
if (not hasattr(torch.nn.functional, "scaled_dot_product_attention")) and self.torch_sdpa:
torch.nn.functional.scaled_dot_product_attention = self.torch_sdpa
def teardown(self):
for engine in self.engines.values():
del engine
self.engines = {}
def get_diffusers_module_name(self, model_name):
name_mapping = {
"clip": "text_encoder",
"clip2": "text_encoder_2",
"unet": "unet",
"unetxl": "unet",
"vae": "vae_decoder",
}
return name_mapping.get(model_name, model_name)
def get_cached_model_name(self, model_name):
model_name = self.get_diffusers_module_name(model_name)
is_unet = model_name == "unet"
hash_source = []
if model_name in ["text_encoder", "text_encoder_2", "unet"] and self.pipeline_info.lora_weights:
if self.pipeline_info.lora_weights in [
"latent-consistency/lcm-lora-sdxl",
"latent-consistency/lcm-lora-sdv1-5",
]:
if is_unet:
model_name = "unet_lcm-lora"
else:
model_name = model_name + "_lora"
hash_source.append(self.pipeline_info.lora_weights)
# TODO(tianleiwu): save custom model to a directory named by its original model.
if is_unet and self.pipeline_info.custom_unet():
model_name = model_name + "_lcm"
if model_name in ["unet"] and self.pipeline_info.controlnet:
model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet)
if hash_source:
model_name += "_" + hashlib.md5("\t".join(hash_source).encode("utf-8")).hexdigest()[:8]
# TODO: When we support original VAE, we shall save custom VAE to another directory.
if self.pipeline_info.is_inpaint():
model_name += "_inpaint"
return model_name
def get_model_dir(self, model_name, root_dir, opt=True, suffix="", create=True):
engine_name = self.engine_type.name.lower()
if engine_name != "ort_cuda" and not suffix:
suffix = f".{engine_name}" if opt else ""
directory_name = self.get_cached_model_name(model_name) + suffix
onnx_model_dir = os.path.join(root_dir, directory_name)
if create:
os.makedirs(onnx_model_dir, exist_ok=True)
return onnx_model_dir
def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""):
onnx_model_dir = self.get_model_dir(model_name, onnx_dir, opt=opt, suffix=suffix)
return os.path.join(onnx_model_dir, "model.onnx")
def get_engine_path(self, engine_dir, model_name, profile_id):
return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id)
def load_pipeline_with_lora(self):
"""Load text encoders and UNet with diffusers pipeline"""
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
self.pipeline_info.name(),
variant="fp16",
torch_dtype=torch.float16,
)
pipeline.load_lora_weights(self.pipeline_info.lora_weights)
pipeline.fuse_lora(lora_scale=self.pipeline_info.lora_scale)
del pipeline.vae
pipeline.vae = None
return pipeline
def get_or_load_model(self, pipeline, model_name, model_obj, framework_model_dir):
if model_name in ["clip", "clip2", "unet", "unetxl"] and pipeline:
if model_name == "clip":
model = pipeline.text_encoder
pipeline.text_encoder = None
elif model_name == "clip2":
model = pipeline.text_encoder_2
pipeline.text_encoder_2 = None
else:
model = pipeline.unet
pipeline.unet = None
else:
model = model_obj.load_model(framework_model_dir)
return model.to(self.torch_device)
def load_models(self, framework_model_dir: str):
# For TRT or ORT_TRT, we will export fp16 torch model for UNet and VAE
# For ORT_CUDA, we export fp32 model first, then optimize to fp16.
export_fp16 = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT]
if "clip" in self.stages:
self.models["clip"] = CLIP(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
max_batch_size=self.max_batch_size,
clip_skip=0,
)
if "clip2" in self.stages:
self.models["clip2"] = CLIPWithProj(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
max_batch_size=self.max_batch_size,
clip_skip=0,
)
if "unet" in self.stages:
self.models["unet"] = UNet(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
fp16=export_fp16,
max_batch_size=self.max_batch_size,
unet_dim=(9 if self.pipeline_info.is_inpaint() else 4),
)
if "unetxl" in self.stages:
self.models["unetxl"] = UNetXL(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
fp16=export_fp16,
max_batch_size=self.max_batch_size,
unet_dim=4,
time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6),
)
# VAE Decoder
if "vae" in self.stages:
self.models["vae"] = VAE(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
max_batch_size=self.max_batch_size,
fp16=export_fp16,
custom_fp16_vae=self.custom_fp16_vae,
)
if self.vae_torch_fallback:
self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir)
def load_resources(self, image_height, image_width, batch_size):
if self.engine_type == EngineType.TORCH:
return
# Allocate buffers for I/O bindings
for model_name, obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size
self.engines[model_name].allocate_buffers(
shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device
)
def _vae_decode(self, latents):
if self.engine_type == EngineType.TORCH:
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
latents = latents.to(dtype=torch.float32)
images = self.engines["vae"](latents)["sample"]
else:
images = self.engines["vae"](latents)["sample"]
elif self.vae_torch_fallback:
if not self.custom_fp16_vae:
latents = latents.to(dtype=torch.float32)
self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32)
images = self.torch_models["vae"](latents)["sample"]
else:
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
images = self.run_engine("vae", {"latent": latents.to(dtype=torch.float32)})["images"]
else:
images = self.run_engine("vae", {"latent": latents})["images"]
return images
def vae_decode(self, latents):
if self.use_vae_slicing:
# The output tensor points to same buffer. Need clone it to avoid overwritten.
decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)]
return torch.cat(decoded_slices)
return self._vae_decode(latents)
def get_engine_paths(
work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: Optional[str] = None
):
root_dir = work_dir or "."
short_name = pipeline_info.short_name()
# When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since
# ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model.
onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx")
engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine")
output_dir = os.path.join(root_dir, engine_type.name, short_name, "output")
timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache")
# Shared among ORT_CUDA, ORT_TRT and TRT engines, and need use load_model(..., always_download_fp16=True)
# So that the shared model is always fp16.
if framework_model_dir is None:
framework_model_dir = os.path.join(root_dir, "torch_model")
return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache

View File

@ -0,0 +1,388 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import gc
import logging
import os
from typing import Dict, List, Optional
import onnx
import torch
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from packaging import version
import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
from onnxruntime.transformers.onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class OrtCudaEngine:
def __init__(
self,
onnx_path,
device_id: int = 0,
enable_cuda_graph: bool = False,
disable_optimization: bool = False,
max_cuda_graphs: int = 1,
):
self.onnx_path = onnx_path
self.provider = "CUDAExecutionProvider"
self.stream = torch.cuda.current_stream().cuda_stream
self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph, self.stream)
session_options = ort.SessionOptions()
# When the model has been optimized by onnxruntime, we can disable optimization to save session creation time.
if disable_optimization:
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
logger.info("creating CUDA EP session for %s", onnx_path)
ort_session = ort.InferenceSession(
onnx_path,
session_options,
providers=[
(self.provider, self.provider_options),
"CPUExecutionProvider",
],
)
logger.info("created CUDA EP session for %s", onnx_path)
device = torch.device("cuda", device_id)
self.enable_cuda_graph = enable_cuda_graph
# Support multiple CUDA graphs for different input shapes.
# For clip2 model that disabled cuda graph, max_cuda_graphs is updated to 0 here.
self.gpu_binding_manager = GpuBindingManager(
ort_session=ort_session,
device=device,
stream=self.stream,
max_cuda_graphs=max_cuda_graphs if enable_cuda_graph else 0,
)
self.current_gpu_binding = None
def metadata(self, name: str):
data = {}
if self.current_gpu_binding is not None:
if self.current_gpu_binding.last_run_gpu_graph_id >= 0:
data[f"{name}.gpu_graph_id"] = self.current_gpu_binding.last_run_gpu_graph_id
return data
def infer(self, feed_dict: Dict[str, torch.Tensor]):
return self.current_gpu_binding.infer(feed_dict=feed_dict, disable_cuda_graph_in_run=not self.enable_cuda_graph)
def allocate_buffers(self, shape_dict, device):
self.current_gpu_binding = self.gpu_binding_manager.get_binding(
shape_dict=shape_dict, use_cuda_graph=self.enable_cuda_graph
)
class _ModelConfig:
"""
Configuration of one model (like Clip, UNet etc) on ONNX export and optimization for CUDA provider.
For example, if you want to use fp32 in layer normalization, set the following:
force_fp32_ops=["SkipLayerNormalization", "LayerNormalization"]
"""
def __init__(
self,
onnx_opset_version: int,
use_cuda_graph: bool,
fp16: bool = True,
force_fp32_ops: Optional[List[str]] = None,
optimize_by_ort: bool = True,
):
self.onnx_opset_version = onnx_opset_version
self.use_cuda_graph = use_cuda_graph
self.fp16 = fp16
self.force_fp32_ops = force_fp32_ops
self.optimize_by_ort = optimize_by_ort
class OrtCudaEngineBuilder(EngineBuilder):
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.ORT_CUDA,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
self.model_config = {}
def _configure(
self,
model_name: str,
onnx_opset_version: int,
use_cuda_graph: bool,
fp16: bool = True,
force_fp32_ops: Optional[List[str]] = None,
optimize_by_ort: bool = True,
):
self.model_config[model_name] = _ModelConfig(
onnx_opset_version,
use_cuda_graph,
fp16=fp16,
force_fp32_ops=force_fp32_ops,
optimize_by_ort=optimize_by_ort,
)
def configure_xl(self, onnx_opset_version: int):
self._configure(
"clip",
onnx_opset_version=onnx_opset_version,
use_cuda_graph=self.use_cuda_graph,
)
self._configure(
"clip2",
onnx_opset_version=onnx_opset_version, # TODO: ArgMax-12 is not implemented in CUDA
use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph
)
self._configure(
"unetxl",
onnx_opset_version=onnx_opset_version,
use_cuda_graph=self.use_cuda_graph,
)
self._configure(
"vae",
onnx_opset_version=onnx_opset_version,
use_cuda_graph=self.use_cuda_graph,
)
def optimized_onnx_path(self, engine_dir, model_name):
suffix = "" if self.model_config[model_name].fp16 else ".fp32"
return self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix)
def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str):
"""Import optimized onnx models for diffusers from Olive or optimize_pipeline tools.
Args:
diffusers_onnx_dir (str): optimized onnx directory of Olive
engine_dir (str): the directory to store imported onnx
"""
if version.parse(ort.__version__) < version.parse("1.17.0"):
print("Skip importing since onnxruntime-gpu version < 1.17.0.")
return
for model_name, model_obj in self.models.items():
onnx_import_path = self.optimized_onnx_path(diffusers_onnx_dir, model_name)
if not os.path.exists(onnx_import_path):
print(f"{onnx_import_path} not existed. Skip importing.")
continue
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
if os.path.exists(onnx_opt_path):
print(f"{onnx_opt_path} existed. Skip importing.")
continue
if model_name == "vae" and self.pipeline_info.is_xl():
print(f"Skip importing VAE since it is not fully compatible with float16: {onnx_import_path}.")
continue
model = OnnxModel(onnx.load(onnx_import_path, load_external_data=True))
if model_name in ["clip", "clip2"]:
hidden_states_per_layer = []
for output in model.graph().output:
if output.name.startswith("hidden_states."):
hidden_states_per_layer.append(output.name)
if hidden_states_per_layer:
kept_hidden_states = hidden_states_per_layer[-2 - model_obj.clip_skip]
model.rename_graph_output(kept_hidden_states, "hidden_states")
model.rename_graph_output(
"last_hidden_state" if model_name == "clip" else "text_embeds", "text_embeddings"
)
model.prune_graph(
["text_embeddings", "hidden_states"] if hidden_states_per_layer else ["text_embeddings"]
)
if model_name == "clip2":
model.change_graph_input_type(model.find_graph_input("input_ids"), onnx.TensorProto.INT32)
model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2"))
elif model_name in ["unet", "unetxl"]:
model.rename_graph_output("out_sample", "latent")
model.save_model_to_file(onnx_opt_path, use_external_data_format=True)
del model
continue
def build_engines(
self,
engine_dir: str,
framework_model_dir: str,
onnx_dir: str,
tmp_dir: Optional[str] = None,
onnx_opset_version: int = 17,
device_id: int = 0,
save_fp32_intermediate_model: bool = False,
import_engine_dir: Optional[str] = None,
max_cuda_graphs: int = 1,
):
self.torch_device = torch.device("cuda", device_id)
self.load_models(framework_model_dir)
if not os.path.isdir(engine_dir):
os.makedirs(engine_dir)
if not os.path.isdir(onnx_dir):
os.makedirs(onnx_dir)
# Add default configuration if missing
if self.pipeline_info.is_xl():
self.configure_xl(onnx_opset_version)
for model_name in self.models:
if model_name not in self.model_config:
self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph)
# Import Engine
if import_engine_dir:
if self.pipeline_info.is_xl():
self.import_diffusers_engine(import_engine_dir, engine_dir)
else:
print(f"Only support importing SDXL onnx. Ignore --engine-dir {import_engine_dir}")
# Load lora only when we need export text encoder or UNet to ONNX.
load_lora = False
if self.pipeline_info.lora_weights:
for model_name in self.models:
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
continue
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
load_lora = True
break
# Export models to ONNX
self.disable_torch_spda()
pipe = self.load_pipeline_with_lora() if load_lora else None
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
print("----")
logger.info("Exporting model: %s", onnx_path)
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
model = model.to(torch.float32)
with torch.inference_mode():
# For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern.
# Export model with sample of batch size 1, image size 512 x 512
inputs = model_obj.get_sample_input(1, 512, 512)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=self.model_config[model_name].onnx_opset_version,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
else:
logger.info("Found cached model: %s", onnx_path)
# Generate fp32 optimized model.
# If final target is fp16 model, we save fp32 optimized model so that it is easy to tune
# fp16 conversion. That could save a lot of time in developing.
use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16
onnx_fp32_path = onnx_path
if use_fp32_intermediate:
onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32")
if not os.path.exists(onnx_fp32_path):
print("------")
logger.info("Generating optimized model: %s", onnx_fp32_path)
model_obj.optimize_ort(
onnx_path,
onnx_fp32_path,
to_fp16=False,
fp32_op_list=self.model_config[model_name].force_fp32_ops,
optimize_by_ort=self.model_config[model_name].optimize_by_ort,
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp32", create=False),
)
else:
logger.info("Found cached optimized model: %s", onnx_fp32_path)
# Generate the final optimized model.
if not os.path.exists(onnx_opt_path):
print("------")
logger.info("Generating optimized model: %s", onnx_opt_path)
# When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16.
optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort
model_obj.optimize_ort(
onnx_fp32_path,
onnx_opt_path,
to_fp16=self.model_config[model_name].fp16,
fp32_op_list=self.model_config[model_name].force_fp32_ops,
optimize_by_ort=optimize_by_ort,
optimize_by_fusion=not use_fp32_intermediate,
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".ort", create=False),
)
else:
logger.info("Found cached optimized model: %s", onnx_opt_path)
self.enable_torch_spda()
built_engines = {}
for model_name in self.models:
if model_name == "vae" and self.vae_torch_fallback:
continue
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
use_cuda_graph = self.model_config[model_name].use_cuda_graph
engine = OrtCudaEngine(
onnx_opt_path,
device_id=device_id,
enable_cuda_graph=use_cuda_graph,
disable_optimization=False,
max_cuda_graphs=max_cuda_graphs,
)
logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options)
built_engines[model_name] = engine
self.engines = built_engines
def run_engine(self, model_name, feed_dict):
return self.engines[model_name].infer(feed_dict)

View File

@ -0,0 +1,288 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import gc
import logging
import os
import torch
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from packaging import version
import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import CudaSession
logger = logging.getLogger(__name__)
class OrtTensorrtEngine(CudaSession):
def __init__(
self,
engine_path,
device_id,
onnx_path,
fp16,
input_profile,
workspace_size,
enable_cuda_graph,
timing_cache_path=None,
):
self.engine_path = engine_path
self.ort_trt_provider_options = self.get_tensorrt_provider_options(
input_profile,
workspace_size,
fp16,
device_id,
enable_cuda_graph,
timing_cache_path=timing_cache_path,
)
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
logger.info("creating TRT EP session for %s", onnx_path)
ort_session = ort.InferenceSession(
onnx_path,
session_options,
providers=[
("TensorrtExecutionProvider", self.ort_trt_provider_options),
],
)
logger.info("created TRT EP session for %s", onnx_path)
device = torch.device("cuda", device_id)
super().__init__(ort_session, device, enable_cuda_graph)
def get_tensorrt_provider_options(
self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph, timing_cache_path=None
):
trt_ep_options = {
"device_id": device_id,
"trt_fp16_enable": fp16,
"trt_engine_cache_enable": True,
"trt_timing_cache_enable": True,
"trt_detailed_build_log": True,
"trt_engine_cache_path": self.engine_path,
}
if version.parse(ort.__version__) > version.parse("1.16.2") and timing_cache_path is not None:
trt_ep_options["trt_timing_cache_path"] = timing_cache_path
if enable_cuda_graph:
trt_ep_options["trt_cuda_graph_enable"] = True
if workspace_size > 0:
trt_ep_options["trt_max_workspace_size"] = workspace_size
if input_profile:
min_shapes = []
max_shapes = []
opt_shapes = []
for name, profile in input_profile.items():
assert isinstance(profile, list) and len(profile) == 3
min_shape = profile[0]
opt_shape = profile[1]
max_shape = profile[2]
assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape)
min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape]))
opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape]))
max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape]))
trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes)
trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes)
trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes)
logger.info("trt_ep_options=%s", trt_ep_options)
return trt_ep_options
def allocate_buffers(self, shape_dict, device):
super().allocate_buffers(shape_dict)
class OrtTensorrtEngineBuilder(EngineBuilder):
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.ORT_TRT,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
def has_engine_file(self, engine_path):
if os.path.isdir(engine_path):
children = os.scandir(engine_path)
for entry in children:
if entry.is_file() and entry.name.endswith(".engine"):
return True
return False
def get_work_space_size(self, model_name, max_workspace_size):
gibibyte = 2**30
workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size
if workspace_size == 0:
_, free_mem, _ = cudart.cudaMemGetInfo()
# The following logic are adopted from TensorRT demo diffusion.
if free_mem > 6 * gibibyte:
workspace_size = free_mem - 4 * gibibyte
return workspace_size
def build_engines(
self,
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_image_height,
opt_image_width,
opt_batch_size=1,
static_batch=False,
static_image_shape=True,
max_workspace_size=0,
device_id=0,
timing_cache=None,
):
self.torch_device = torch.device("cuda", device_id)
self.load_models(framework_model_dir)
if not os.path.isdir(engine_dir):
os.makedirs(engine_dir)
if not os.path.isdir(onnx_dir):
os.makedirs(onnx_dir)
# Load lora only when we need export text encoder or UNet to ONNX.
load_lora = False
if self.pipeline_info.lora_weights:
for model_name, model_obj in self.models.items():
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not self.has_engine_file(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
load_lora = True
break
# Export models to ONNX
self.disable_torch_spda()
pipe = self.load_pipeline_with_lora() if load_lora else None
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not self.has_engine_file(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
logger.info(f"Exporting model: {onnx_path}")
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
with torch.inference_mode(), torch.autocast("cuda"):
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=onnx_opset,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
else:
logger.info("Found cached model: %s", onnx_path)
# Optimize onnx
if not os.path.exists(onnx_opt_path):
logger.info("Generating optimizing model: %s", onnx_opt_path)
model_obj.optimize_trt(onnx_path, onnx_opt_path)
else:
logger.info("Found cached optimized model: %s", onnx_opt_path)
self.enable_torch_spda()
built_engines = {}
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not self.has_engine_file(engine_path):
logger.info(
"Building TensorRT engine for %s from %s to %s. It can take a while to complete...",
model_name,
onnx_opt_path,
engine_path,
)
else:
logger.info("Reuse cached TensorRT engine in directory %s", engine_path)
input_profile = model_obj.get_input_profile(
opt_batch_size,
opt_image_height,
opt_image_width,
static_batch=static_batch,
static_image_shape=static_image_shape,
)
engine = OrtTensorrtEngine(
engine_path,
device_id,
onnx_opt_path,
fp16=True,
input_profile=input_profile,
workspace_size=self.get_work_space_size(model_name, max_workspace_size),
enable_cuda_graph=self.use_cuda_graph,
timing_cache_path=timing_cache,
)
built_engines[model_name] = engine
self.engines = built_engines
def run_engine(self, model_name, feed_dict):
return self.engines[model_name].infer(feed_dict)

View File

@ -0,0 +1,395 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import gc
import os
import pathlib
from collections import OrderedDict
import numpy as np
import tensorrt as trt
import torch
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import (
CreateConfig,
ModifyNetworkOutputs,
Profile,
engine_from_bytes,
engine_from_network,
network_from_onnx_path,
save_engine,
)
# Map of numpy dtype -> torch dtype
numpy_to_torch_dtype_dict = {
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
}
def _cuda_assert(cuda_ret):
err = cuda_ret[0]
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
)
if len(cuda_ret) > 1:
return cuda_ret[1]
return None
class TensorrtEngine:
def __init__(
self,
engine_path,
):
self.engine_path = engine_path
self.engine = None
self.context = None
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.cuda_graph_instance = None
def __del__(self):
del self.engine
del self.context
del self.buffers
del self.tensors
def build(
self,
onnx_path,
fp16,
input_profile=None,
enable_all_tactics=False,
timing_cache=None,
update_output_names=None,
):
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
if input_profile:
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {}
if not enable_all_tactics:
config_kwargs["tactic_sources"] = []
network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
if update_output_names:
print(f"Updating network outputs to {update_output_names}")
network = ModifyNetworkOutputs(network, update_output_names)
engine = engine_from_network(
network,
config=CreateConfig(
fp16=fp16, refittable=False, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
),
save_timing_cache=timing_cache,
)
save_engine(engine, path=self.engine_path)
def load(self):
print(f"Loading TensorRT engine: {self.engine_path}")
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
def activate(self, reuse_device_memory=None):
if reuse_device_memory:
self.context = self.engine.create_execution_context_without_device_memory()
self.context.device_memory = reuse_device_memory
else:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(self.engine.num_io_tensors):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_binding_shape(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor
def infer(self, feed_dict, stream, use_cuda_graph=False):
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
if use_cuda_graph:
if self.cuda_graph_instance is not None:
_cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
_cuda_assert(cudart.cudaStreamSynchronize(stream))
else:
# do inference before CUDA graph capture
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
# capture cuda graph
_cuda_assert(
cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
)
self.context.execute_async_v3(stream)
self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream))
from cuda import nvrtc
result, major, minor = nvrtc.nvrtcVersion()
assert result == nvrtc.nvrtcResult(0)
if major < 12:
self.cuda_graph_instance = _cuda_assert(
cudart.cudaGraphInstantiate(self.graph, b"", 0)
) # cuda < 12
else:
self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12
else:
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
return self.tensors
class TensorrtEngineBuilder(EngineBuilder):
"""
Helper class to hide the detail of TensorRT Engine from pipeline.
"""
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.TRT,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
self.stream = None
self.shared_device_memory = None
def load_resources(self, image_height, image_width, batch_size):
super().load_resources(image_height, image_width, batch_size)
self.stream = _cuda_assert(cudart.cudaStreamCreate())
def teardown(self):
super().teardown()
if self.shared_device_memory:
cudart.cudaFree(self.shared_device_memory)
cudart.cudaStreamDestroy(self.stream)
del self.stream
def load_engines(
self,
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_batch_size,
opt_image_height,
opt_image_width,
static_batch=False,
static_shape=True,
enable_all_tactics=False,
timing_cache=None,
):
"""
Build and load engines for TensorRT accelerated inference.
Export ONNX models first, if applicable.
Args:
engine_dir (str):
Directory to write the TensorRT engines.
framework_model_dir (str):
Directory to write the framework model ckpt.
onnx_dir (str):
Directory to write the ONNX models.
onnx_opset (int):
ONNX opset version to export the models.
opt_batch_size (int):
Batch size to optimize for during engine building.
opt_image_height (int):
Image height to optimize for during engine building. Must be a multiple of 8.
opt_image_width (int):
Image width to optimize for during engine building. Must be a multiple of 8.
static_batch (bool):
Build engine only for specified opt_batch_size.
static_shape (bool):
Build engine only for specified opt_image_height & opt_image_width. Default = True.
enable_all_tactics (bool):
Enable all tactic sources during TensorRT engine builds.
timing_cache (str):
Path to the timing cache to accelerate build or None
"""
# Create directory
for directory in [engine_dir, onnx_dir]:
if not os.path.exists(directory):
print(f"[I] Create directory: {directory}")
pathlib.Path(directory).mkdir(parents=True)
self.load_models(framework_model_dir)
# Load lora only when we need export text encoder or UNet to ONNX.
load_lora = False
if self.pipeline_info.lora_weights:
for model_name, model_obj in self.models.items():
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not os.path.exists(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
load_lora = True
break
# Export models to ONNX
self.disable_torch_spda()
pipe = self.load_pipeline_with_lora() if load_lora else None
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not os.path.exists(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
print(f"Exporting model: {onnx_path}")
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
with torch.inference_mode(), torch.autocast("cuda"):
inputs = model_obj.get_sample_input(1, opt_image_height, opt_image_width)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=onnx_opset,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
else:
print(f"Found cached model: {onnx_path}")
# Optimize onnx
if not os.path.exists(onnx_opt_path):
print(f"Generating optimizing model: {onnx_opt_path}")
model_obj.optimize_trt(onnx_path, onnx_opt_path)
else:
print(f"Found cached optimized model: {onnx_opt_path} ")
self.enable_torch_spda()
# Build TensorRT engines
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
engine = TensorrtEngine(engine_path)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(engine.engine_path):
engine.build(
onnx_opt_path,
fp16=True,
input_profile=model_obj.get_input_profile(
opt_batch_size,
opt_image_height,
opt_image_width,
static_batch,
static_shape,
),
enable_all_tactics=enable_all_tactics,
timing_cache=timing_cache,
update_output_names=None,
)
self.engines[model_name] = engine
# Load TensorRT engines
for model_name in self.models:
if model_name == "vae" and self.vae_torch_fallback:
continue
self.engines[model_name].load()
def max_device_memory(self):
max_device_memory = 0
for engine in self.engines.values():
max_device_memory = max(max_device_memory, engine.engine.device_memory_size)
return max_device_memory
def activate_engines(self, shared_device_memory=None):
if shared_device_memory is None:
max_device_memory = self.max_device_memory()
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
self.shared_device_memory = shared_device_memory
# Load and activate TensorRT engines
for engine in self.engines.values():
engine.activate(reuse_device_memory=self.shared_device_memory)
def run_engine(self, model_name, feed_dict):
return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph)

View File

@ -0,0 +1,108 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
logger = logging.getLogger(__name__)
class TorchEngineBuilder(EngineBuilder):
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.TORCH,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
self.compile_config = {}
if use_cuda_graph:
self.compile_config = {
"clip": {"mode": "reduce-overhead", "dynamic": False},
"clip2": {"mode": "reduce-overhead", "dynamic": False},
"unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
"unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
"vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False},
}
def build_engines(
self,
framework_model_dir: str,
):
import torch
self.torch_device = torch.device("cuda", torch.cuda.current_device())
self.load_models(framework_model_dir)
pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None
built_engines = {}
for model_name, model_obj in self.models.items():
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
model = model.to(device=self.torch_device, dtype=torch.float32)
else:
model = model.to(device=self.torch_device, dtype=torch.float16)
if model_name in self.compile_config:
compile_config = self.compile_config[model_name]
if model_name in ["unet", "unetxl"]:
model.to(memory_format=torch.channels_last)
engine = torch.compile(model, **compile_config)
built_engines[model_name] = engine
else: # eager mode
built_engines[model_name] = model
self.engines = built_engines
def run_engine(self, model_name, feed_dict):
if model_name in ["unet", "unetxl"]:
if "controlnet_images" in feed_dict:
return {"latent": self.engines[model_name](**feed_dict)}
if model_name == "unetxl":
added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
return {
"latent": self.engines[model_name](
feed_dict["sample"],
feed_dict["timestep"],
feed_dict["encoder_hidden_states"],
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
}
return {
"latent": self.engines[model_name](
feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
)[0]
}
if model_name in ["vae_encoder"]:
return {"latent": self.engines[model_name](feed_dict["images"])}
raise RuntimeError(f"Shall not reach here: {model_name}")

View File

@ -0,0 +1,350 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
#
# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint
# to float32 onnx models.
#
# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16
# like the following:
# python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16
#
# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support
# for the fused operators. The users could disable the operator fusion manually to workaround.
import argparse
import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional
import __init__ # noqa: F401. Walk-around to run this script directly
import coloredlogs
import onnx
from fusion_options import FusionOptions
from onnx_model_clip import ClipOnnxModel
from onnx_model_unet import UnetOnnxModel
from onnx_model_vae import VaeOnnxModel
from optimizer import optimize_by_onnxruntime, optimize_model
from packaging import version
import onnxruntime
logger = logging.getLogger(__name__)
def has_external_data(onnx_model_path):
original_model = onnx.load_model(str(onnx_model_path), load_external_data=False)
for initializer in original_model.graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL:
return True
return False
def _optimize_sd_pipeline(
source_dir: Path,
target_dir: Path,
use_external_data_format: Optional[bool],
float16: bool,
force_fp32_ops: List[str],
enable_runtime_optimization: bool,
args,
):
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
Args:
source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
use_external_data_format (Optional[bool]): use external data format.
float16 (bool): use half precision
force_fp32_ops(List[str]): operators that are forced to run in float32.
enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
Raises:
RuntimeError: input onnx model does not exist
RuntimeError: output onnx model path existed
"""
model_type_mapping = {
"unet": "unet",
"vae_encoder": "vae",
"vae_decoder": "vae",
"text_encoder": "clip",
"text_encoder_2": "clip",
"safety_checker": "unet",
}
model_type_class_mapping = {
"unet": UnetOnnxModel,
"vae": VaeOnnxModel,
"clip": ClipOnnxModel,
}
force_fp32_operators = {
"unet": [],
"vae_encoder": [],
"vae_decoder": [],
"text_encoder": [],
"text_encoder_2": [],
"safety_checker": [],
}
is_xl = (source_dir / "text_encoder_2").exists()
if force_fp32_ops:
for fp32_operator in force_fp32_ops:
parts = fp32_operator.split(":")
if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
force_fp32_operators[parts[0]].append(parts[1])
else:
raise ValueError(
f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
)
for name, model_type in model_type_mapping.items():
onnx_model_path = source_dir / name / "model.onnx"
if not os.path.exists(onnx_model_path):
if name != "safety_checker":
logger.info("input onnx model does not exist: %s", onnx_model_path)
# some model are optional so we do not raise error here.
continue
# Prepare output directory
optimized_model_path = target_dir / name / "model.onnx"
output_dir = optimized_model_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
if use_external_data_format is None:
use_external_data_format = has_external_data(onnx_model_path)
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
logger.info(f"Optimize {onnx_model_path}...")
args.model_type = model_type
fusion_options = FusionOptions.parse(args)
if model_type in ["unet"]:
# Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd
has_all_optimizations = version.parse(onnxruntime.__version__) >= version.parse("1.15.0")
fusion_options.enable_packed_kv = float16 and fusion_options.enable_packed_kv
fusion_options.enable_packed_qkv = float16 and has_all_optimizations and fusion_options.enable_packed_qkv
fusion_options.enable_bias_add = has_all_optimizations and fusion_options.enable_bias_add
m = optimize_model(
str(onnx_model_path),
model_type=model_type,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
use_gpu=True,
provider=args.provider,
)
if float16:
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
if is_xl and name == "vae_decoder":
logger.info("Skip converting %s to float16 to avoid NaN", name)
else:
logger.info("Convert %s to float16 ...", name)
m.convert_float_to_float16(
keep_io_types=False,
op_block_list=force_fp32_operators[name],
)
if enable_runtime_optimization:
# Use this step to see the final graph that executed by Onnx Runtime.
with tempfile.TemporaryDirectory() as tmp_dir:
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
tmp_model_path = Path(tmp_dir) / "model.onnx"
m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
optimize_by_onnxruntime(
str(tmp_model_path),
use_gpu=True,
provider=args.provider,
optimized_model_path=str(ort_optimized_model_path),
save_as_external_data=use_external_data_format,
)
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
m = model_type_class_mapping[model_type](model)
m.get_operator_statistics()
m.get_fused_operator_statistics()
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
logger.info("%s is optimized", name)
logger.info("*" * 20)
def _copy_extra_directory(source_dir: Path, target_dir: Path):
"""Copy extra directory that does not have onnx model
Args:
source_dir (Path): source directory
target_dir (Path): target directory
Raises:
RuntimeError: source path does not exist
"""
extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"]
for name in extra_dirs:
source_path = source_dir / name
if not os.path.exists(source_path):
continue
target_path = target_dir / name
shutil.copytree(source_path, target_path)
logger.info("%s => %s", source_path, target_path)
extra_files = ["model_index.json"]
for name in extra_files:
source_path = source_dir / name
if not os.path.exists(source_path):
raise RuntimeError(f"source path does not exist: {source_path}")
target_path = target_dir / name
shutil.copyfile(source_path, target_path)
logger.info("%s => %s", source_path, target_path)
# Some directory are optional
onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"]
for onnx_model_dir in onnx_model_dirs:
source_path = source_dir / onnx_model_dir / "config.json"
target_path = target_dir / onnx_model_dir / "config.json"
if source_path.exists():
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(source_path, target_path)
logger.info("%s => %s", source_path, target_path)
def optimize_stable_diffusion_pipeline(
input_dir: str,
output_dir: str,
overwrite: bool,
use_external_data_format: Optional[bool],
float16: bool,
enable_runtime_optimization: bool,
args,
):
if os.path.exists(output_dir):
if overwrite:
shutil.rmtree(output_dir, ignore_errors=True)
else:
raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.")
source_dir = Path(input_dir)
target_dir = Path(output_dir)
target_dir.mkdir(parents=True, exist_ok=True)
_copy_extra_directory(source_dir, target_dir)
_optimize_sd_pipeline(
source_dir,
target_dir,
use_external_data_format,
float16,
args.force_fp32_ops,
enable_runtime_optimization,
args,
)
def parse_arguments(argv: Optional[List[str]] = None):
"""Parse arguments
Returns:
Namespace: arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input",
required=True,
type=str,
help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
)
parser.add_argument(
"-o",
"--output",
required=True,
type=str,
help="Root of output directory of stable diffusion onnx pipeline with optimized models.",
)
parser.add_argument(
"--float16",
required=False,
action="store_true",
help="Output models of half or mixed precision.",
)
parser.set_defaults(float16=False)
parser.add_argument(
"--force_fp32_ops",
required=False,
nargs="+",
type=str,
help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
)
parser.add_argument(
"--inspect",
required=False,
action="store_true",
help="Save the optimized graph from Onnx Runtime. "
"This option has no impact on inference performance except it might reduce session creation time.",
)
parser.set_defaults(inspect=False)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite exists files.",
)
parser.set_defaults(overwrite=False)
parser.add_argument(
"-e",
"--use_external_data_format",
required=False,
action="store_true",
help="Onnx model larger than 2GB need to use external data format. "
"If specified, save each onnx model to two files: one for onnx graph, another for weights. "
"If not specified, use same format as original model by default. ",
)
parser.set_defaults(use_external_data_format=None)
parser.add_argument(
"--provider",
required=False,
type=str,
default=None,
help="Execution provider to use.",
)
FusionOptions.add_arguments(parser)
args = parser.parse_args(argv)
return args
def main(argv: Optional[List[str]] = None):
args = parse_arguments(argv)
logger.info("Arguments: %s", str(args))
optimize_stable_diffusion_pipeline(
args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args
)
if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
main()

View File

@ -0,0 +1,136 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
ONNX Model Optimizer for Stable Diffusion
"""
import gc
import logging
import os
import shutil
import tempfile
from pathlib import Path
import onnx
from packaging import version
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
logger = logging.getLogger(__name__)
class OrtStableDiffusionOptimizer:
def __init__(self, model_type: str):
assert model_type in ["vae", "unet", "clip"]
self.model_type = model_type
self.model_type_class_mapping = {
"unet": UnetOnnxModel,
"vae": VaeOnnxModel,
"clip": ClipOnnxModel,
}
def _optimize_by_ort(self, onnx_model, use_external_data_format, tmp_dir):
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
tmp_model_path = Path(tmp_dir) / "model.onnx"
onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
del onnx_model
gc.collect()
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
optimize_by_onnxruntime(
str(tmp_model_path),
use_gpu=True,
optimized_model_path=str(ort_optimized_model_path),
save_as_external_data=use_external_data_format,
external_data_filename="optimized.onnx_data",
)
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
return self.model_type_class_mapping[self.model_type](model)
def optimize_by_ort(self, onnx_model, use_external_data_format=False, tmp_dir=None):
# Use this step to see the final graph that executed by Onnx Runtime.
if tmp_dir is None:
with tempfile.TemporaryDirectory() as temp_dir:
return self._optimize_by_ort(onnx_model, use_external_data_format, temp_dir)
else:
os.makedirs(tmp_dir, exist_ok=True)
model = self._optimize_by_ort(onnx_model, use_external_data_format, tmp_dir)
shutil.rmtree(tmp_dir)
return model
def optimize(
self,
input_fp32_onnx_path,
optimized_onnx_path,
float16=True,
keep_io_types=False,
fp32_op_list=None,
keep_outputs=None,
optimize_by_ort=True,
optimize_by_fusion=True,
final_target_float16=True,
tmp_dir=None,
):
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
logger.info(f"Optimize {input_fp32_onnx_path}...")
if optimize_by_fusion:
fusion_options = FusionOptions(self.model_type)
# It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step.
# For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel.
if self.model_type in ["unet"] and not final_target_float16:
fusion_options.enable_packed_kv = False
fusion_options.enable_packed_qkv = False
m = optimize_model(
input_fp32_onnx_path,
model_type=self.model_type,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
use_gpu=True,
)
else:
model = onnx.load_model(input_fp32_onnx_path, load_external_data=True)
m = self.model_type_class_mapping[self.model_type](model)
if keep_outputs:
m.prune_graph(outputs=keep_outputs)
model_size = m.model.ByteSize()
# model size might be negative (overflow?) in Windows.
use_external_data_format = model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF
# Note that ORT < 1.16 could not save model larger than 2GB.
# This step is is optional since it has no impact on inference latency.
# The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case).
# When the model has been optimized by onnxruntime, we can disable optimization in SessionOption
# to save session creation time. Another benefit is to inspect the final graph for developing purpose.
from onnxruntime import __version__ as ort_version
if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format):
m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format, tmp_dir=tmp_dir)
if float16:
logger.info("Convert to float16 ...")
m.convert_float_to_float16(
keep_io_types=keep_io_types,
op_block_list=fp32_op_list,
)
m.get_operator_statistics()
m.get_fused_operator_statistics()
m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)

View File

@ -0,0 +1,831 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import os
import pathlib
import random
import time
from typing import Any, Dict, List, Optional
import numpy as np
import nvtx
import torch
from cuda import cudart
from diffusion_models import PipelineInfo, get_tokenizer
from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler
from engine_builder import EngineType
from engine_builder_ort_cuda import OrtCudaEngineBuilder
from engine_builder_ort_trt import OrtTensorrtEngineBuilder
from engine_builder_tensorrt import TensorrtEngineBuilder
from engine_builder_torch import TorchEngineBuilder
from PIL import Image
class StableDiffusionPipeline:
"""
Stable Diffusion pipeline using TensorRT.
"""
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
scheduler="DDIM",
device="cuda",
output_dir=".",
verbose=False,
nvtx_profile=False,
use_cuda_graph=False,
framework_model_dir="pytorch_model",
engine_type: EngineType = EngineType.ORT_CUDA,
):
"""
Initializes the Diffusion pipeline.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
scheduler (str):
The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM].
device (str):
PyTorch device to run inference. Default: 'cuda'
output_dir (str):
Output directory for log files and image artifacts
verbose (bool):
Enable verbose logging.
nvtx_profile (bool):
Insert NVTX profiling markers.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
framework_model_dir (str):
cache directory for framework checkpoints
engine_type (EngineType)
backend engine type like ORT_TRT or TRT
"""
self.pipeline_info = pipeline_info
self.version = pipeline_info.version
self.vae_scaling_factor = pipeline_info.vae_scaling_factor()
self.max_batch_size = max_batch_size
self.framework_model_dir = framework_model_dir
self.output_dir = output_dir
for directory in [self.framework_model_dir, self.output_dir]:
if not os.path.exists(directory):
print(f"[I] Create directory: {directory}")
pathlib.Path(directory).mkdir(parents=True)
self.device = device
self.torch_device = torch.device(device, torch.cuda.current_device())
self.verbose = verbose
self.nvtx_profile = nvtx_profile
self.use_cuda_graph = use_cuda_graph
self.tokenizer = None
self.tokenizer2 = None
self.generator = torch.Generator(device="cuda")
self.actual_steps = None
self.current_scheduler = None
self.set_scheduler(scheduler)
# backend engine
self.engine_type = engine_type
if engine_type == EngineType.TRT:
self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
elif engine_type == EngineType.ORT_TRT:
self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
elif engine_type == EngineType.ORT_CUDA:
self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
elif engine_type == EngineType.TORCH:
self.backend = TorchEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
else:
raise RuntimeError(f"Backend engine type {engine_type.name} is not supported")
# Load text tokenizer
if not self.pipeline_info.is_xl_refiner():
self.tokenizer = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer")
if self.pipeline_info.is_xl():
self.tokenizer2 = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer_2")
self.control_image_processor = None
if self.pipeline_info.is_xl() and self.pipeline_info.controlnet:
from diffusers.image_processor import VaeImageProcessor
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=8, do_convert_rgb=True, do_normalize=False
)
# Create CUDA events
self.events = {}
for stage in ["clip", "denoise", "vae", "vae_encoder", "pil"]:
for marker in ["start", "stop"]:
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
self.markers = {}
def is_backend_tensorrt(self):
return self.engine_type == EngineType.TRT
def set_scheduler(self, scheduler: str):
if scheduler == self.current_scheduler:
return
# Scheduler options
sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012}
if self.version in ("2.0", "2.1"):
sched_opts["prediction_type"] = "v_prediction"
else:
sched_opts["prediction_type"] = "epsilon"
if scheduler == "DDIM":
self.scheduler = DDIMScheduler(device=self.device, **sched_opts)
elif scheduler == "EulerA":
self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts)
elif scheduler == "UniPC":
self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts)
elif scheduler == "LCM":
self.scheduler = LCMScheduler(device=self.device, **sched_opts)
else:
raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM")
self.current_scheduler = scheduler
self.denoising_steps = None
def set_denoising_steps(self, denoising_steps: int):
if not (self.denoising_steps == denoising_steps and isinstance(self.scheduler, DDIMScheduler)):
self.scheduler.set_timesteps(denoising_steps)
self.scheduler.configure()
self.denoising_steps = denoising_steps
def load_resources(self, image_height, image_width, batch_size):
# If engine is built with static input shape, call this only once after engine build.
# Otherwise, it need be called before every inference run.
self.backend.load_resources(image_height, image_width, batch_size)
def set_random_seed(self, seed):
if isinstance(seed, int):
self.generator.manual_seed(seed)
else:
self.generator.seed()
def get_current_seed(self):
return self.generator.initial_seed()
def teardown(self):
for e in self.events.values():
cudart.cudaEventDestroy(e)
if self.backend:
self.backend.teardown()
def run_engine(self, model_name, feed_dict):
return self.backend.run_engine(model_name, feed_dict)
def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width):
latents_dtype = torch.float16
latents_shape = (batch_size, unet_channels, latent_height, latent_width)
latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator)
# Scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def initialize_timesteps(self, timesteps, strength):
"""Initialize timesteps for refiner."""
self.scheduler.set_timesteps(timesteps)
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
init_timestep = int(timesteps * strength) + offset
init_timestep = min(init_timestep, timesteps)
t_start = max(timesteps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
return timesteps, t_start
def initialize_refiner(self, batch_size, image, strength):
"""Add noise to a reference image."""
# Initialize timesteps
timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength)
latent_timestep = timesteps[:1].repeat(batch_size)
# Pre-process input image
image = self.preprocess_images(batch_size, (image,))[0]
# VAE encode init image
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.encode_image(image)
# Add noise to latents using timesteps
noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float16, generator=self.generator)
latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep)
return timesteps, t_start, latents
def _get_add_time_ids(
self,
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype,
requires_aesthetics_score,
):
if requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
return add_time_ids, add_neg_time_ids
def start_profile(self, name, color="blue"):
if self.nvtx_profile:
self.markers[name] = nvtx.start_range(message=name, color=color)
event_name = name + "-start"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
def stop_profile(self, name):
event_name = name + "-stop"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
if self.nvtx_profile:
nvtx.end_range(self.markers[name])
def preprocess_images(self, batch_size, images=()):
self.start_profile("preprocess", color="pink")
init_images = []
for i in images:
image = i.to(self.device)
if image.shape[0] != batch_size:
image = image.repeat(batch_size, 1, 1, 1)
init_images.append(image)
self.stop_profile("preprocess")
return tuple(init_images)
def preprocess_controlnet_images(
self, batch_size, images=None, do_classifier_free_guidance=True, height=1024, width=1024
):
"""
Process a list of PIL.Image.Image as control images, and return a torch tensor.
"""
if images is None:
return None
self.start_profile("preprocess", color="pink")
if not self.pipeline_info.is_xl():
images = [
torch.from_numpy(
(np.array(image.convert("RGB")).astype(np.float32) / 255.0)[..., None].transpose(3, 2, 0, 1)
)
.to(device=self.device, dtype=torch.float16)
.repeat_interleave(batch_size, dim=0)
for image in images
]
else:
images = [
self.control_image_processor.preprocess(image, height=height, width=width)
.to(device=self.device, dtype=torch.float16)
.repeat_interleave(batch_size, dim=0)
for image in images
]
if do_classifier_free_guidance:
images = [torch.cat([i] * 2) for i in images]
images = torch.cat([image[None, ...] for image in images], dim=0)
self.stop_profile("preprocess")
return images
def encode_prompt(
self,
prompt,
negative_prompt,
encoder="clip",
tokenizer=None,
pooled_outputs=False,
output_hidden_states=False,
force_zeros_for_empty_prompt=False,
do_classifier_free_guidance=True,
dtype=torch.float16,
):
if tokenizer is None:
tokenizer = self.tokenizer
self.start_profile("clip", color="green")
def tokenize(prompt, output_hidden_states):
text_input_ids = (
tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
.input_ids.type(torch.int32)
.to(self.device)
)
hidden_states = None
if self.engine_type == EngineType.TORCH:
outputs = self.backend.engines[encoder](text_input_ids)
text_embeddings = outputs[0]
if output_hidden_states:
hidden_states = outputs["last_hidden_state"]
else:
outputs = self.run_engine(encoder, {"input_ids": text_input_ids})
text_embeddings = outputs["text_embeddings"]
if output_hidden_states:
hidden_states = outputs["hidden_states"]
return text_embeddings, hidden_states
# Tokenize prompt
text_embeddings, hidden_states = tokenize(prompt, output_hidden_states)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = text_embeddings.clone()
if hidden_states is not None:
hidden_states = hidden_states.clone()
# Note: negative prompt embedding is not needed for SD XL when guidance <= 1
if do_classifier_free_guidance:
# For SD XL base, handle force_zeros_for_empty_prompt
is_empty_negative_prompt = all([not i for i in negative_prompt])
if force_zeros_for_empty_prompt and is_empty_negative_prompt:
uncond_embeddings = torch.zeros_like(text_embeddings)
if output_hidden_states:
uncond_hidden_states = torch.zeros_like(hidden_states)
else:
# Tokenize negative prompt
uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states)
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if output_hidden_states:
hidden_states = torch.cat([uncond_hidden_states, hidden_states])
self.stop_profile("clip")
if pooled_outputs:
# For text encoder in sdxl base
return hidden_states.to(dtype=dtype), text_embeddings.to(dtype=dtype)
if output_hidden_states:
# For text encoder 2 in sdxl base or refiner
return hidden_states.to(dtype=dtype)
# For text encoder in sd 1.5
return text_embeddings.to(dtype=dtype)
def denoise_latent(
self,
latents,
text_embeddings,
denoiser="unet",
timesteps=None,
step_offset=0,
guidance=7.5,
add_kwargs=None,
):
do_classifier_free_guidance = guidance > 1.0
self.start_profile("denoise", color="blue")
if not isinstance(timesteps, torch.Tensor):
timesteps = self.scheduler.timesteps
for step_index, timestep in enumerate(timesteps):
# Expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, step_offset + step_index, timestep
)
# Predict the noise residual
if self.nvtx_profile:
nvtx_unet = nvtx.start_range(message="unet", color="blue")
params = {
"sample": latent_model_input,
"timestep": timestep.to(latents.dtype),
"encoder_hidden_states": text_embeddings,
}
if add_kwargs:
params.update(add_kwargs)
noise_pred = self.run_engine(denoiser, params)["latent"]
if self.nvtx_profile:
nvtx.end_range(nvtx_unet)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
if type(self.scheduler) is UniPCMultistepScheduler:
latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
elif type(self.scheduler) is LCMScheduler:
latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0]
else:
latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep)
# The actual number of steps. It might be different from denoising_steps.
self.actual_steps = len(timesteps)
self.stop_profile("denoise")
return latents
def encode_image(self, image):
self.start_profile("vae_encoder", color="red")
init_latents = self.run_engine("vae_encoder", {"images": image})["latent"]
init_latents = self.vae_scaling_factor * init_latents
self.stop_profile("vae_encoder")
return init_latents
def decode_latent(self, latents):
self.start_profile("vae", color="red")
images = self.backend.vae_decode(latents)
self.stop_profile("vae")
return images
def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> Dict[str, Any]:
throughput = batch_size / (toc - tic)
latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1]
latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1]
latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1]
latency_vae_encoder = (
cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1]
if vae_enc
else None
)
latency_pil = cudart.cudaEventElapsedTime(self.events["pil-start"], self.events["pil-stop"])[1] if pil else None
latency = (toc - tic) * 1000.0
print("|----------------|--------------|")
print("| {:^14} | {:^12} |".format("Module", "Latency"))
print("|----------------|--------------|")
if vae_enc:
print("| {:^14} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder))
print("| {:^14} | {:>9.2f} ms |".format("CLIP", latency_clip))
print(
"| {:^14} | {:>9.2f} ms |".format(
"UNet" + ("+CNet" if self.pipeline_info.controlnet else "") + " x " + str(self.actual_steps),
latency_unet,
)
)
print("| {:^14} | {:>9.2f} ms |".format("VAE-Dec", latency_vae))
pipeline = "Refiner" if self.pipeline_info.is_xl_refiner() else "Pipeline"
if pil:
print("| {:^14} | {:>9.2f} ms |".format("PIL", latency_pil))
print("|----------------|--------------|")
print(f"| {pipeline:^14} | {latency:>9.2f} ms |")
print("|----------------|--------------|")
print(f"Throughput: {throughput:.2f} image/s")
perf_data = {
"latency_clip": latency_clip,
"latency_unet": latency_unet,
"latency_vae": latency_vae,
"latency_pil": latency_pil,
"latency": latency,
"throughput": throughput,
}
if vae_enc:
perf_data["latency_vae_encoder"] = latency_vae_encoder
return perf_data
@staticmethod
def pt_to_pil(images):
images = (
((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
)
return [Image.fromarray(images[i]) for i in range(images.shape[0])]
@staticmethod
def pt_to_numpy(images: torch.FloatTensor):
"""
Convert a PyTorch tensor to a NumPy image.
"""
return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy()
def metadata(self) -> Dict[str, Any]:
data = {
"actual_steps": self.actual_steps,
"seed": self.get_current_seed(),
"name": self.pipeline_info.name(),
"custom_vae": self.pipeline_info.custom_fp16_vae(),
"custom_unet": self.pipeline_info.custom_unet(),
}
if self.engine_type == EngineType.ORT_CUDA:
for engine_name, engine in self.backend.engines.items():
data.update(engine.metadata(engine_name))
return data
def save_images(self, images: List, prompt: List[str], negative_prompt: List[str], metadata: Dict[str, Any]):
session_id = str(random.randint(1000, 9999))
for i, image in enumerate(images):
seed = str(self.get_current_seed())
prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20]
parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)]
image_path = os.path.join(self.output_dir, "-".join(parts) + ".png")
print(f"Saving image {i+1} / {len(images)} to: {image_path}")
from PIL import PngImagePlugin
info = PngImagePlugin.PngInfo()
for k, v in metadata.items():
info.add_text(k, str(v))
info.add_text("prompt", prompt[i])
info.add_text("negative_prompt", negative_prompt[i])
image.save(image_path, "PNG", pnginfo=info)
def _infer(
self,
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=30,
guidance=5.0,
seed=None,
image=None,
strength=0.3,
controlnet_images=None,
controlnet_scales=None,
show_latency=False,
output_type="pil",
):
if show_latency:
torch.cuda.synchronize()
start_time = time.perf_counter()
assert len(prompt) == len(negative_prompt)
batch_size = len(prompt)
self.set_denoising_steps(denoising_steps)
self.set_random_seed(seed)
timesteps = None
step_offset = 0
with torch.inference_mode(), torch.autocast("cuda"):
if image is not None:
timesteps, step_offset, latents = self.initialize_refiner(
batch_size=batch_size,
image=image,
strength=strength,
)
else:
# Pre-initialize latents
latents = self.initialize_latents(
batch_size=batch_size,
unet_channels=4,
latent_height=(image_height // 8),
latent_width=(image_width // 8),
)
do_classifier_free_guidance = guidance > 1.0
if not self.pipeline_info.is_xl():
denoiser = "unet"
text_embeddings = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
dtype=latents.dtype,
)
add_kwargs = {}
else:
denoiser = "unetxl"
# Time embeddings
original_size = (image_height, image_width)
crops_coords_top_left = (0, 0)
target_size = (image_height, image_width)
aesthetic_score = 6.0
negative_aesthetic_score = 2.5
add_time_ids, add_negative_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype=latents.dtype,
requires_aesthetics_score=self.pipeline_info.is_xl_refiner(),
)
if do_classifier_free_guidance:
add_time_ids = torch.cat([add_negative_time_ids, add_time_ids], dim=0)
add_time_ids = add_time_ids.to(device=self.device).repeat(batch_size, 1)
if self.pipeline_info.is_xl_refiner():
# CLIP text encoder 2
text_embeddings, pooled_embeddings2 = self.encode_prompt(
prompt,
negative_prompt,
encoder="clip2",
tokenizer=self.tokenizer2,
pooled_outputs=True,
output_hidden_states=True,
dtype=latents.dtype,
)
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
else: # XL Base
# CLIP text encoder
text_embeddings = self.encode_prompt(
prompt,
negative_prompt,
encoder="clip",
tokenizer=self.tokenizer,
output_hidden_states=True,
force_zeros_for_empty_prompt=True,
do_classifier_free_guidance=do_classifier_free_guidance,
dtype=latents.dtype,
)
# CLIP text encoder 2
text_embeddings2, pooled_embeddings2 = self.encode_prompt(
prompt,
negative_prompt,
encoder="clip2",
tokenizer=self.tokenizer2,
pooled_outputs=True,
output_hidden_states=True,
force_zeros_for_empty_prompt=True,
do_classifier_free_guidance=do_classifier_free_guidance,
dtype=latents.dtype,
)
# Merged text embeddings
text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1)
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
if self.pipeline_info.controlnet:
controlnet_images = self.preprocess_controlnet_images(
latents.shape[0],
controlnet_images,
do_classifier_free_guidance=do_classifier_free_guidance,
height=image_height,
width=image_width,
)
add_kwargs.update(
{
"controlnet_images": controlnet_images,
"controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device),
}
)
# UNet denoiser
latents = self.denoise_latent(
latents,
text_embeddings,
timesteps=timesteps,
step_offset=step_offset,
denoiser=denoiser,
guidance=guidance,
add_kwargs=add_kwargs,
)
with torch.inference_mode():
# VAE decode latent
if output_type == "latent":
images = latents
else:
images = self.decode_latent(latents / self.vae_scaling_factor)
if output_type == "pil":
self.start_profile("pil", color="green")
images = self.pt_to_pil(images)
self.stop_profile("pil")
perf_data = None
if show_latency:
torch.cuda.synchronize()
end_time = time.perf_counter()
perf_data = self.print_summary(
start_time, end_time, batch_size, vae_enc=self.pipeline_info.is_xl_refiner(), pil=(output_type == "pil")
)
return images, perf_data
def run(
self,
prompt: List[str],
negative_prompt: List[str],
image_height: int,
image_width: int,
denoising_steps: int = 30,
guidance: float = 5.0,
seed: Optional[int] = None,
image: Optional[torch.Tensor] = None,
strength: float = 0.3,
controlnet_images: Optional[torch.Tensor] = None,
controlnet_scales: Optional[torch.Tensor] = None,
show_latency: bool = False,
output_type: str = "pil",
deterministic: bool = False,
):
"""
Run the diffusion pipeline.
Args:
prompt (List[str]):
The text prompt to guide image generation.
negative_prompt (List[str]):
The prompt not to guide the image generation.
image_height (int):
Height (in pixels) of the image to be generated. Must be a multiple of 8.
image_width (int):
Width (in pixels) of the image to be generated. Must be a multiple of 8.
denoising_steps (int):
Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference.
guidance (float):
Higher guidance scale encourages to generate images that are closely linked to the text prompt.
seed (int):
Seed for the random generator
image (tuple[torch.Tensor]):
Reference image.
strength (float):
Indicates extent to transform the reference image, which is used as a starting point,
and more noise is added the higher the strength.
show_latency (bool):
Whether return latency data.
output_type (str):
It can be "latent", "pt" or "pil".
"""
if deterministic:
torch.use_deterministic_algorithms(True)
if self.is_backend_tensorrt():
import tensorrt as trt
from trt_utilities import TRT_LOGGER
with trt.Runtime(TRT_LOGGER):
return self._infer(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=denoising_steps,
guidance=guidance,
seed=seed,
image=image,
strength=strength,
controlnet_images=controlnet_images,
controlnet_scales=controlnet_scales,
show_latency=show_latency,
output_type=output_type,
)
else:
return self._infer(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=denoising_steps,
guidance=guidance,
seed=seed,
image=image,
strength=strength,
controlnet_images=controlnet_images,
controlnet_scales=controlnet_scales,
show_latency=show_latency,
output_type=output_type,
)

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
def init_trt_plugins():
# Register TensorRT plugins
trt.init_libnvinfer_plugins(TRT_LOGGER, "")

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

Some files were not shown because too many files have changed in this diff Show More