I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,610 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import ast
import datetime
import gc
import logging
import os
import sys
import time
import numpy as np
import psutil
import torch
import whisper
from benchmark_helper import measure_memory, setup_logger
from onnxruntime_extensions import get_library_path
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
import onnxruntime as ort
logger = logging.getLogger(__name__)
def get_inputs(args: argparse.Namespace):
if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
raise Exception("Unable to auto-detect inputs for provided model")
def load_via_ffmpeg():
audio = whisper.load_audio(args.audio_path)
audio = whisper.pad_or_trim(audio)
return audio
def load_via_numpy():
with open(args.audio_path, "rb") as f:
audio = np.asarray(list(f.read()), dtype=np.uint8)
audio = np.array([audio])
return audio
inputs = {
"max_length": args.max_length,
"min_length": args.min_length,
"num_beams": args.num_beams,
"num_return_sequences": args.num_return_sequences,
"length_penalty": args.length_penalty,
"repetition_penalty": args.repetition_penalty,
}
if args.benchmark_type == "ort":
# convert_to_onnx export or ONNX E2E solution created by Olive
for k, v in inputs.items():
inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
if args.has_decoder_input_ids:
inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
if args.has_logits_processor:
inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
if args.has_temperature:
inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
# Measure time taken to load audio file
logger.info(f"Load audio: {args.audio_path}")
load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
time_fn(args, load_audio_fn, args.has_audio_stream)
audio_data = load_audio_fn(args.has_audio_stream)
if args.has_audio_stream:
# ONNX E2E solution created by Olive
inputs["audio_stream"] = audio_data
return inputs
# Measure time taken to get input features
logger.info("Feature extraction: ")
return_type = "np" if args.benchmark_type == "ort" else "pt"
processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
[audio], return_tensors=return_type, sampling_rate=args.sampling_rate
).input_features
time_fn(args, processor_fn, audio_data)
input_features = processor_fn(audio_data)
if args.benchmark_type == "ort":
# convert_to_onnx export
inputs["input_features"] = input_features
return inputs
inputs["inputs"] = input_features.to(
dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
)
inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
inputs["early_stopping"] = True
inputs["use_cache"] = True
if args.decoder_input_ids:
inputs["forced_decoder_ids"] = args.decoder_input_ids
return inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark Whisper from Hugging Face
# 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
# 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
start_time = time.time()
model = AutoModelForSpeechSeq2Seq.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_cache=True,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
sess_options.register_custom_ops_library(get_library_path())
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
if args.tune:
ort.set_default_logger_severity(0)
ort.set_default_logger_verbosity(0)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
start_time = time.time()
model = ORTModelForSpeechSeq2Seq.from_pretrained(
args.hf_ort_dir_path,
provider=provider,
provider_options=provider_options,
session_options=sess_options,
use_io_binding=True, # Avoid memory copy overhead
)
end_time = time.time()
if args.benchmark_type == "ort":
# convert_to_onnx.py export
logger.info(f"Loading model from {args.ort_model_path}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path,
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
torch_device = torch.device(args.target_device)
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type == "ort"
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(warmup_inputs)
logger.info(outputs)
for _ in warmup_range:
fn(warmup_inputs)
# Benchmark
if args.device != "cpu":
torch.cuda.synchronize(torch_device)
start_time = time.time()
bench_range = (
range(args.num_runs)
if args.benchmark_type == "ort"
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
fn(benchmark_inputs)
if args.device != "cpu":
torch.cuda.synchronize(torch_device)
end_time = time.time()
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type != "ort":
logger.info("")
batch_size = 1
latency = (end_time - start_time) / args.num_runs
throughput = batch_size / latency
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} qps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, inputs, model):
# Inference steps to measure
def get_pred_ids(inputs):
# Inference pass with predicted token ids generation
predicted_ids = model.generate(**inputs)
return predicted_ids
def gen_and_dec(inputs):
# Inference pass with generation and decoding
predicted_ids = get_pred_ids(inputs)
transcription = []
for _ in range(args.num_return_sequences):
transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
return predicted_ids, transcription
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_logits(inputs):
# # Inference pass without decoding
# outputs = model(**inputs)
# return outputs
generate_fn = gen_and_dec
if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
if args.benchmark_type == "hf-ort":
# Rename log files per model component and turn profiling off to stop appending to log
new_prefix = new_logname[: -len(".json")]
old_logname = model.encoder.session.end_profiling()
new_logname = new_prefix + "-encoder.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
old_logname = model.decoder.session.end_profiling()
new_logname = new_prefix + "-decoder.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
old_logname = model.decoder_with_past.session.end_profiling()
new_logname = new_prefix + "-decoder-with-past.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating PyTorch...")
time_fn(args, generate_fn, inputs)
predicted_ids, transcription = generate_fn(inputs)
logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
logger.info(f"Transcription: {transcription[0]}")
measure_fn(args, generate_fn, inputs)
def run_ort_inference(args, inputs, model):
def prepare_ort_inputs(inputs, warmup=False):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
logger.error(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
if warmup and args.tune:
inputs["min_length"] = inputs["max_length"]
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
del inputs[unnecessary_input]
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()
for k, v in inputs.items():
io_binding.bind_cpu_input(k, v)
for output in model.get_outputs():
io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
return io_binding
return inputs
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
return io_binding
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
def handle_output(output):
if args.eos_token_id in output:
first_end = np.where(output == args.eos_token_id)[0][0]
return output[: first_end + 1]
return output
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
ort_inputs = prepare_ort_inputs(inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluation
logger.info("\nEvaluating ONNX Runtime...")
ort_evaluate_inputs = ort_inputs
if args.tune:
ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True)
ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs)
time_fn(args, generate_fn, ort_evaluate_inputs)
ort_outputs = generate_fn(ort_inputs)
if args.device != "cpu":
ort_outputs = ort_outputs.copy_outputs_to_cpu()
ort_outputs = ort_outputs[0]
if args.has_audio_stream:
# ONNX E2E model from Olive produces transcribed output
logger.info(f"Transcription: {ort_outputs[0][0]}")
else:
# convert_to_onnx model produces generated ids
actual_output = handle_output(ort_outputs[0][0])
logger.info(f"Generated token length: {len(actual_output)} tokens")
transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
# print to stdout as the output for comparison
print(f"{transcription}")
measure_fn(args, generate_fn, ort_inputs)
def run_inference(args, inputs, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, inputs, model)
elif args.benchmark_type == "ort":
run_ort_inference(args, inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
)
parser.add_argument(
"-p",
"--precision",
type=str,
required=True,
default="fp32",
choices=["int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-model-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Optional args:
parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
# Args for decoding logic
# Required args:
parser.add_argument("--max-length", type=int, default=448)
parser.add_argument("--min-length", type=int, default=0)
parser.add_argument("--num-beams", type=int, default=1)
parser.add_argument("--num-return-sequences", type=int, default=1)
parser.add_argument("--length-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
# Optional args for E2E solution:
parser.add_argument(
"--decoder-input-ids",
type=str,
default="[]",
help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
)
parser.add_argument(
"--logits-processor",
type=int,
default=1,
help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Temperature value for generation.",
)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--tune",
default=False,
action="store_true",
help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel",
)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
args.monitor_type = args.device
# Set runtime properties
if "ort" in args.benchmark_type:
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (
args.execution_provider,
{
"device_id": args.device_id,
"tunable_op_enable": 1,
"tunable_op_tuning_enable": 1 if args.tune else 0,
},
)
args.device = "cuda"
# Check that model paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
if args.benchmark_type == "ort":
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
# Convert decoder_input_ids string to list of ids
# (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
return args
def main():
args = parse_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
config = WhisperConfig.from_pretrained(args.model_name)
processor = WhisperProcessor.from_pretrained(args.model_name)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "processor", processor) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
setattr(args, "has_audio_stream", False) # noqa: B010
setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
# Measure cost to transcribe audio
model = get_model(args)
if args.benchmark_type == "ort":
# Check for optional inputs that could have been added during export
ort_model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
args.has_audio_stream = "audio_stream" in ort_model_inputs
setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
if args.decoder_input_ids == []:
args.decoder_input_ids = [config.decoder_start_token_id]
inputs = get_inputs(args)
run_inference(args, inputs, model)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,532 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import json
import logging
import os
import subprocess
import librosa
import torch
from benchmark_helper import setup_logger
from metrics import BenchmarkRecord
from transformers import WhisperConfig, WhisperProcessor
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--audio-path",
type=str,
required=True,
help="Path to folder of audio files for E2E evaluation",
)
parser.add_argument(
"-l",
"--language",
default=None,
help="Language of audio file",
)
parser.add_argument(
"-t",
"--task",
default=None,
choices=["transcribe", "translate"],
help="Task to complete",
)
parser.add_argument(
"-w",
"--warmup-runs",
type=int,
default=5,
)
parser.add_argument(
"-n",
"--num-runs",
type=int,
default=10,
)
parser.add_argument(
"--hf-pt-eager",
default=False,
action="store_true",
help="Benchmark in PyTorch without `torch.compile`",
)
parser.add_argument(
"--hf-pt-compile",
default=False,
action="store_true",
help="Benchmark in PyTorch with `torch.compile`",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
parser.add_argument(
"--ort-model-path",
type=str,
help="Path to ONNX model for ORT benchmarking",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name in Hugging Face (e.g. openai/whisper-large-v2)",
)
parser.add_argument(
"--precision",
type=str,
required=True,
choices=["int8", "fp16", "fp32"],
help="Precision to run model",
)
parser.add_argument(
"--device",
type=str,
required=True,
choices=["cpu", "cuda", "rocm"],
help="Device to benchmark models",
)
parser.add_argument(
"--device-id",
type=int,
default=0,
help="GPU device ID",
)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Print detailed logs",
)
parser.add_argument(
"--timeout",
type=int,
default=5,
help="Number of mins to attempt the benchmark before moving on",
)
parser.add_argument(
"--log-folder",
type=str,
default=None,
help="Path to folder to save logs and results",
)
parser.add_argument("--tune", default=False, action="store_true")
args = parser.parse_args()
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
log_folder_name = f"./{args.model_size}-{args.precision}"
if not args.log_folder:
args.log_folder = log_folder_name
os.makedirs(args.log_folder, exist_ok=True)
# Convert timeout value to secs
args.timeout *= 60
return args
def process_log_file(device_id, log_file, base_results):
entries = []
# Detect steps in speech pipeline
step = None
load_audio_pattern = "Load audio: "
feat_ext_pattern = "Feature extraction: "
pytorch_pattern = "Evaluating PyTorch..."
onnxruntime_pattern = "Evaluating ONNX Runtime..."
load_audio_latency_s, load_audio_throughput_s = None, None
feat_ext_latency_s, feat_ext_throughput_s = None, None
token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None
throughput, memory = None, None
# Detect metrics
latency_pattern = "Latency: "
throughput_pattern = "Throughput: "
token_length_pattern = "Generated token length: "
memory_pattern = "peak="
with open(log_file) as f:
for input_line in f:
line = input_line.replace("\n", "")
# Get step in speech recognition pipeline
if load_audio_pattern in line:
step = "load-audio"
elif feat_ext_pattern in line:
step = "feature-extraction"
elif pytorch_pattern in line or onnxruntime_pattern in line:
step = "process"
# Check metrics
if latency_pattern in line:
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
elif throughput_pattern in line:
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
if step == "load-audio":
load_audio_latency_s, load_audio_throughput_s = latency_s, throughput
step = None
if step == "feature-extraction":
feat_ext_latency_s, feat_ext_throughput_s = latency_s, throughput
step = None
elif token_length_pattern in line:
token_length = int(line[len(token_length_pattern) : line.rfind(" ")])
per_token_latency_s = latency_s / token_length
per_token_latency_ms = per_token_latency_s * 1000
elif memory_pattern in line:
if "CPU" in line:
# Example format for log entry:
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
else:
# Example format for log entry:
# GPU memory usage: before=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1638.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}, peak=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1780.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}]
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
usage = json.loads(peak)[device_id]["max_used_MB"]
memory = float(usage) / 1000
# Calculate real-time factor (RTF):
# RTF = total latency / audio duration
total_latency = (
(load_audio_latency_s if load_audio_latency_s else 0)
+ (feat_ext_latency_s if feat_ext_latency_s else 0)
+ (latency_s if latency_s else 0)
)
audio_duration = base_results[-1]
rtf = (total_latency / audio_duration) if audio_duration else -1
logger.info(f"Total latency: {total_latency} s")
logger.info(f"Audio duration: {audio_duration} s")
logger.info(f"Real-time factor: {rtf}")
# Append log entry to list of entries
entry = base_results + [ # noqa: RUF005
token_length,
load_audio_latency_s,
load_audio_throughput_s,
feat_ext_latency_s if feat_ext_latency_s else -1,
feat_ext_throughput_s if feat_ext_throughput_s else -1,
latency_s,
per_token_latency_ms,
throughput,
memory,
rtf,
]
entries.append(entry)
return entries
def save_results(results, filename):
import pandas as pd
df = pd.DataFrame(
results,
columns=[
"Warmup Runs",
"Measured Runs",
"Model Name",
"Engine",
"Precision",
"Device",
"Audio File",
"Duration (s)",
"Token Length",
"Load Audio Latency (s)",
"Load Audio Throughput (qps)",
"Feature Extractor Latency (s)",
"Feature Extractor Throughput (qps)",
"Latency (s)",
"Per Token Latency (ms/token)",
"Throughput (qps)",
"Memory (GB)",
"Real Time Factor (RTF)",
],
)
# Set column types
df["Warmup Runs"] = df["Warmup Runs"].astype("int")
df["Measured Runs"] = df["Measured Runs"].astype("int")
df["Duration (s)"] = df["Duration (s)"].astype("float")
df["Token Length"] = df["Token Length"].astype("int")
df["Load Audio Latency (s)"] = df["Load Audio Latency (s)"].astype("float")
df["Load Audio Throughput (qps)"] = df["Load Audio Throughput (qps)"].astype("float")
df["Feature Extractor Latency (s)"] = df["Feature Extractor Latency (s)"].astype("float")
df["Feature Extractor Throughput (qps)"] = df["Feature Extractor Throughput (qps)"].astype("float")
df["Latency (s)"] = df["Latency (s)"].astype("float")
df["Per Token Latency (ms/token)"] = df["Per Token Latency (ms/token)"].astype("float")
df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
df["Real Time Factor (RTF)"] = df["Real Time Factor (RTF)"].astype("float")
# get package name and version
import pkg_resources
installed_packages = pkg_resources.working_set
installed_packages_list = sorted(
[
f"{i.key}=={i.version}"
for i in installed_packages
if i.key in ["ort-nightly-gpu", "ort-nightly", "onnxruntime", "onnxruntime-gpu"]
]
)
ort_pkg_name = ""
ort_pkg_version = ""
if installed_packages_list:
ort_pkg_name = installed_packages_list[0].split("==")[0]
ort_pkg_version = installed_packages_list[0].split("==")[1]
# Save results to csv with standard format
records = []
for _, row in df.iterrows():
if row["Engine"] == "onnxruntime":
record = BenchmarkRecord(
row["Model Name"], row["Precision"], row["Engine"], row["Device"], ort_pkg_name, ort_pkg_version
)
else:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], row["Engine"], row["Device"], torch.__name__, torch.__version__
)
record.config.customized["audio_file"] = row["Audio File"]
record.config.warmup_runs = row["Warmup Runs"]
record.config.measured_runs = row["Measured Runs"]
record.metrics.customized["duration"] = row["Duration (s)"]
record.metrics.customized["token_length"] = row["Token Length"]
record.metrics.customized["load_audio_latency"] = row["Load Audio Latency (s)"]
record.metrics.customized["load_audio_throughput"] = row["Load Audio Throughput (qps)"]
record.metrics.customized["feature_extractor_latency_s"] = row["Feature Extractor Latency (s)"]
record.metrics.customized["feature_extractor_throughput_qps"] = row["Feature Extractor Throughput (qps)"]
record.metrics.customized["per_token_latency_ms"] = row["Per Token Latency (ms/token)"]
record.metrics.customized["rtf"] = row["Real Time Factor (RTF)"]
record.metrics.latency_ms_mean = row["Latency (s)"] * 1000
record.metrics.throughput_qps = row["Throughput (qps)"]
record.metrics.max_memory_usage_GB = row["Memory (GB)"]
records.append(record)
BenchmarkRecord.save_as_csv(filename, records)
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
logger.info(f"Results saved in {filename}!")
def benchmark(args, benchmark_cmd, engine, audio_file, duration):
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
log_path = os.path.join(args.log_folder, log_filename)
with open(log_path, "w") as log_file:
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
try:
process.wait(args.timeout)
except subprocess.TimeoutExpired:
process.kill()
# Create entries for csv
logger.info("Gathering data from log files...")
base_results = [
args.warmup_runs,
args.num_runs,
args.model_name,
engine,
args.precision,
args.device,
audio_file,
duration,
]
results = process_log_file(args.device_id, log_path, base_results)
return results
def main():
args = get_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
config = WhisperConfig.from_pretrained(args.model_name)
processor = WhisperProcessor.from_pretrained(args.model_name)
# Calculate forced decoder input ids
hf_forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
ort_forced_decoder_ids = [config.decoder_start_token_id] + list( # noqa: RUF005
map(lambda token_id: token_id[1], hf_forced_decoder_ids)
)
hf_decoder_input_ids_cmd = (
["--decoder-input-ids", str(hf_forced_decoder_ids)] if args.language and args.task else []
)
ort_decoder_input_ids_cmd = (
["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
)
ort_tune_cmd = ["--tune"] if args.tune else []
all_results = []
for audio_file in os.listdir(args.audio_path):
audio_path = os.path.join(args.audio_path, audio_file)
try:
duration = librosa.get_duration(path=audio_path)
except Exception as e:
duration = -1
logger.warning(f"An error occurred while trying to calculate the audio duration: {e}", exc_info=True)
logger.warning(
f"If you get an error that says:\n\tsoundfile.LibsndfileError: Error opening '{audio_file}': File contains data in an unknown format.\nyou may not have installed `ffmpeg` in addition to installing `librosa`."
)
logger.info(f"Testing {audio_path}...")
# Benchmark PyTorch without torch.compile
if args.hf_pt_eager:
benchmark_cmd = [ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"hf-pt-eager",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
] + hf_decoder_input_ids_cmd
logger.info("Benchmark PyTorch without torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration)
all_results.extend(results)
# Benchmark PyTorch with torch.compile
if args.hf_pt_compile:
benchmark_cmd = [ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"hf-pt-compile",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
] + hf_decoder_input_ids_cmd
logger.info("Benchmark PyTorch with torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration)
all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
if args.hf_ort_dir_path:
benchmark_cmd = [ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"hf-ort",
"--hf-ort-dir-path",
args.hf_ort_dir_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
] + hf_decoder_input_ids_cmd
logger.info("Benchmark Optimum + ONNX Runtime")
results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration)
all_results.extend(results)
# Benchmark ONNX Runtime
if args.ort_model_path:
benchmark_cmd = (
[ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"ort",
"--ort-model-path",
args.ort_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
]
+ ort_decoder_input_ids_cmd
+ ort_tune_cmd
)
logger.info("Benchmark ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
all_results.extend(results)
csv_file = f"{args.model_size}-{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_results, os.path.join(args.log_folder, csv_file))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,536 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import copy
import logging
import os
import torch
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
from whisper_chain import chain_model
from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
from onnxruntime import quantization
logger = logging.getLogger("")
PROVIDERS = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
conversion_args = parser.add_argument_group("Conversion Process Args")
optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
quant_args = parser.add_argument_group("INT8 Quantization Args")
#################################
# Conversion options for Whisper
#################################
conversion_args.add_argument(
"-m",
"--model_name_or_path",
required=False,
default=PRETRAINED_WHISPER_MODELS[0],
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
)
conversion_args.add_argument(
"--model_impl",
required=False,
default="hf",
choices=["hf", "openai"],
type=str,
help="Select implementation for export of encoder and decoder subgraphs",
)
conversion_args.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
conversion_args.add_argument(
"--output",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Output directory",
)
conversion_args.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
conversion_args.set_defaults(optimize_onnx=False)
conversion_args.add_argument(
"--use_gpu",
required=False,
action="store_true",
help="Use GPU for model inference",
)
conversion_args.set_defaults(use_gpu=False)
conversion_args.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
)
conversion_args.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int64 instead of int32 for input_ids and attention_mask.",
)
conversion_args.set_defaults(use_int64_inputs=False)
conversion_args.add_argument(
"--disable_auto_mixed_precision",
required=False,
action="store_true",
help="Use pure fp16 instead of mixed precision",
)
conversion_args.set_defaults(disable_auto_mixed_precision=False)
conversion_args.add_argument(
"-r",
"--provider",
required=False,
type=str,
default="cpu",
choices=list(PROVIDERS.keys()),
help="Provider to benchmark. Default is CPUExecutionProvider.",
)
conversion_args.add_argument(
"--verbose",
required=False,
action="store_true",
help="Enable verbose logging",
)
conversion_args.set_defaults(verbose=False)
conversion_args.add_argument(
"-e",
"--use_external_data_format",
required=False,
action="store_true",
help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
)
conversion_args.set_defaults(use_external_data_format=False)
conversion_args.add_argument(
"-w",
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing ONNX model",
)
conversion_args.set_defaults(overwrite=False)
conversion_args.add_argument(
"--separate_encoder_and_decoder_init",
required=False,
action="store_true",
help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
)
conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
conversion_args.add_argument(
"--no_beam_search_op",
required=False,
action="store_true",
help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
)
conversion_args.set_defaults(no_beam_search_op=False)
conversion_args.add_argument(
"--state_dict_path",
type=str,
default="",
help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)
#############################################################
# Optional inputs for Whisper
# (listed below in the order that WhisperBeamSearch expects)
#############################################################
optional_inputs.add_argument(
"-v",
"--use_vocab_mask",
required=False,
action="store_true",
help="Use vocab_mask as an extra graph input to enable specific logits processing",
)
optional_inputs.set_defaults(use_vocab_mask=False)
optional_inputs.add_argument(
"-u",
"--use_prefix_vocab_mask",
required=False,
action="store_true",
help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
)
optional_inputs.set_defaults(use_prefix_vocab_mask=False)
optional_inputs.add_argument(
"-f",
"--use_forced_decoder_ids",
required=False,
action="store_true",
help="Use decoder_input_ids as an extra graph input to the beam search op",
)
optional_inputs.set_defaults(use_forced_decoder_ids=False)
optional_inputs.add_argument(
"-l",
"--use_logits_processor",
required=False,
action="store_true",
help="Use logits_processor as an extra graph input to enable specific logits processing",
)
optional_inputs.set_defaults(use_specific_logits_processor=False)
optional_inputs.add_argument(
"--collect_cross_qk",
required=False,
action="store_true",
help="Beam search model collect stacked cross QK.",
)
optional_inputs.set_defaults(collect_cross_qk=False)
optional_inputs.add_argument(
"--extra_decoding_ids",
required=False,
action="store_true",
help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
)
optional_inputs.set_defaults(extra_decoding_ids=False)
optional_inputs.add_argument(
"-t",
"--use_temperature",
required=False,
action="store_true",
help="Use temperature as an extra graph input for the WhisperBeamSearch op",
)
optional_inputs.set_defaults(use_temperature=False)
optional_inputs.add_argument(
"--no_repeat_ngram_size",
type=int,
default=0,
help="default to 0",
)
#############################################################
# Optional outputs for Whisper
# (listed below in the order that WhisperBeamSearch expects)
#############################################################
optional_outputs.add_argument(
"--output_sequence_scores",
required=False,
action="store_true",
help="Beam search model output scores for each generated sequence.",
)
optional_outputs.set_defaults(output_sequence_scores=False)
optional_outputs.add_argument(
"--output_scores",
required=False,
action="store_true",
help="Beam search model output scores over vocab per generated token.",
)
optional_outputs.set_defaults(output_scores=False)
optional_outputs.add_argument(
"--output_cross_qk",
required=False,
action="store_true",
help="Beam search model output collected qk as output. Also hint collect_cross_qk",
)
optional_outputs.set_defaults(output_cross_qk=False)
optional_outputs.add_argument(
"--cross_qk_onnx_model",
required=False,
type=str,
default=None,
help="The model which consumes cross_qk outputs.",
)
optional_outputs.add_argument(
"--output_no_speech_probs",
required=False,
action="store_true",
help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
)
optional_outputs.set_defaults(output_no_speech_probs=False)
###################################
# Quantization options for Whisper
###################################
quant_args.add_argument(
"--quantize_embedding_layer",
required=False,
action="store_true",
help="Quantize MatMul, GEMM, and Gather.",
)
quant_args.set_defaults(quantize_embedding_layer=False)
quant_args.add_argument(
"--quantize_per_channel",
required=False,
action="store_true",
help="Quantize weights per each channel.",
)
quant_args.set_defaults(quantize_per_channel=False)
quant_args.add_argument(
"--quantize_reduce_range",
required=False,
action="store_true",
help="Quantize weights with 7 bits.",
)
quant_args.set_defaults(quantize_reduce_range=False)
args = parser.parse_args(argv)
args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
return args
def export_onnx_models(
model_name_or_path,
model_impl,
cache_dir,
output_dir,
use_gpu,
use_external_data_format,
optimize_onnx,
precision,
verbose,
use_forced_decoder_ids: bool = False,
merge_encoder_and_decoder_init: bool = True,
overwrite: bool = False,
disable_auto_mixed_precision: bool = False,
use_int32_inputs: bool = True,
quantize_embedding_layer: bool = False,
quantize_per_channel: bool = False,
quantize_reduce_range: bool = False,
state_dict_path: str = "",
provider: str = "cpu",
):
device = torch.device("cuda:0" if use_gpu else "cpu")
models = WhisperHelper.load_model(
model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path
)
config = models["decoder"].config
if (not use_external_data_format) and (config.num_hidden_layers > 24):
logger.info("Try use_external_data_format when model size > 2GB")
output_paths = []
for name, model in models.items():
print(f"========> Handling {name} model......")
model.to(device)
filename_suffix = "_" + name
onnx_path = WhisperHelper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix,
new_folder=False,
)
if overwrite or not os.path.exists(onnx_path):
logger.info(f"Exporting ONNX model to {onnx_path}")
# We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
device_to_export = torch.device("cpu")
cloned_model = copy.deepcopy(model).to(device_to_export)
WhisperHelper.export_onnx(
cloned_model,
device_to_export,
onnx_path,
verbose,
use_external_data_format,
use_int32_inputs=use_int32_inputs,
)
else:
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
# Optimize ONNX graph. Note that we have not implemented graph optimization for Whisper yet.
if optimize_onnx or precision != Precision.FLOAT32:
output_path = WhisperHelper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix + "_" + str(precision),
new_folder=False,
)
if overwrite or not os.path.exists(output_path):
if optimize_onnx:
logger.info(f"Optimizing model to {output_path}")
WhisperHelper.optimize_onnx(
onnx_path,
output_path,
precision == Precision.FLOAT16,
config.encoder_attention_heads,
config.d_model,
use_external_data_format,
auto_mixed_precision=not disable_auto_mixed_precision,
use_gpu=use_gpu,
provider=provider,
)
onnx_path = output_path
if precision == Precision.INT8:
quantization.quantize_dynamic(
onnx_path,
output_path,
op_types_to_quantize=(
["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
),
use_external_data_format=use_external_data_format,
per_channel=quantize_per_channel,
reduce_range=quantize_reduce_range,
extra_options={"MatMulConstBOnly": True},
)
else:
logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
else:
output_path = onnx_path
ort_session = create_onnxruntime_session(
output_path,
use_gpu=use_gpu,
provider=provider,
)
assert ort_session is not None
output_paths.append(output_path)
return output_paths
def main(argv=None):
args = parse_arguments(argv)
setup_logger(args.verbose)
logger.info(f"Arguments:{args}")
cache_dir = args.cache_dir
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
prepare_environment(cache_dir, output_dir, args.use_gpu)
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 requires --use_gpu"
if args.optimize_onnx:
logger.warning("Applying graph optimization for Whisper...")
output_paths = export_onnx_models(
args.model_name_or_path,
args.model_impl,
cache_dir,
output_dir,
args.use_gpu,
args.use_external_data_format,
args.optimize_onnx,
args.precision,
args.verbose,
args.use_forced_decoder_ids,
not args.separate_encoder_and_decoder_init,
args.overwrite,
args.disable_auto_mixed_precision,
not args.use_int64_inputs,
args.quantize_embedding_layer,
args.quantize_per_channel,
args.quantize_reduce_range,
args.state_dict_path,
args.provider,
)
max_diff = 0
if not args.no_beam_search_op:
logger.info("Chaining model ... :")
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
output_dir,
args.model_name_or_path,
suffix="_beamsearch",
new_folder=False,
)
for path in output_paths:
if "encoder_decoder" in path:
args.encoder_path = path
elif "decoder" in path:
args.decoder_path = path
chain_model(args)
output_paths.append(args.beam_model_output_dir)
# Check chained model
ort_session = create_onnxruntime_session(
args.beam_model_output_dir,
use_gpu=args.use_gpu,
provider=args.provider,
)
device = torch.device("cuda:0" if args.use_gpu else "cpu")
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
# Verify batched decoding with prompts for whisper openai implementation
if args.model_impl == "openai" and args.use_forced_decoder_ids:
max_diff = WhisperHelper.verify_onnx(
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
)
else:
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:
logger.info("PyTorch and ONNX Runtime results are close")
except Exception as e:
logger.warning(
f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
)
# Remove extra ONNX models saved in output directory
for fle in os.listdir(output_dir):
if "_beamsearch" not in fle:
os.remove(os.path.join(output_dir, fle))
output_paths = [args.beam_model_output_dir]
logger.info(f"Done! Outputs: {output_paths}")
return max_diff
if __name__ == "__main__":
main()

