I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,703 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import gc
import itertools
import logging
import os
import sys
import time
import numpy as np
import onnx
import psutil
import torch
from benchmark_helper import measure_memory, setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings_as_ortvalues,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import onnxruntime as ort
logger = logging.getLogger(__name__)
# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
def get_ort_model_inputs_len(args, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
return 0
if args.benchmark_type == "hf-ort":
try:
# New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
return len(model.inputs_names)
except Exception:
# Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
return len(model.decoder.input_names)
return len(model.get_inputs())
def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
init_inputs, iter_inputs = None, None
# For past_present_share_buffer:
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
# Set max_seq_len to config value for other models
max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
elif args.benchmark_type in {"hf-ort"}:
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
# Using split models in Optimum (e.g. created by Optimum export)
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
else:
# Using merged model in Optimum (e.g. created by convert_to_onnx export)
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
elif args.benchmark_type == "ort-convert-to-onnx":
# Microsoft export from convert_to_onnx
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
elif args.benchmark_type == "ort-msft":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
init_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
iter_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
else:
raise Exception("Unable to auto-detect inputs for provided model")
return init_inputs, iter_inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
# 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
# 3) Benchmark LLaMA-2 from local download of model
# 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
# 5) Benchmark LLaMA-2 from convert_to_onnx
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_cache=True,
cache_dir=args.cache_dir,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export or convert_to_onnx.py export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
decoder_file_name = None
decoder_with_past_file_name = None
for filename in os.listdir(args.hf_ort_dir_path):
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
continue
if "decoder_model" in filename or filename == "model.onnx":
decoder_file_name = filename
if "decoder_with_past_model" in filename:
decoder_with_past_file_name = filename
if "decoder_merged_model" in filename:
decoder_file_name = filename
decoder_with_past_file_name = filename
start_time = time.time()
model = ORTModelForCausalLM.from_pretrained(
args.hf_ort_dir_path,
decoder_file_name=decoder_file_name,
decoder_with_past_file_name=decoder_with_past_file_name,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
use_merged=(True if decoder_file_name == "model.onnx" else None),
provider=provider,
provider_options=provider_options,
session_options=sess_options,
)
end_time = time.time()
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path.format(args.rank),
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(inputs)
logger.info(outputs)
input_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_inputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
output_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_outputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
for _ in warmup_range:
input_sync()
fn(inputs)
output_sync()
# Benchmark
total_time = 0
bench_range = (
range(args.num_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
input_sync()
start_time = time.time()
fn(inputs)
output_sync()
end_time = time.time()
total_time += end_time - start_time
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
logger.info("")
latency = total_time / args.num_runs
throughput = args.batch_size / latency
if args.rank == 0:
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} tps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
if args.rank == 0:
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, init_inputs, iter_inputs, model):
# Inference steps to measure
def get_logits(inputs):
# Inference pass without decoding
outputs = model(**inputs)
return outputs
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_pred_ids(inputs):
# # Inference pass with predicted token ids generation
# predicted_ids = model.generate(**inputs)
# return predicted_ids
# def gen_and_dec(inputs):
# # Inference pass with generation and decoding
# predicted_ids = get_pred_ids(inputs)
# transcription = []
# for bs in range(args.batch_size):
# for rs in range(args.num_return_sequences):
# transcription.append(
# args.tokenizer.batch_decode(
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
# )[0]
# )
# return transcription
generate_fn = get_logits
if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(init_inputs)
generate_fn(iter_inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder_with_past.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
time_fn(args, generate_fn, init_inputs)
measure_fn(args, generate_fn, init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
time_fn(args, generate_fn, iter_inputs)
measure_fn(args, generate_fn, iter_inputs)
def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding, kv_cache_ortvalues
return inputs, kv_cache_ortvalues
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
kv_cache_ortvalues = {}
if args.profile:
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
# Turn profiling off to stop appending to log
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)
def run_inference(args, init_inputs, iter_inputs, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, init_inputs, iter_inputs, model)
elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
run_ort_inference(args, init_inputs, iter_inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def get_args(rank=0):
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=[
"hf-pt-eager",
"hf-pt-compile",
"hf-ort",
"ort-msft",
"ort-convert-to-onnx",
],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
)
# Args for choosing the model
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-dir-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
default="32 64 128 256 512",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Args for decoding logic
parser.add_argument("--max-length", type=int, default=32)
parser.add_argument("--num-return-sequences", type=int, default=1)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--cache-dir",
type=str,
required=True,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
args.device = "cuda"
# Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.sequence_lengths = args.sequence_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
# Check that only one (batch_size, sequence_length) combination is set for profiling
if args.profile:
assert (
len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1
), "Please provide only one (batch_size, sequence_length) combination for profiling"
return args
def main():
rank = get_rank()
world_size = get_size()
args = get_args(rank)
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
args.rank = rank
args.world_size = world_size
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
config = AutoConfig.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "tokenizer", tokenizer) # noqa: B010
setattr(args, "config", config) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
# Get model and model info
model = get_model(args)
ort_model_inputs_len = get_ort_model_inputs_len(args, model)
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
else:
setattr(args, "use_buffer_share", False) # noqa: B010
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
if args.rank == 0:
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
setattr(args, "batch_size", int(batch_size)) # noqa: B010
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
run_inference(args, init_inputs, iter_inputs, model)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,492 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import json
import logging
import os
import subprocess
import torch
from benchmark_helper import setup_logger
from metrics import BenchmarkRecord
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--batch-sizes",
type=str,
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
type=str,
default="8 16 32 64 128 256 512",
)
parser.add_argument(
"-w",
"--warmup-runs",
type=int,
default=5,
)
parser.add_argument(
"-n",
"--num-runs",
type=int,
default=1000,
)
parser.add_argument(
"--hf-pt-eager",
default=False,
action="store_true",
help="Benchmark in PyTorch without `torch.compile`",
)
parser.add_argument(
"--hf-pt-compile",
default=False,
action="store_true",
help="Benchmark in PyTorch with `torch.compile`",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
parser.add_argument(
"--ort-msft-model-path",
type=str,
default="",
help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx",
)
parser.add_argument(
"--ort-convert-to-onnx-model-path",
type=str,
default="",
help="Path to ONNX model from convert_to_onnx",
)
parser.add_argument(
"--cache-dir",
type=str,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name in Hugging Face",
)
parser.add_argument(
"--precision",
type=str,
required=True,
choices=["int4", "int8", "fp16", "fp32"],
help="Precision to run model",
)
parser.add_argument(
"--device",
type=str,
required=True,
choices=["cpu", "cuda", "rocm"],
help="Device to benchmark models",
)
parser.add_argument(
"--device-id",
type=int,
default=0,
help="GPU device ID",
)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Print detailed logs",
)
parser.add_argument(
"--timeout",
type=int,
default=10,
help="Number of mins to attempt the benchmark before moving on",
)
parser.add_argument(
"--log-folder",
type=str,
default=None,
help="Path to folder to save logs and results",
)
args = parser.parse_args()
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
log_folder_name = f"./{args.model_size}_{args.precision}"
if not args.log_folder:
args.log_folder = log_folder_name
os.makedirs(args.log_folder, exist_ok=True)
# Convert timeout value to secs
args.timeout *= 60
return args
def process_log_file(device_id, log_file, base_results):
entries = []
batch_size, sequence_length, step = None, None, None
latency_s, latency_ms, throughput, memory = None, None, None, None
batch_pattern = "Batch Size: "
sequence_pattern = "Sequence Length: "
prompt_step_pattern = "to get past_key_values"
per_token_step_pattern = "with past_key_values"
latency_pattern = "Latency: "
throughput_pattern = "Throughput: "
memory_pattern = "peak="
with open(log_file) as f:
for input_line in f:
line = input_line.replace("\n", "")
if batch_pattern in line:
batch_size = int(line[len(batch_pattern) :])
elif sequence_pattern in line:
sequence_length = int(line[len(sequence_pattern) :])
elif prompt_step_pattern in line:
step = "prompt"
elif per_token_step_pattern in line:
step = "per-token"
elif latency_pattern in line:
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
latency_ms = latency_s * 1000
elif throughput_pattern in line:
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
elif memory_pattern in line:
if "CPU" in line:
# Example format for log entry:
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
else:
# Example format for log entry:
# GPU memory usage: before=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 69637.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] peak=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 73861.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}]
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
usage = json.loads(peak)[device_id]["max_used_MB"]
memory = float(usage) / 1000
# Append log entry to list of entries
entry = base_results + [ # noqa: RUF005
batch_size,
sequence_length,
step,
latency_s,
latency_ms,
throughput,
memory,
]
entries.append(entry)
return entries
def save_results(results, filename):
import pandas as pd
df = pd.DataFrame(
results,
columns=[
"Warmup Runs",
"Measured Runs",
"Model Name",
"Engine",
"Precision",
"Device",
"Batch Size",
"Sequence Length",
"Step",
"Latency (s)",
"Latency (ms)",
"Throughput (tps)",
"Memory (GB)",
],
)
# Set column types
df["Warmup Runs"] = df["Warmup Runs"].astype("int")
df["Measured Runs"] = df["Measured Runs"].astype("int")
df["Batch Size"] = df["Batch Size"].astype("int")
df["Sequence Length"] = df["Sequence Length"].astype("int")
df["Latency (s)"] = df["Latency (s)"].astype("float")
df["Latency (ms)"] = df["Latency (ms)"].astype("float")
df["Throughput (tps)"] = df["Throughput (tps)"].astype("float")
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
# get package name and version
import pkg_resources
installed_packages = pkg_resources.working_set
installed_packages_list = sorted(
[
f"{i.key}=={i.version}"
for i in installed_packages
if i.key in ["ort-nightly-gpu", "ort-nightly", "onnxruntime", "onnxruntime-gpu"]
]
)
ort_pkg_name = ""
ort_pkg_version = ""
if installed_packages_list:
ort_pkg_name = installed_packages_list[0].split("==")[0]
ort_pkg_version = installed_packages_list[0].split("==")[1]
# Save results to csv with standard format
records = []
for _, row in df.iterrows():
if row["Engine"] in ["optimum-ort", "onnxruntime"]:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], "onnxruntime", row["Device"], ort_pkg_name, ort_pkg_version
)
elif row["Engine"] in ["pytorch-eager", "pytorch-compile"]:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], "pytorch", row["Device"], torch.__name__, torch.__version__
)
else:
record = BenchmarkRecord(row["Model Name"], row["Precision"], row["Engine"], row["Device"], "", "")
record.config.warmup_runs = row["Warmup Runs"]
record.config.measured_runs = row["Measured Runs"]
record.config.batch_size = row["Batch Size"]
record.config.seq_length = row["Sequence Length"]
record.config.customized["measure_step"] = row["Step"]
record.config.customized["engine"] = row["Engine"]
record.metrics.customized["latency_s_mean"] = row["Latency (s)"]
record.metrics.latency_ms_mean = row["Latency (ms)"]
record.metrics.customized["throughput_tps"] = row["Throughput (tps)"]
record.metrics.max_memory_usage_GB = row["Memory (GB)"]
records.append(record)
BenchmarkRecord.save_as_csv(filename, records)
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
logger.info(f"Results saved in {filename}!")
def benchmark(args, benchmark_cmd, engine):
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
log_path = os.path.join(args.log_folder, log_filename)
with open(log_path, "w") as log_file:
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
try:
process.wait(args.timeout)
except subprocess.TimeoutExpired:
process.kill()
# Create entries for csv
logger.info("Gathering data from log files...")
base_results = [args.warmup_runs, args.num_runs, args.model_name, engine, args.precision, args.device]
results = process_log_file(args.device_id, log_path, base_results)
return results
def main():
args = get_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
all_results = []
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
# Benchmark PyTorch without torch.compile
if args.hf_pt_eager:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-pt-eager",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch without torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-eager")
all_results.extend(results)
# Benchmark PyTorch with torch.compile
if args.hf_pt_compile:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-pt-compile",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch with torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-compile")
all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
if args.hf_ort_dir_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-ort",
"--hf-ort-dir-path",
args.hf_ort_dir_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark Optimum + ONNX Runtime")
results = benchmark(args, benchmark_cmd, "optimum-ort")
all_results.extend(results)
# Benchmark Microsoft model in ONNX Runtime
if args.ort_msft_model_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"ort-msft",
"--ort-model-path",
args.ort_msft_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark Microsoft model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "ort-msft")
all_results.extend(results)
# Benchmark convert_to_onnx model in ONNX Runtime
if args.ort_convert_to_onnx_model_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"ort-convert-to-onnx",
"--ort-model-path",
args.ort_convert_to_onnx_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime")
all_results.extend(results)
csv_file = f"{args.model_size}_{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_results, os.path.join(args.log_folder, csv_file))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,606 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This is an end-to-end benchmarking script for the Hugging Face LLaMA-2 model.
#
# Prerequisites:
# 1) Install `huggingface-cli`:
#
# $ pip install huggingface_hub
#
# 2) Authenticate with Hugging Face's CLI:
#
# $ huggingface-cli login
#
# 3) Accept Meta's license in Hugging Face to access the models at https://huggingface.co/meta-llama/
#
# 4) Install the latest ONNX Runtime version
#
# $ pip install onnxruntime-gpu
#
# 5) Install flash attention v2
#
# $ pip install flash-attn --no-build-isolation
#
# 6) Install bitsandbytes
#
# $ pip install bitsandbytes
from __future__ import annotations
import argparse
import datetime
import gc
import itertools
import json
import logging
import os
import textwrap
import time
import numpy as np
import pandas as pd
import torch
from benchmark_helper import setup_logger
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import onnxruntime as ort
logger = logging.getLogger(__name__)
def get_model(args: argparse.Namespace):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
model = None
if args.onnx_precision == "int4" and args.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
try:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
except Exception as e:
# When flash_attention or sdpa doesn't support a model, it throws an exception.
# Rather than stopping a process, run as eager mode.
print("Try to load a model using eager mode: ", e)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="eager",
).to(args.target_device)
model.eval()
if args.benchmark_type == "pt-compile":
model = torch.compile(model)
else:
sess_options = ort.SessionOptions()
ep = (
("CUDAExecutionProvider", {"device_id": args.device_id})
if args.device == "cuda"
else "CPUExecutionProvider"
)
model = ort.InferenceSession(args.onnx_model_path, sess_options=sess_options, providers=[ep])
return model
def run_inference(args, model, runs, inputs, outputs):
if args.benchmark_type == "pt-compile":
with torch.no_grad():
outputs = model(**inputs)
# Synchronize inputs
io_binding = None
if args.benchmark_type in {"pt-eager", "pt-compile"}:
if args.device != "cpu":
torch.cuda.synchronize(args.target_device)
else:
io_binding = add_io_bindings_as_tensors(model, inputs, outputs, args.use_fp16, args.use_buffer_share)
io_binding.synchronize_inputs()
# Run inference
start = time.perf_counter()
for _ in range(runs):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
with torch.no_grad():
outputs = model(**inputs)
if args.device != "cpu":
torch.cuda.synchronize(args.target_device)
else:
model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
end = time.perf_counter()
avg = (end - start) / runs
return avg, outputs
def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
clear_cache()
inputs, outputs = get_initial_inputs_and_outputs(
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
)
_, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
return inputs, outputs
def clear_cache():
gc.collect()
torch.cuda.empty_cache()
def save_results(results, filename, gen_length):
df = pd.DataFrame(
results,
columns=[
"Batch Size",
"Prompt Length",
"Prompt Processing Latency (ms)",
"Prompt Processing Throughput (tps)",
"Sampling Latency (ms)",
"Sampling Throughput (tps)",
"First Token Generated Latency (ms)",
"First Token Generated Throughput (tps)",
f"Average Latency of First {gen_length // 2} Tokens Generated (ms)",
f"Average Throughput of First {gen_length // 2} Tokens Generated (tps)",
f"Average Latency of First {gen_length} Tokens Generated (ms)",
f"Average Throughput of First {gen_length} Tokens Generated (tps)",
"Wall-Clock Latency (s)",
"Wall-Clock Throughput (tps)",
],
)
df.to_csv(filename, index=False)
logger.info(f"Results saved in {filename}!")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=["pt-eager", "pt-compile", "ort"],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=False,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a",
"--auth",
default=False,
action="store_true",
help="Use Hugging Face authentication token to access model",
)
parser.add_argument(
"-t",
"--trust",
default=False,
action="store_true",
help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files",
)
parser.add_argument(
"-c",
"--cache-dir",
type=str,
default=os.path.join(".", "model_cache"),
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(model_name, cache_dir=cache_dir)`.",
)
parser.add_argument(
"--hf-dir-path",
type=str,
default="",
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(folder_path)`.",
)
parser.add_argument(
"-o",
"--onnx-model-path",
required=False,
help="Path to ONNX model",
)
parser.add_argument(
"-f",
"--prompts-file",
required=True,
default=os.path.join(".", "models", "llama", "prompts.json"),
help="JSON file containing entries in the format 'prompt length: prompt' where prompt length = tokenized length of prompt",
)
parser.add_argument(
"--use_buffer_share",
default=False,
action="store_true",
help="Use when GroupQueryAttention (GQA) is in ONNX model",
)
parser.add_argument(
"--anomaly-filtering",
default=False,
action="store_true",
help="Use this flag to filter anomaly accelerator times for tokens generated. \
This may give more accurate latency and throughput metrics for tokens generated. \
Wall-clock metrics are still reported with anomaly times though.",
),
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--prompt-lengths",
default="16 64 256 1024",
)
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"-g",
"--generation-length",
type=int,
default=256,
help="Number of new tokens to generate",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=100)
parser.add_argument("--seed", type=int, default=2)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
# Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "ort":
assert args.onnx_model_path, "Please specify a path to `--onnx-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.prompt_lengths = args.prompt_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
setattr(args, "onnx_precision", args.precision) # noqa: B010
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32
engine = "ort" if args.benchmark_type == "ort" else "pt"
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "torch_dtype", torch_dtype) # noqa: B010
setattr(args, "engine", engine) # noqa: B010
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
args.use_buffer_share = args.use_buffer_share and engine == "ort"
return args
def main():
args = get_args()
setup_logger(False)
logger.info(args.__dict__)
# Get prompts and prompt sizes
size_to_prompt = None
with open(args.prompts_file) as f:
size_to_prompt = json.load(f, object_hook=lambda d: {int(k): v for k, v in d.items()})
# Get config, tokenizer, and model
config = AutoConfig.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.trust,
)
tokenizer = AutoTokenizer.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.trust,
)
model = get_model(args)
all_csv_metrics = []
for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901
logger.info(f"Running batch size = {batch_size}, prompt length = {prompt_length}")
clear_cache()
max_length = prompt_length + args.generation_length
if prompt_length not in size_to_prompt:
raise NotImplementedError(
textwrap.dedent(
f"""
A prompt of size {prompt_length} was not found in '{args.prompts_file}'. There are a couple of solutions to fix this.
1) You can change one of the keys in '{args.prompts_file}' to be {prompt_length}.
If {prompt_length} < actual prompt's length, the benchmark E2E tool will repeat the first word in the prompt until {prompt_length} = actual prompt's length.
If {prompt_length} > actual prompt's length, the benchmark E2E tool will automatically trim the actual prompt's length so that {prompt_length} = actual prompt's length.
2) You can add a new key-value entry in '{args.prompts_file}' of the form '{prompt_length}': 'your prompt goes here'.
"""
)
)
prompt = [size_to_prompt[prompt_length]] * batch_size
csv_metrics = [batch_size, prompt_length]
try:
# Measure prompt processing
logger.info("Measuring prompt processing...")
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)
# Calculate prompt metrics
accelerator_prompt_latency_ms = accelerator_prompt_latency_s * 1000
accelerator_prompt_thrpt = batch_size * (prompt_length / accelerator_prompt_latency_s)
logger.info(f"Average Latency of Prompt Processing: {accelerator_prompt_latency_ms} ms")
logger.info(
f"Average Throughput of Prompt Processing: {batch_size * (prompt_length / accelerator_prompt_latency_s)} tps"
)
csv_metrics.extend([accelerator_prompt_latency_ms, accelerator_prompt_thrpt])
# Measure token generation
logger.info("Measuring token generation...")
clear_cache()
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
all_token_ids = inputs["input_ids"].clone()
current_length = all_token_ids.shape[-1]
num_heads = config.num_key_value_heads
head_size = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
has_eos = torch.zeros(batch_size, device=args.target_device, dtype=torch.bool)
# 0th entry will have prompt accelerator time, 1st entry onwards will have token generation accelerator time
accelerator_times = []
sampling_times = [] # cost to sample after each model run
wall_clock_start_time = time.perf_counter()
while current_length <= max_length:
# Run inference
accelerator_time_latency_s, outputs = run_inference(args, model, 1, inputs, outputs)
accelerator_times.append(accelerator_time_latency_s)
# Sample with argmax (greedy search)
sampling_start_time = time.perf_counter()
if outputs["logits"].shape[1] > 1:
prompt_end_indices = inputs["attention_mask"].sum(1) - 1
idxs = (
prompt_end_indices.unsqueeze(dim=1)
.repeat(1, config.vocab_size)
.view(batch_size, 1, config.vocab_size)
)
next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
else:
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check if we previously reached EOS token id or if generated token id is EOS token id
has_eos = has_eos | next_tokens == tokenizer.eos_token_id
# Determine which new tokens to add to list of all token ids
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
sampling_end_time = time.perf_counter()
sampling_times.append(sampling_end_time - sampling_start_time)
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
current_length += 1
# Update inputs for next inference run
inputs["input_ids"] = tokens_to_add
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
)
if "position_ids" in inputs:
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
outputs["logits"].zero_()
# Update KV caches for next inference run
if args.engine == "pt":
# Update KV caches for PyTorch
inputs["past_key_values"] = outputs["past_key_values"]
elif not args.use_buffer_share:
# Update KV caches for ONNX Runtime if buffer sharing is not used
for i in range(config.num_hidden_layers):
inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]
new_sequence_length = inputs["attention_mask"].shape[1]
for i in range(config.num_hidden_layers):
present_key = torch.zeros(
batch_size,
num_heads,
new_sequence_length,
head_size,
device=args.target_device,
dtype=args.torch_dtype,
)
present_value = torch.zeros(
batch_size,
num_heads,
new_sequence_length,
head_size,
device=args.target_device,
dtype=args.torch_dtype,
)
outputs.update(
{
f"present.{i}.key": present_key.contiguous(),
f"present.{i}.value": present_value.contiguous(),
}
)
wall_clock_end_time = time.perf_counter()
# Filter out any anomaly accelerator times (e.g. for `torch.compile`)
accelerator_times.pop(0) # Remove prompt processing time
if args.anomaly_filtering:
anomaly_threshold_factor = 10
min_time_s = min(accelerator_times)
orig_size = len(accelerator_times)
accelerator_times = list(
filter(lambda acc_time: acc_time < anomaly_threshold_factor * min_time_s, accelerator_times)
)
new_size = len(accelerator_times)
logger.info(
f"Filtered out {orig_size - new_size} anomaly accelerator times that are {anomaly_threshold_factor}x greater than {min_time_s * 1000} ms..."
)
#######################################################
# Calculate sampling and first token generated metrics
#######################################################
# Calculate sampling metrics
avg_sampling_latency_s = sum(sampling_times) / len(sampling_times)
avg_sampling_latency_ms = avg_sampling_latency_s * 1000
avg_sampling_thrpt = batch_size * (1 / avg_sampling_latency_s)
logger.info(f"Average Latency of Sampling: {avg_sampling_latency_ms} ms")
logger.info(f"Average Throughput of Sampling: {avg_sampling_thrpt} tps")
# Calculate first token generated metrics
first_token_latency_s = accelerator_times[0]
first_token_latency_ms = first_token_latency_s * 1000
first_token_thrpt = batch_size * (1 / first_token_latency_s)
logger.info(f"Latency of First Token Generated: {first_token_latency_ms} ms")
logger.info(f"Throughput of First Token Generated: {first_token_thrpt} tps")
####################################################
# Calculate first `halfway` token generated metrics
####################################################
halfway = args.generation_length // 2
halfway_token_latency_s = sum(accelerator_times[:halfway]) / len(accelerator_times[:halfway])
halfway_token_latency_ms = halfway_token_latency_s * 1000
halfway_token_thrpt = batch_size * (1 / halfway_token_latency_s)
logger.info(f"Average Latency of First {halfway} Tokens Generated: {halfway_token_latency_ms} ms")
logger.info(f"Average Throughput of First {halfway} Tokens Generated: {halfway_token_thrpt} tps")
#########################################
# Calculate all tokens generated metrics
#########################################
all_token_latency_s = sum(accelerator_times) / len(accelerator_times)
all_token_latency_ms = all_token_latency_s * 1000
all_token_thrpt = batch_size * (1 / all_token_latency_s)
logger.info(
f"Average Latency of First {args.generation_length} Tokens Generated: {all_token_latency_ms} ms"
)
logger.info(f"Average Throughput of First {args.generation_length} Tokens Generated: {all_token_thrpt} tps")
###############################
# Calculate wall clock metrics
###############################
wall_clock_latency_s = wall_clock_end_time - wall_clock_start_time
wall_clock_thrpt = batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)
logger.info(f"Wall-Clock Latency: {wall_clock_latency_s} s")
logger.info(
f"Wall-Clock Throughput: {batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)} tps"
)
# Add metrics to CSV
logger.info("Adding results to CSV")
csv_metrics.extend(
[
avg_sampling_latency_ms,
avg_sampling_thrpt,
first_token_latency_ms,
first_token_thrpt,
halfway_token_latency_ms,
halfway_token_thrpt,
all_token_latency_ms,
all_token_thrpt,
wall_clock_latency_s,
wall_clock_thrpt,
]
)
all_csv_metrics.append(csv_metrics)
except Exception as e:
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")
filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_csv_metrics, filename, args.generation_length)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,57 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import torch.distributed as dist
def init_dist():
if "LOCAL_RANK" in os.environ:
int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0))
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
else:
# don't need to do init for single process
pass
def _get_comm():
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
return comm
except ImportError:
return None
def get_rank():
comm = _get_comm()
return comm.Get_rank() if comm is not None else 0
def get_size():
comm = _get_comm()
return comm.Get_size() if comm is not None else 1
def barrier():
comm = _get_comm()
if comm is not None:
comm.Barrier()
def print_out(*args):
if get_rank() == 0:
print(*args)

View File

@ -0,0 +1,503 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer
from onnxruntime import InferenceSession, OrtValue
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if use_past_kv:
# Shape: (batch_size, 1)
position_ids = position_ids[:, -1].unsqueeze(-1)
# Shape: (batch_size, sequence_length)
return position_ids
# Inputs for first pass to get initial past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, sequence_length)
# position_ids: (batch_size, sequence_length)
def get_sample_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
engine: str = "pt",
return_dict: bool = False,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
if not return_dict:
# For export
return (input_ids, attention_mask, position_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return inputs
# Inputs for subsequent passes with past_key_values
# input_ids: (batch_size, 1)
# attention_mask: (batch_size, past_sequence_length + 1)
# position_ids: (batch_size, 1)
# past_key: (batch_size, num_heads, past_sequence_length, head_size)
# past_value: (batch_size, num_heads, past_sequence_length, head_size)
def get_sample_with_past_kv_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
past_seq_len: int,
use_fp16: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
# position_ids is of shape (batch_size, 1)
position_ids = get_position_ids(attention_mask, use_past_kv=True)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Inputs for all passes with past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, past_sequence_length + sequence_length)
# position_ids: (batch_size, sequence_length)
# past_key: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
# past_value: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
def get_merged_sample_with_past_kv_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
past_seq_len: int,
max_seq_len: int,
use_fp16: bool = False,
use_buffer_share: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
if use_buffer_share:
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
def get_msft_sample_inputs(
config: AutoConfig,
batch_size: int,
past_seq_len: int,
seq_len: int,
max_seq_len: int,
use_fp16: bool,
use_buffer_share: bool,
split_kv: bool,
):
np_dtype = np.float16 if use_fp16 else np.float32
head_size = config.hidden_size // config.num_attention_heads
if not split_kv:
ort_inputs = {
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
"attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
"k_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"v_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"pos": np.array(past_seq_len, dtype=np.int64),
}
else:
ort_inputs = {
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
"attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
np.int32
),
"pos": np.array(past_seq_len, dtype=np.int64),
}
for i in range(config.num_hidden_layers):
ort_inputs.update(
{
f"k_{i}_cache": np.random.rand(
batch_size, config.num_attention_heads, past_seq_len, head_size
).astype(np_dtype),
f"v_{i}_cache": np.random.rand(
batch_size, config.num_attention_heads, past_seq_len, head_size
).astype(np_dtype),
}
)
if use_buffer_share:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
num_heads = config.num_key_value_heads // world_size
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
)
for _ in range(config.num_hidden_layers)
]
return past_kv
# Convert list of past_key_values to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
past_kv = {}
for i, (past_k, past_v) in enumerate(past_key_values):
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
return past_kv
# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
pt_inputs: dict,
use_buffer_share: bool = False,
past_seq_len: int = 0,
max_seq_len: int = 2048,
):
ort_inputs = {}
for k, v in pt_inputs.items():
if isinstance(v, np.ndarray):
ort_inputs[k] = v
elif k == "past_key_values":
ort_inputs.update(flatten_past_kv_inputs(v))
else:
ort_inputs[k] = v.detach().cpu().numpy()
# Reshape KV caches if using past-present-share-buffer
if use_buffer_share:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
# Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
# (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
for k, v in ort_inputs.items():
# Allocate new buffers with max_sequence_length for GQA
if "cache" in k or "past_key_values" in k:
# Copy v (BxSxPxH) into new_v (BxSxMxH)
batch_size, num_heads, _, head_size = v.shape
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
ort_inputs[k] = new_v
return ort_inputs
# Verify ONNX Runtime inputs with model
def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(ort_inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
print(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
del ort_inputs[unnecessary_input]
return ort_inputs
# Add IO bindings for execution providers using OrtValue
# Use when you need to run inference once or twice to save memory
def add_io_bindings_as_ortvalues(
model: InferenceSession,
ort_inputs: dict,
device: str,
device_id: int,
use_buffer_share: bool,
kv_cache_ortvalues: dict,
):
io_binding = model.io_binding()
model_inputs = set(map(lambda i: i.name, model.get_inputs()))
for k, v in ort_inputs.items():
# Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
# GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
# but `position_ids` is used as a PyTorch model input
if k not in model_inputs:
continue
# Bind OrtValue inputs to device
if use_buffer_share and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
kv_cache_ortvalues[k] = v_device
else:
kv_cache_ortvalues[k].update_inplace(v)
io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
else:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
for output in model.get_outputs():
name = output.name
if use_buffer_share and ("out" in name or "present" in name):
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
input_name = name.replace("out", "cache").replace("present", "past_key_values")
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
else:
io_binding.bind_output(name, device_type=device, device_id=device_id)
return io_binding, kv_cache_ortvalues
# Add IO bindings for execution providers using PyTorch tensors
# Use when you need to run inference many times
def add_io_bindings_as_tensors(
model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
):
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)
device = None
pt_to_np = {
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.float16": np.float16,
"torch.float32": np.float32,
}
# Bind inputs/outputs to IO binding
io_binding = model.io_binding()
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type=v.device.type,
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
device = v.device
for output in model.get_outputs():
name = output.name
# Bind KV cache outputs to KV cache inputs
v = (
inputs[name.replace("present", "past_key_values")]
if use_buffer_share and "present" in name
else outputs[name]
)
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
return io_binding
# Get actual inputs when using real data (instead of sample data) and initialize outputs
def get_initial_inputs_and_outputs(
config: AutoConfig,
tokenizer: AutoTokenizer,
requested_length: int,
prompt: list[str],
device: torch.device,
use_fp16: bool,
use_buffer_share: bool,
engine: str,
):
tokenizer.pad_token = tokenizer.eos_token
encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
torch_dtype = torch.float16 if use_fp16 else torch.float32
# input_ids: pad token id is 0
# attention_mask: pad token id is 0
# position_ids: pad token id is 1
input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
# Check if tokenized prompt length matches the requested prompt length
tokenized_length = input_ids.shape[-1]
if tokenized_length > requested_length:
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
input_ids = input_ids[:, :requested_length]
attention_mask = attention_mask[:, :requested_length]
position_ids = get_position_ids(attention_mask, use_past_kv=False)
elif tokenized_length < requested_length:
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
for _ in range(requested_length - tokenized_length):
input_ids = torch.hstack((input_ids_first_col, input_ids))
attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
position_ids = get_position_ids(attention_mask, use_past_kv=False)
tokenized_length = input_ids.shape[-1]
assert tokenized_length == requested_length
# Create inputs
inputs = {
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
}
if engine != "ort":
inputs["past_key_values"] = []
# Get shape of KV cache inputs
batch_size, sequence_length = input_ids.shape
max_sequence_length = config.max_position_embeddings
num_heads = config.num_key_value_heads
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
# Create KV cache inputs
for i in range(config.num_hidden_layers):
past_key = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
past_value = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
if engine == "ort":
inputs.update(
{
f"past_key_values.{i}.key": past_key.contiguous(),
f"past_key_values.{i}.value": past_value.contiguous(),
}
)
else:
inputs["past_key_values"].append((past_key, past_value))
outputs = None
if engine == "ort":
# Create outputs
logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
outputs = {"logits": logits.contiguous()}
if not use_buffer_share:
for i in range(config.num_hidden_layers):
present_key = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
present_value = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
outputs.update(
{f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
)
return inputs, outputs

View File

@ -0,0 +1,309 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import argparse
import logging
import os
import time
import numpy as np
import torch
from benchmark_helper import setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings_as_ortvalues,
convert_inputs_for_ort,
get_merged_sample_with_past_kv_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from llama_torch import setup_torch_model
from transformers import AutoConfig
import onnxruntime as ort
logger = logging.getLogger("")
def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
max_sequence_length = config.max_position_embeddings
return past_sequence_length, curr_sequence_length, max_sequence_length
def get_inputs(args: argparse.Namespace, config: AutoConfig):
# Dummy values for parity
world_size = get_size()
batch_size = 2
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)
if args.merged:
inputs = get_merged_sample_with_past_kv_inputs(
config,
args.device,
batch_size,
seq_len=sequence_length,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
return_dict=True,
world_size=world_size,
)
elif args.use_past_kv:
inputs = get_sample_with_past_kv_inputs(
config,
args.device,
batch_size,
sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
world_size=world_size,
)
else:
inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
return inputs
def verify_parity(
args: argparse.Namespace,
location: str,
use_auth_token: bool,
kv_cache_ortvalues: dict,
pytorch_model: None | torch.nn.Module = None,
config: None | AutoConfig = None,
):
# If it's running in a machine which GPU memory < 36GB, it should unload the llama in GPU in time and free the GPU memory for ORT.
py_model = pytorch_model
if py_model is None:
config, py_model = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)
inputs = get_inputs(args, config)
# Run inference with PyTorch
if args.execution_provider != "cpu":
torch.cuda.synchronize()
start_time = time.time()
pt_outputs = py_model(**inputs).logits.detach().cpu().numpy()
if args.execution_provider != "cpu":
torch.cuda.synchronize()
end_time = time.time()
logger.info(f"PyTorch took {end_time - start_time} s")
if args.small_gpu and py_model is not None:
del py_model
torch.cuda.empty_cache()
# Run inference with ORT
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
inputs = convert_inputs_for_ort(
inputs,
use_buffer_share=args.use_buffer_share,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
)
ep = f"{args.execution_provider.upper()}ExecutionProvider"
if ep == "CUDAExecutionProvider":
ep = (ep, {"device_id": args.rank})
ort_model = ort.InferenceSession(
args.onnx_model_path,
sess_options=ort.SessionOptions(),
providers=[ep],
)
inputs = verify_ort_inputs(ort_model, inputs)
# Add IO bindings for non-CPU execution providers
if args.execution_provider != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
ort_model,
ort_inputs=inputs,
device=args.execution_provider,
device_id=int(args.rank),
use_buffer_share=args.use_buffer_share,
kv_cache_ortvalues=kv_cache_ortvalues,
)
io_binding.synchronize_inputs()
start_time = time.time()
ort_model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
end_time = time.time()
ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
del ort_model
else:
start_time = time.time()
ort_outputs = ort_model.run(None, inputs)
end_time = time.time()
ort_outputs = ort_outputs[0] # Get logits
logger.info(f"ONNX Runtime took {end_time - start_time} s")
# Compare PyTorch and ONNX Runtime accuracy
tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
if not parity:
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
return kv_cache_ortvalues
def get_args(argv: list[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=False,
help="Model name in Hugging Face",
)
parser.add_argument(
"-t",
"--torch_model_directory",
required=False,
default=os.path.join("."),
help="Path to folder containing PyTorch model and associated files if saved on disk",
)
parser.add_argument(
"-o",
"--onnx_model_path",
required=True,
default=os.path.join("."),
help="Path to ONNX model (with external data files saved in the same folder as the model)",
)
parser.add_argument(
"-ep",
"--execution_provider",
required=False,
default="cpu",
choices=["cpu", "cuda", "rocm"],
help="Execution provider to verify parity with",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Print verbose logs",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"-p",
"--use_past_kv",
action="store_true",
help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
)
parser.set_defaults(use_past_kv=False)
parser.add_argument(
"-g",
"--use_buffer_share",
action="store_true",
help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
)
parser.set_defaults(use_buffer_share=False)
parser.add_argument(
"--merged",
action="store_true",
help="Use merged model (i.e. decoder_merged_model.onnx).",
)
parser.set_defaults(merged=False)
parser.add_argument(
"-fp",
"--precision",
required=True,
choices=["int4", "int8", "fp16", "fp32"],
help="Precision of model",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./model_cache",
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)
# The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
parser.add_argument(
"--small_gpu",
action="store_true",
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
)
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32"
if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
else "fp16"
)
return args
def main(argv: list[str] = []): # noqa: B006
args = get_args(argv)
setup_logger(args.verbose)
logger.info(f"Arguments: {args}")
rank = get_rank()
# Load model and config
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
args.rank = rank
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
use_auth_token = args.torch_model_directory == os.path.join(".")
location = args.model_name if use_auth_token else args.torch_model_directory
kv_cache_ortvalues = {}
if not args.merged:
verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
else:
config = llama = None
if not args.small_gpu:
config, llama = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)
# Verify prompt processing in merged model (decoder_model.onnx)
args.use_past_kv = False
kv_cache_ortvalues = verify_parity(
args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
)
# Verify token generation in merged model (decoder_with_past_model.onnx)
args.use_past_kv = True
verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)
if __name__ == "__main__":
seed = 2
np.random.seed(seed)
torch.manual_seed(seed)
main()