View File

@ -0,0 +1,326 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import onnx
from benchmark_helper import Precision
from convert_generation import (
get_shared_initializers,
update_decoder_subgraph_output_cross_attention,
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
)
from onnx import TensorProto, helper
from transformers import WhisperConfig, WhisperTokenizer
logger = logging.getLogger(__name__)
def verify_inputs(beam_inputs, graph_inputs):
# Verify that ONNX graph's inputs match beam search op's inputs
beam_required_inputs = list(filter(lambda beam_input: beam_input, beam_inputs))
assert len(graph_inputs) == len(beam_required_inputs)
for graph_input, beam_input in zip(graph_inputs, beam_required_inputs):
# Check if graph_input is in beam_input to handle beam_input names with the "_fp16" suffix
assert graph_input.name in beam_input
def clean_list(arr, remove_all_strings=True):
if remove_all_strings:
# Remove all empty strings in list
return list(filter(lambda elm: elm != "", arr))
# Remove empty strings at end of list
while len(arr) > 0:
if arr[-1] == "":
arr.pop()
else:
break
return arr
def chain_model(args):
# Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op
encoder_model = onnx.load_model(args.encoder_path, load_external_data=True)
encoder_model.graph.name = "encoderdecoderinit subgraph"
decoder_model = onnx.load_model(args.decoder_path, load_external_data=True)
decoder_model.graph.name = "decoder subgraph"
config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
# Create inputs/outputs for WhisperBeamSearch op
temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
beam_inputs = [
"input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features",
"max_length",
"min_length",
"num_beams",
"num_return_sequences",
"length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty",
"repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty",
"vocab_mask" if args.use_vocab_mask else "",
"prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
"", # attention mask
"decoder_input_ids" if args.use_forced_decoder_ids else "",
"logits_processor" if args.use_logits_processor else "",
"cross_qk_layer_head" if args.collect_cross_qk else "",
"extra_decoding_ids" if args.extra_decoding_ids else "",
temperature_name if args.use_temperature else "",
]
sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores"
scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores"
beam_outputs = [
"sequences",
sequence_scores_name if args.output_sequence_scores else "",
scores_name if args.output_scores else "",
"cross_qk" if args.collect_cross_qk else "",
"no_speech_probs_beam" if args.output_no_speech_probs else "",
]
graph_nodes = []
if args.precision == Precision.FLOAT16:
input_features_cast_node = helper.make_node(
"Cast",
inputs=["input_features"],
outputs=["input_features_fp16"],
name="CastInputFeaturesToFp16",
to=TensorProto.FLOAT16,
)
len_pen_cast_node = helper.make_node(
"Cast",
inputs=["length_penalty"],
outputs=["length_penalty_fp16"],
name="CastLengthPenaltyToFp16",
to=TensorProto.FLOAT16,
)
rep_pen_cast_node = helper.make_node(
"Cast",
inputs=["repetition_penalty"],
outputs=["repetition_penalty_fp16"],
name="CastRepetitionPenaltyToFp16",
to=TensorProto.FLOAT16,
)
graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node])
if args.use_temperature:
temp_cast_node = helper.make_node(
"Cast",
inputs=["temperature"],
outputs=["temperature_fp16"],
name="temperature_to_fp16",
to=TensorProto.FLOAT16,
)
graph_nodes.append(temp_cast_node)
if args.output_sequence_scores:
output_sequence_scores_cast_node = helper.make_node(
"Cast",
inputs=["sequence_scores_fp16"],
outputs=["sequence_scores"],
name="CastOutputSequenceScoresToFp32",
to=TensorProto.FLOAT,
)
graph_nodes.append(output_sequence_scores_cast_node)
if args.output_scores:
output_scores_cast_node = helper.make_node(
"Cast",
inputs=["scores_fp16"],
outputs=["scores"],
name="CastScoresToFp32",
to=TensorProto.FLOAT,
)
graph_nodes.append(output_scores_cast_node)
# Create WhisperBeamSearch op
beam_search_attrs = [
helper.make_attribute("eos_token_id", config.eos_token_id),
helper.make_attribute("pad_token_id", config.pad_token_id),
helper.make_attribute(
"decoder_start_token_id", config.decoder_start_token_id
), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0]
helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]),
helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]),
helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]),
(
helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0])
if args.output_no_speech_probs
else ""
),
helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]),
helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]),
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
helper.make_attribute("early_stopping", True),
helper.make_attribute("model_type", 2),
helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "",
]
node = helper.make_node(
"WhisperBeamSearch",
inputs=clean_list(beam_inputs, remove_all_strings=False),
outputs=clean_list(beam_outputs, remove_all_strings=False),
name="BeamSearch",
domain="com.microsoft",
)
node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True))
# Graph inputs
input_features = helper.make_tensor_value_info(
"input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
)
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size])
prefix_vocab_mask = helper.make_tensor_value_info(
"prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size]
)
decoder_input_ids = helper.make_tensor_value_info(
"decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
)
logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2])
extra_decoding_ids = helper.make_tensor_value_info(
"extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"]
)
temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
graph_inputs = clean_list(
[
input_features,
max_length,
min_length,
num_beams,
num_return_sequences,
length_penalty,
repetition_penalty,
vocab_mask if args.use_vocab_mask else "",
prefix_vocab_mask if args.use_prefix_vocab_mask else "",
decoder_input_ids if args.use_forced_decoder_ids else "",
logits_processor if args.use_logits_processor else "",
cross_qk_layer_head if args.collect_cross_qk else "",
extra_decoding_ids if args.extra_decoding_ids else "",
temperature if args.use_temperature else "",
]
)
# Graph outputs
sequences = helper.make_tensor_value_info(
"sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
)
sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"])
scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"])
cross_qk = helper.make_tensor_value_info(
"cross_qk",
TensorProto.FLOAT,
["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"],
)
no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"])
graph_outputs = clean_list(
[
sequences,
sequence_scores if args.output_sequence_scores else "",
scores if args.output_scores else "",
cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "",
no_speech_probs if args.output_no_speech_probs else "",
]
)
# Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference
if hasattr(args, "use_gpu") and args.use_gpu:
if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
else:
logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph")
if hasattr(args, "collect_cross_qk") and args.collect_cross_qk:
update_decoder_subgraph_output_cross_attention(decoder_model.graph)
# Initializers/opsets
# Delete shared data between decoder/encoder and move to larger graph initializers
initializers = get_shared_initializers(encoder_model, decoder_model)
node.attribute.extend(
[
helper.make_attribute("decoder", decoder_model.graph),
helper.make_attribute("encoder", encoder_model.graph),
]
)
opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
graph_nodes.append(node)
if args.output_no_speech_probs:
prob_cast_node = helper.make_node(
"Cast",
inputs=["no_speech_probs_beam"],
outputs=["no_speech_probs"],
name="no_speech_probs_cast_to_fp32",
to=TensorProto.FLOAT,
)
graph_nodes.append(prob_cast_node)
# Make graph with WhisperBeamSearch op
beam_graph = helper.make_graph(
graph_nodes,
name="WhisperBeamSearch Graph",
inputs=graph_inputs,
outputs=graph_outputs,
initializer=initializers,
)
beam_graph_input_names = [gi.name for gi in graph_inputs]
beam_graph_output_names = [go.name for go in graph_outputs]
if args.cross_qk_onnx_model:
post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True)
post_qk_graph = post_qk_model.graph
beam_graph.initializer.extend(post_qk_graph.initializer)
beam_graph.node.extend(post_qk_graph.node)
# If tensor from cross_qk_onnx_model has same name as tensor in beamsearch graph, treat them as same tensor.
# User should notice this rule when provide cross_qk_onnx_model to append to the beamsearch node.
for pgi in post_qk_graph.input:
if (
(pgi.name not in beam_graph_input_names)
and (pgi.name not in beam_graph_output_names)
and (pgi.name != "cross_qk")
):
beam_graph.input.extend([pgi])
beam_graph.output.extend(post_qk_graph.output)
# Verify graph's inputs match beam search's inputs
verify_inputs(beam_inputs, graph_inputs)
assert decoder_model.ir_version == encoder_model.ir_version
logger.info(f"Using IR version {decoder_model.ir_version} for chained model")
# Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model
beam_model = helper.make_model_gen_version(
beam_graph,
producer_name="onnxruntime.transformers",
opset_imports=opset_import,
ir_version=decoder_model.ir_version,
)
# Save WhisperBeamSearch graph and external data
if os.path.isfile(args.beam_model_output_dir):
logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
os.remove(args.beam_model_output_dir)
os.remove(args.beam_model_output_dir + ".data")
onnx.save(
beam_model,
args.beam_model_output_dir,
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=True,
location=f"{os.path.basename(args.beam_model_output_dir)}.data",
)
onnx.checker.check_model(args.beam_model_output_dir, full_check=True)

View File

@ -0,0 +1,402 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Union
import numpy
import onnx
import torch
from io_binding_helper import TypeHelper
from models.t5.past_helper import PastKeyValuesHelper
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
from transformers import WhisperConfig, file_utils
from whisper_openai_helper import WhisperDecoderInitOpenai
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class WhisperDecoderInit(torch.nn.Module):
"""A Whisper decoder to create initial past key values.
This model is only called once during starting decoding.
"""
def __init__(
self,
decoder: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: Optional[int] = None,
):
super().__init__()
self.decoder = decoder
self.config = config
self.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.FloatTensor,
):
encoder_outputs = file_utils.ModelOutput()
encoder_outputs["last_hidden_state"] = encoder_hidden_states
encoder_outputs["hidden_states"] = None
encoder_outputs["attentions"] = None
out = self.decoder.model(
None,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
past_key_values=None,
use_cache=True,
return_dict=True,
)
logits = self.decoder.proj_out(out[0])
return logits, out.past_key_values, out.encoder_last_hidden_state
class WhisperDecoder(torch.nn.Module):
"""A Whisper decoder with past key values"""
def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None):
super().__init__()
self.decoder = decoder
self.config = config
self.model_impl = model_impl
if model is not None:
self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder)
def forward(self, decoder_input_ids, *past):
encoder_outputs = file_utils.ModelOutput()
dummy_encoder_hidden_states = torch.randn((decoder_input_ids.shape[0], 3000, int(self.config.d_model)))
encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states
encoder_outputs["hidden_states"] = dummy_encoder_hidden_states
encoder_outputs["attentions"] = None
if self.model_impl == "openai":
dummy_encoder_hidden_states.unsqueeze(0)
dec_out, present = self.whisper_decoder_openai_init(
decoder_input_ids, dummy_encoder_hidden_states, past=past
)
return dec_out, present
if len(past) == 0:
past_key_values = None
else:
past_key_values = PastKeyValuesHelper.back_group_by_layer(past)
decoder_out = self.decoder(
None,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
logits = decoder_out[0]
present_self, _ = PastKeyValuesHelper.group_by_self_and_cross(decoder_out.past_key_values)
return logits, present_self
class WhisperDecoderInputs:
def __init__(
self,
decoder_input_ids,
past_key_values=None,
):
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
@staticmethod
def create_dummy(
config: WhisperConfig,
batch_size: int,
encode_sequence_length: int,
past_decode_sequence_length: int,
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
model_impl: str = "hf",
): # -> WhisperDecoderInputs:
"""Create dummy inputs for WhisperDecoder.
Args:
decoder: decoder
batch_size (int): batch size
encode_sequence_length (int): sequence length of input_ids for encoder
past_decode_sequence_length (int): past sequence length of input_ids for decoder
device (torch.device): device of output tensors
float16 (bool): whether the model uses float32 or float16 in input
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
Returns:
WhisperDecoderInputs: dummy inputs for decoder
"""
num_attention_heads: int = config.encoder_attention_heads
num_layers: int = config.decoder_layers # + config.encoder_layers
vocab_size: int = config.vocab_size
# Use head_size, use hidden_size / num_attention_heads here.
# For example, whisper-large, d_model=1280 and num_heads=20
head_size: int = config.d_model // config.encoder_attention_heads
sequence_length: int = 1 # fixed for decoding
decoder_input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=(torch.int32 if use_int32_inputs else torch.int64),
device=device,
)
float_type = torch.float16 if float16 else torch.float32
if past_decode_sequence_length > 0:
self_attention_past_shape = [
batch_size,
num_attention_heads,
past_decode_sequence_length,
head_size,
]
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
head_size,
]
past = []
for _ in range(2 * num_layers):
past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
for _ in range(2 * num_layers):
past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
else:
past = None
return WhisperDecoderInputs(decoder_input_ids, past)
def to_list(self) -> List:
input_list = [self.decoder_input_ids]
if self.past_key_values:
input_list.extend(self.past_key_values)
return input_list
def to_fp32(self):
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
return WhisperDecoderInputs(
self.decoder_input_ids.clone(),
past,
)
class WhisperDecoderHelper:
@staticmethod
def export_onnx(
decoder: WhisperDecoder,
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
Args:
decoder (Union[WhisperDecoder, WhisperDecoderNoPastState]): decoder object
device (torch.device): device of decoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs
"""
assert isinstance(decoder, (WhisperDecoder, WhisperDecoderInit))
inputs = WhisperDecoderInputs.create_dummy(
decoder.config,
batch_size=2,
encode_sequence_length=3000,
past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
model_impl=decoder.model_impl,
)
input_list = inputs.to_list()
# Fix past disappearing bug - duplicate first past entry
# input_list.insert(2, input_list[2])
past_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=True)
present_self_names = present_names[: 2 * decoder.config.decoder_layers]
input_past_names = past_names if isinstance(decoder, WhisperDecoder) else []
output_present_names = present_self_names if isinstance(decoder, WhisperDecoder) else present_names
output_names = ["logits", *output_present_names]
# Shape of input tensors (sequence_length==1):
# input_ids: (batch_size, sequence_length)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
# Shape of output tensors:
# logits: (batch_size, sequence_length, vocab_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
input_names = ["input_ids"]
input_names.extend(input_past_names)
dynamic_axes = {
"input_ids": {0: "batch_size"},
"encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length / 2"},
"logits": {0: "batch_size", 1: "sequence_length"},
}
for name in input_past_names:
dynamic_axes[name] = {
0: "batch_size",
2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
}
for name in output_present_names:
if "cross" in name:
dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
else: # self attention past state
if isinstance(decoder, WhisperDecoder):
dynamic_axes[name] = {
0: "batch_size",
2: "past_decode_sequence_length + 1",
}
else:
dynamic_axes[name] = {
0: "batch_size",
# 2: 'sequence_length'
}
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
decoder,
args=tuple(input_list),
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
if use_external_data_format:
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: WhisperDecoderInputs):
"""Run inference of ONNX model."""
logger.debug("start onnxruntime_inference")
ort_inputs = {
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
}
if inputs.past_key_values:
assert len(inputs.past_key_values) % 4 == 0
num_layers = int(len(inputs.past_key_values) / 4)
past_names = PastKeyValuesHelper.get_past_names(num_layers)
for i, past_tensor in enumerate(inputs.past_key_values):
ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs
@staticmethod
def verify_onnx(
model: Union[WhisperDecoder, WhisperDecoderInit],
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
test_cases_max_diff = []
for (
batch_size,
encode_sequence_length,
past_decode_sequence_length,
) in test_cases[:max_cases]:
if isinstance(model, WhisperDecoderInit):
dec_seq_len = 0
else:
dec_seq_len = past_decode_sequence_length
inputs = WhisperDecoderInputs.create_dummy(
model.config,
batch_size,
encode_sequence_length,
dec_seq_len,
device=device,
float16=float16,
use_int32_inputs=use_int32_inputs,
)
# We use fp32 PyTroch model as baseline even when ONNX model is fp16
input_list = inputs.to_fp32().to_list()
# Run inference of PyTorch model
with torch.no_grad():
torch_outputs = model(*input_list)
ort_outputs = WhisperDecoderHelper.onnxruntime_inference(ort_session, inputs)
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
max_diff_all = max_diff
logger.debug(f"logits max_diff={max_diff}")
for i in range(2 * model.config.num_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
if isinstance(model, WhisperDecoderInit):
for i in range(2 * model.config.num_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
test_cases_max_diff.append(max_diff_all)
logger.info(
"batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
batch_size,
encode_sequence_length,
past_decode_sequence_length,
max_diff_all,
)
return max_diff_all

View File

@ -0,0 +1,164 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
from typing import List
import numpy
import onnx
import torch
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
from transformers import WhisperConfig
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class WhisperEncoder(torch.nn.Module):
"""Whisper encoder outputs only the last hidden state"""
def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"):
super().__init__()
self.encoder = encoder
self.config = config
self.model_impl = model_impl
def forward(self, input_features):
if self.model_impl == "openai":
return self.encoder(input_features)
return self.encoder.model.encoder(input_features)[0]
class WhisperEncoderInputs:
def __init__(self, input_features):
self.input_ids: torch.LongTensor = input_features
@staticmethod
def create_dummy(
batch_size: int,
sequence_length: int,
feature_size: int,
device: torch.device,
use_int32_inputs: bool = False,
):
"""Create dummy inputs for Whisper encoder.
Args:
batch_size (int): batch size
sequence_length (int): sequence length
feature_size (int): feature size for spectrogram input
device (torch.device): device of output tensors
Returns:
WhisperEncoderInputs: dummy inputs for encoder
"""
input_features = torch.randn(
size=(batch_size, feature_size, sequence_length),
device=device,
)
return WhisperEncoderInputs(input_features)
def to_list(self) -> List:
if self.input_ids is None:
return []
return [self.input_ids]
class WhisperEncoderHelper:
@staticmethod
def export_onnx(
encoder,
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export encoder to ONNX
Args:
encoder (WhisperEncoder): encoder object
device (torch.device): device of encoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
"""
config = encoder.config
encoder_inputs = WhisperEncoderInputs.create_dummy(
batch_size=2,
sequence_length=3000,
feature_size=config.num_mel_bins,
device=device,
use_int32_inputs=use_int32_inputs,
)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
encoder,
args=tuple(encoder_inputs.to_list()),
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
export_params=True,
input_names=["input_features"],
output_names=["hidden_states"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "feature_size", 2: "sequence_length"},
"hidden_states": {0: "batch_size", 1: "sequence_length"},
},
opset_version=17,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
if use_external_data_format:
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: WhisperEncoderInputs):
"""Run inference of ONNX model."""
ort_inputs = {
"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
}
return ort_session.run(None, ort_inputs)
@staticmethod
def verify_onnx(
model: WhisperEncoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
inputs = WhisperEncoderInputs.create_dummy(
batch_size=4,
sequence_length=11,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
torch_outputs = model(*input_list)
ort_outputs = WhisperEncoderHelper.onnxruntime_inference(ort_session, inputs)
max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
logger.info(f"max_diff={max_diff}")
return max_diff

View File

@ -0,0 +1,306 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional
import numpy
import onnx
import torch
from models.t5.past_helper import PastKeyValuesHelper
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
from transformers import WhisperConfig
from whisper_decoder import WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderInputs
from whisper_openai_helper import WhisperDecoderInitOpenai
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class WhisperEncoderDecoderInit(torch.nn.Module):
"""A combination of WhisperEncoder and WhisperDecoderInit."""
def __init__(
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: Optional[int] = None,
model_impl: str = "hf",
model: torch.nn.Module = None,
):
super().__init__()
self.config = config
self.whisper_encoder = WhisperEncoder(encoder, config, model_impl=model_impl)
self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id)
if model is not None:
self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder)
self.model_impl = model_impl
def forward(
self,
encoder_input_ids: torch.Tensor,
decoder_input_ids: torch.Tensor = None,
remove_hooks: bool = False,
):
encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids)
# Decoder out: (logits, past_key_values, encoder_hidden_state)
if self.model_impl == "openai":
encoder_hidden_states.unsqueeze(0)
decinit_out, present = self.whisper_decoder_openai_init(
decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks
)
return decinit_out, encoder_hidden_states, present
else:
decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states)
present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1])
present = present_self + present_cross
return decinit_out[0], encoder_hidden_states, present
class WhisperEncoderDecoderInitInputs:
def __init__(self, encoder_input_ids, decoder_input_ids=None):
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
@staticmethod
def create_dummy(
config: WhisperConfig,
batch_size: int,
encode_sequence_length: int,
use_decoder_input_ids: bool,
device: torch.device,
use_int32_inputs: bool = False,
): # -> WhisperEncoderDecoderInitInputs:
encoder_inputs: WhisperEncoderInputs = WhisperEncoderInputs.create_dummy(
batch_size,
sequence_length=3000,
feature_size=config.num_mel_bins,
device=device,
)
decoder_input_ids = None
if use_decoder_input_ids:
dtype = torch.int32 if use_int32_inputs else torch.int64
decoder_input_ids = torch.ones((batch_size, 2), dtype=dtype, device=device) * config.decoder_start_token_id
return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids)
def to_list(self) -> List:
input_list = [self.encoder_input_ids]
if self.decoder_input_ids is not None:
input_list.append(self.decoder_input_ids)
return input_list
class WhisperEncoderDecoderInitHelper:
@staticmethod
def export_onnx(
model: WhisperEncoderDecoderInit,
device: torch.device,
onnx_model_path: str,
use_decoder_input_ids: bool = True,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
Args:
model (WhisperEncoderDecoderInit): the model to export
device (torch.device): device of decoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
"""
assert isinstance(model, WhisperEncoderDecoderInit)
inputs = WhisperEncoderDecoderInitInputs.create_dummy(
model.config,
batch_size=2,
encode_sequence_length=3000,
use_decoder_input_ids=True,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True)
present = out[2]
present_names = PastKeyValuesHelper.get_input_names(present, encoder=True)
output_names = ["logits", "encoder_hidden_states", *present_names]
# Shape of input tensors (sequence_length==1):
# input_ids: (batch_size, sequence_length)
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
# Shape of output tensors:
# logits: (batch_size, sequence_length, vocab_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
input_names = ["encoder_input_ids"]
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference.
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
sequence_length = "1"
num_heads = str(model.config.encoder_attention_heads)
hidden_size = str(model.config.d_model)
head_size = str(model.config.d_model // model.config.encoder_attention_heads)
dynamic_axes = {
"encoder_input_ids": {0: "batch_size", 1: "feature_size"},
"encoder_hidden_states": {
0: "batch_size",
1: "encode_sequence_length",
2: hidden_size,
},
"logits": {
0: "batch_size",
1: "decode_sequence_length",
},
}
if use_decoder_input_ids:
input_names.append("decoder_input_ids")
dynamic_axes["decoder_input_ids"] = {
0: "batch_size",
1: "decode_sequence_length",
}
for name in present_names:
if "cross" in name:
dynamic_axes[name] = {
0: "batch_size",
1: num_heads,
2: "encode_sequence_length",
3: head_size,
}
else: # self attention past state
dynamic_axes[name] = {
0: "batch_size",
1: num_heads,
2: "decode_sequence_length",
3: head_size,
}
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
args=tuple(input_list),
f=temp_onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
# Workaround as mentioned earlier: change numeric dim_param to dim_value
model = onnx.load(temp_onnx_model_path)
for tensor in model.graph.output:
for dim_proto in tensor.type.tensor_type.shape.dim:
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
sequence_length,
num_heads,
hidden_size,
head_size,
]:
dim_value = int(dim_proto.dim_param)
dim_proto.Clear()
dim_proto.dim_value = dim_value
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: WhisperEncoderDecoderInitInputs):
"""Run inference of ONNX model."""
logger.debug("start onnxruntime_inference")
ort_inputs = {
"encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
}
if inputs.decoder_input_ids is not None:
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs
@staticmethod
def verify_onnx(
model: WhisperEncoderDecoderInit,
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
ort_inputs = ort_session.get_inputs()
use_decoder_input_ids = len(ort_inputs) == 3
test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
test_cases_max_diff = []
for batch_size, encode_sequence_length in test_cases[:max_cases]:
inputs = WhisperEncoderDecoderInitInputs.create_dummy(
model.config,
batch_size,
encode_sequence_length,
use_decoder_input_ids=use_decoder_input_ids,
device=device,
use_int32_inputs=use_int32_inputs,
)
ort_outputs = WhisperEncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
# Run inference of PyTorch model
input_list = inputs.to_list()
torch_outputs = model(*input_list)
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
logger.debug(f"logits max_diff={max_diff}")
max_diff_all = max_diff
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
for i in range(2 * model.config.num_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")
for i in range(2 * model.config.num_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * model.config.num_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
test_cases_max_diff.append(max_diff_all)
logger.info(
f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
)
return max(test_cases_max_diff)

View File

@ -0,0 +1,524 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
from pathlib import Path
from typing import Dict, Tuple, Union
import numpy as np
import torch
from float16 import float_to_float16_max_diff
from onnx_model import OnnxModel
from optimizer import optimize_model
from packaging import version
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from transformers import __version__ as transformers_version
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
PRETRAINED_WHISPER_MODELS = [
"whisper-tiny",
"whisper-tiny.en",
"whisper-base",
"whisper-base.en",
"whisper-small",
"whisper-small.en",
"whisper-medium",
"whisper-medium.en",
"whisper-large",
"whisper-large-v2",
"whisper-large-v3",
]
class WhisperHelper:
@staticmethod
def get_onnx_path(
output_dir: str,
model_name_or_path: str,
suffix: str = "",
new_folder: bool = False,
) -> str:
"""Build onnx path
Args:
output_dir (str): output directory
model_name_or_path (str): pretrained model name, or path to the model checkpoint
suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
new_folder (bool, optional): create a new directory for the model. Defaults to False.
Returns:
str: path of onnx model
"""
model_name = model_name_or_path
if os.path.isdir(model_name_or_path):
model_name = Path(model_name_or_path).parts[-1]
else:
model_name = model_name.split("/")[-1]
model_name += suffix
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
return os.path.join(directory, model_name + ".onnx")
@staticmethod
def load_model_openai(
model_name_or_path: str,
cache_dir: str,
device: torch.device,
) -> torch.nn.Module:
"""Load model given a pretrained name or path, then build models for ONNX conversion.
Args:
model_name_or_path (str): pretrained model name or path
cache_dir (str): cache directory
device (torch.device): device to run the model
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
from whisper import _ALIGNMENT_HEADS, _MODELS, _download
from whisper.model import ModelDimensions, Whisper
in_memory = False
model_name = model_name_or_path.split("/")[-1][8:]
checkpoint_file, alignment_heads = None, None
if model_name in _MODELS:
checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory)
alignment_heads = _ALIGNMENT_HEADS[model_name]
with open(checkpoint_file, "rb") as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)
@staticmethod
def load_model(
model_name_or_path: str,
model_impl: str,
cache_dir: str,
device: torch.device,
merge_encoder_and_decoder_init: bool = True,
state_dict_path: str = "",
) -> Dict[str, torch.nn.Module]:
"""Load model given a pretrained name or path, then build models for ONNX conversion.
Args:
model_name_or_path (str): pretrained model name or path
cache_dir (str): cache directory
device (torch.device): device to run the model
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs)
if model_impl == "openai":
openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device)
model_encoder, model_decoder = openai_model.encoder, openai_model.decoder
passed_model = openai_model
else:
model_encoder, model_decoder = model, model
passed_model = None
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path), strict=False)
decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model)
decoder.eval().to(device)
if merge_encoder_and_decoder_init:
encoder_decoder_init = WhisperEncoderDecoderInit(
model_encoder,
model_decoder,
model.config,
decoder_start_token_id=None,
model_impl=model_impl,
model=passed_model,
)
return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
else:
encoder = WhisperEncoder(model.model.encoder, model.config)
encoder.eval().to(device)
decoder_init = WhisperDecoderInit(model.decoder, model.config)
decoder_init.eval().to(device)
return {
"encoder": encoder,
"decoder": decoder,
"decoder_init": decoder_init,
}
@staticmethod
def export_onnx(
model: Union[WhisperEncoder, WhisperDecoder, WhisperDecoderInit, WhisperEncoderDecoderInit],
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_decoder_input_ids: bool = True,
use_int32_inputs: bool = False,
):
if isinstance(model, WhisperEncoder):
WhisperEncoderHelper.export_onnx(
model,
device,
onnx_model_path,
verbose,
use_external_data_format,
)
elif isinstance(model, WhisperEncoderDecoderInit):
WhisperEncoderDecoderInitHelper.export_onnx(
model,
device,
onnx_model_path,
use_decoder_input_ids,
verbose,
use_external_data_format,
use_int32_inputs,
)
else:
WhisperDecoderHelper.export_onnx(
model,
device,
onnx_model_path,
verbose,
use_external_data_format,
use_int32_inputs,
)
@staticmethod
def auto_mixed_precision(
onnx_model: OnnxModel,
op_block_list: Tuple[str] = (
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"Relu",
"Add",
),
):
"""Convert model to mixed precision.
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
Args:
onnx_model (OnnxModel): optimized ONNX model
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
Returns:
parameters(dict): a dictionary of parameters used in float16 conversion
"""
op_full_set = set([node.op_type for node in onnx_model.nodes()])
fp32_op_set = set(op_block_list)
fp16_op_set = op_full_set.difference(fp32_op_set)
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
# logits is the first output
logits_output_name = onnx_model.graph().output[0].name
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
is_weight_fp16_precision = False
output_name_to_node = onnx_model.output_name_to_node()
assert logits_output_name in output_name_to_node
node = output_name_to_node[logits_output_name]
last_matmul_node = None
if node.op_type == "MatMul":
last_matmul_node = node
logger.info(f"Found last MatMul node for logits: {node.name}")
initializer = None
for input in node.input:
initializer = onnx_model.get_initializer(input)
if initializer is not None:
break
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
# we can deduce that the weights are stored in float16 precision.
max_diff = float_to_float16_max_diff(initializer)
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
is_weight_fp16_precision = max_diff < 1e-6
else:
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
keep_io_types = []
node_block_list = []
if (not is_weight_fp16_precision) and (last_matmul_node is not None):
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
keep_io_types = [logits_output_name]
node_block_list = [last_matmul_node.name]
parameters = {
"keep_io_types": keep_io_types,
"op_block_list": list(op_block_list),
"node_block_list": node_block_list,
"force_fp16_initializers": is_weight_fp16_precision,
}
logger.info(f"auto_mixed_precision parameters: {parameters}")
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
return parameters
@staticmethod
def optimize_onnx(
onnx_model_path: str,
optimized_model_path: str,
is_float16: bool,
num_attention_heads: int,
hidden_size: int,
use_external_data_format: bool = False,
auto_mixed_precision: bool = True,
use_gpu: bool = False,
provider: str = "cpu",
):
"""Optimize ONNX model with an option to convert it to use mixed precision."""
from fusion_options import FusionOptions
optimization_options = FusionOptions("bart")
optimization_options.use_multi_head_attention = True
optimization_options.disable_multi_head_attention_bias = provider == "rocm"
m = optimize_model(
onnx_model_path,
model_type="bart",
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=2 if not use_external_data_format else None,
optimization_options=optimization_options,
use_gpu=use_gpu,
only_onnxruntime=False,
)
if is_float16:
if auto_mixed_precision:
WhisperHelper.auto_mixed_precision(m)
else:
m.convert_model_float32_to_float16(cast_input_output=False)
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
@staticmethod
def pt_transcription_for_verify_onnx(
processor: WhisperProcessor,
pt_model: torch.nn.Module,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
# Try to import `datasets` pip package
try:
from datasets import load_dataset
except Exception as e:
logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) # noqa: G201
install_cmd = "pip install datasets"
logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
os.system(install_cmd)
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features_ = []
if batch_size == 1:
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
else:
input_features_ = [
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
]
assert len(input_features_) == batch_size
input_features = torch.cat((input_features_[0], input_features_[1]))
max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
"max_length": max_length,
"min_length": min_length,
"num_beams": num_beams,
"num_return_sequences": num_return_sequences,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"early_stopping": True,
"use_cache": True,
}
if prompt_mode:
prompts = ["John has doubts", "Maria has grave doubts"]
prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
pt_transcription = []
pt_outputs = []
# The looping for model.generate is necessary here due to the limitation as per
# https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids
# prompt_ids input requires a tensor of rank 1
for i in range(batch_size):
inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
inputs["input_features"] = input_features_[i].to(device)
pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
pt_outputs.append(pt_output)
pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
inputs["input_features"] = input_features
del inputs["prompt_ids"]
else:
prompt_ids = []
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
pt_outputs = list(pt_outputs)
del inputs["early_stopping"]
del inputs["use_cache"]
return inputs, pt_transcription, pt_outputs, prompt_ids
@staticmethod
def select_transcription_options(
batch_size: int,
prompt_mode: bool,
):
if batch_size > 1 and prompt_mode:
expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_options = {
expected_transcription_no_comma_prompt1,
expected_transcription_no_comma_prompt2,
expected_transcription_misspelled_prompt1,
expected_transcription_misspelled_prompt2,
}
else:
expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
return expected_transcription_options
@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir)
config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
processor,
pt_model,
device,
batch_size=batch_size,
prompt_mode=prompt_mode,
)
start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
ort_to_np = {
"tensor(float)": np.float32,
"tensor(float16)": np.float16,
"tensor(int64)": np.int64,
"tensor(int32)": np.int32,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
}
use_extra_decoding_ids = "extra_decoding_ids" in ort_names
for name, dtype in zip(ort_names, ort_dtypes):
if name == "input_features":
inputs[name] = inputs[name].detach().cpu().numpy()
elif name == "vocab_mask":
inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype])
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
if not prompt_mode:
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
else:
# This logic handles the scenario for when prompts are not of the same size
# For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
# The final decoder_input_ids will look as such after padding
# [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
# [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
ort_prompts = []
for i in range(batch_size):
ort_prompts.append(decoder_prompt_ids[i].tolist())
max_len = max(len(p) for p in ort_prompts)
padded_prompts = []
for p in ort_prompts:
padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
padded_prompts.append(padded_prompt + forced_decoder_ids)
inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
elif name == "cross_qk_layer_head":
inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype])
elif name == "extra_decoding_ids":
inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0)
elif name == "temperature":
inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0][:, 0, :]
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)
expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)
parity = 1
for i in range(batch_size):
parity *= (
pt_transcription[i] in expected_transcription_options
and ort_transcription[i] in expected_transcription_options
)
max_diff = 0
if not parity:
for i in range(batch_size):
if pt_outputs[i].shape != ort_outputs[i].shape:
diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])]
else:
diff = pt_outputs[i] - ort_outputs[i]
max_diff_i = max(diff.min(), diff.max(), key=abs)
max_diff = max(max_diff, max_diff_i)
if max_diff != 0:
logger.warning(f"PyTorch outputs: {pt_transcription}")
logger.warning(f"ONNX Runtime outputs: {ort_transcription}")
return max_diff

View File

@ -0,0 +1,84 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import torch
logger = logging.getLogger(__name__)
class WhisperDecoderInitOpenai(torch.nn.Module):
"""WhisperDecoderInit for Openai."""
def __init__(
self,
model: torch.nn.Module,
decoder: torch.nn.Module,
):
super().__init__()
self.whisper_model = model
self.whisper_decoder = decoder
self.kv_cache = {}
@torch.no_grad()
def forward(
self,
tokens,
audio_features,
past=None,
remove_hooks=False,
):
# Create a kv_cache for past_values
past_kv_cache = dict()
if past is not None:
# Convert past values from 4D to 3D
past = [torch.transpose(val, 1, 2) for val in past]
past = [val.reshape(val.shape[:2] + (-1,)) for val in past]
half_idx = len(past) // 2
for idx, block in enumerate(self.whisper_decoder.blocks):
past_kv_cache[block.attn.key] = past[2 * idx]
past_kv_cache[block.attn.value] = past[2 * idx + 1]
past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx]
past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1]
hooks = None
if not self.kv_cache:
self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks()
logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache)
# Add concat node for past values
if past is not None:
for block in self.whisper_decoder.blocks:
self.kv_cache[block.attn.key] = torch.cat(
[past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1
).detach()
self.kv_cache[block.attn.value] = torch.cat(
[past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1
).detach()
present_self, present_cross = [], []
# Group self and cross values
for block in self.whisper_decoder.blocks:
present_self.append(self.kv_cache[block.attn.key])
present_self.append(self.kv_cache[block.attn.value])
if past is None:
present_cross.append(self.kv_cache[block.cross_attn.key])
present_cross.append(self.kv_cache[block.cross_attn.value])
present_self = present_self + present_cross
# Add reshape and transpose ops to convert from 3D to 4D
present_self = [
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self
]
# Remove forward hooks to avoid model cloning step
if hooks is not None and remove_hooks:
self.kv_cache = {}
for hook in hooks:
hook.remove()
return logits, present_self