View File

@ -0,0 +1,47 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import torch
from dist_settings import barrier, get_rank, get_size
from transformers import AutoConfig, AutoModelForCausalLM
logger = logging.getLogger("")
def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=None):
world_size = get_size()
logger.info(f"world_size: {world_size}")
rank = get_rank()
barrier()
if not os.path.exists(args.cache_dir):
os.makedirs(args.cache_dir, exist_ok=True)
for i in range(world_size):
if i == rank % (world_size):
l_config = AutoConfig.from_pretrained(
location, use_auth_token=auth, cache_dir=args.cache_dir, trust_remote_code=auth
)
l_config.use_cache = True
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=auth,
trust_remote_code=auth,
config=l_config,
torch_dtype=torch_dtype,
cache_dir=args.cache_dir,
)
if world_size > 1:
llama.parallel_model()
if device:
llama.to(device)
llama.eval()
llama.requires_grad_(False)
barrier()
return l_config, llama

View File

@ -0,0 +1,108 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import numpy as np
import torch
from benchmark_helper import create_onnxruntime_session
from datasets import load_dataset
from llama_inputs import get_position_ids
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer
class QuantKVDataLoader:
def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
self.batch_size = 1
self.pad_max = args.pad_max
tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
dataset = load_dataset(args.smooth_quant_dataset, split="train")
dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
self.dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self.collate_batch,
)
self.decoder_model = (
create_onnxruntime_session(
onnx_model_path,
args.execution_provider != "cpu", # use_gpu
provider=args.execution_provider,
verbose=args.verbose,
)
if onnx_model_path
else None
)
def collate_batch(self, batch):
input_ids_batched = []
attention_mask_batched = []
position_ids_batched = []
labels = []
for text in batch:
# Set inputs for model
input_ids = text["input_ids"]
attention_mask = torch.ones(len(input_ids))
position_ids = get_position_ids(attention_mask, use_past_kv=False)
label = len(input_ids) - 1
# Pad input data because all model inputs must have same shape
pad_len = self.pad_max - input_ids.shape[0]
input_ids = pad(input_ids, (0, pad_len), value=1)
attention_mask = pad(attention_mask, (0, pad_len), value=0)
position_ids = pad(position_ids, (0, pad_len), value=0)
input_ids_batched.append(input_ids)
attention_mask_batched.append(attention_mask)
position_ids_batched.append(position_ids)
labels.append(label)
input_ids_batched = torch.vstack(input_ids_batched)
attention_mask_batched = torch.vstack(attention_mask_batched)
position_ids_batched = torch.vstack(position_ids_batched)
labels = torch.tensor(labels)
return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
def __iter__(self):
try:
for (input_ids, attention_mask, position_ids), labels in self.dataloader:
# Inputs for decoder_model.onnx
inputs = {
"input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
"attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
"position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
}
label = labels.detach().cpu().numpy()
if self.decoder_model is not None:
# Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
outputs = self.decoder_model.run(None, inputs)
for i in range(int((len(outputs) - 1) / 2)):
inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
past_sequence_length = inputs["past_key_values.0.key"].shape[2]
inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
inputs["position_ids"] = (
get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
)
# Yield (inputs, label) tuple for Intel's Neural Compressor:
# https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
yield (inputs, label)
except StopIteration:
return