I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,278 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import copy
import logging
import os
import torch
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
from t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS, T5Helper
logger = logging.getLogger("")
def parse_arguments():
parser = argparse.ArgumentParser()
pretrained_models = PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS
parser.add_argument(
"-m",
"--model_name_or_path",
required=False,
default=PRETRAINED_T5_MODELS[0],
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
)
parser.add_argument(
"--model_type",
required=False,
type=str,
default="t5",
choices=["t5", "mt5"],
help="Model type: either t5 (default) or mt5",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--output",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Output directory",
)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=[Precision.FLOAT32, Precision.FLOAT16],
help="Precision of model to run. fp32 for full precision, fp16 for half precision",
)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)
parser.add_argument(
"-s",
"--use_decoder_start_token",
required=False,
action="store_true",
help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.",
)
parser.set_defaults(use_decoder_start_token=False)
parser.add_argument(
"-w",
"--overwrite",
required=False,
action="store_true",
help="overwrite existing ONNX model",
)
parser.set_defaults(overwrite=False)
parser.add_argument(
"--disable_auto_mixed_precision",
required=False,
action="store_true",
help="use pure fp16 instead of mixed precision",
)
parser.set_defaults(disable_auto_mixed_precision=False)
parser.add_argument(
"--separate_encoder_and_decoder_init",
required=False,
action="store_true",
help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.",
)
parser.set_defaults(separate_encoder_and_decoder_init=False)
parser.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
)
parser.set_defaults(use_int64_inputs=False)
parser.add_argument(
"--state_dict_path",
type=str,
default="",
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)
args = parser.parse_args()
return args
def export_onnx_models(
model_name_or_path,
cache_dir,
output_dir,
use_gpu,
use_external_data_format,
optimize_onnx,
precision,
verbose,
use_decoder_start_token: bool = False,
merge_encoder_and_decoder_init: bool = True,
overwrite: bool = False,
disable_auto_mixed_precision: bool = False,
use_int32_inputs: bool = True,
model_type: str = "t5",
state_dict_path: str = "",
):
device = torch.device("cuda:0" if use_gpu else "cpu")
models = T5Helper.load_model(
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path
)
config = models["decoder"].config
if (not use_external_data_format) and (config.num_layers > 24):
logger.info("Try use_external_data_format when model size > 2GB")
output_paths = []
for name, model in models.items():
model.to(device)
filename_suffix = "_" + name
onnx_path = T5Helper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix,
new_folder=False,
)
if overwrite or not os.path.exists(onnx_path):
logger.info(f"Exporting ONNX model to {onnx_path}")
# We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
cloned_model = copy.deepcopy(model).to(device)
T5Helper.export_onnx(
cloned_model,
device,
onnx_path,
verbose,
use_external_data_format,
use_decoder_input_ids=not use_decoder_start_token,
use_int32_inputs=use_int32_inputs,
)
else:
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
# Optimize ONNX graph. Note that we have not implemented graph optimization for T5 yet.
if optimize_onnx or precision != Precision.FLOAT32:
output_path = T5Helper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix + "_" + str(precision),
new_folder=False,
)
if overwrite or not os.path.exists(output_path):
logger.info(f"Optimizing model to {output_path}")
T5Helper.optimize_onnx(
onnx_path,
output_path,
precision == Precision.FLOAT16,
config.num_heads,
config.hidden_size,
use_external_data_format,
auto_mixed_precision=not disable_auto_mixed_precision,
use_gpu=use_gpu,
)
else:
logger.info(f"Skip optimizing: existed ONNX model {onnx_path}")
else:
output_path = onnx_path
ort_session = create_onnxruntime_session(
output_path,
use_gpu=use_gpu,
provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"],
)
with torch.no_grad():
max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
if max_diff > 1e-4:
logger.warning("PyTorch and OnnxRuntime results are NOT close")
output_paths.append(output_path)
return output_paths
def main():
args = parse_arguments()
setup_logger(args.verbose)
logger.info(f"Arguments:{args}")
cache_dir = args.cache_dir
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
prepare_environment(cache_dir, output_dir, args.use_gpu)
if args.precision != Precision.FLOAT32:
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 requires --use_gpu"
if args.optimize_onnx:
logger.warning("Graph optimization for T5 is not implemented yet.")
output_paths = export_onnx_models(
args.model_name_or_path,
cache_dir,
output_dir,
args.use_gpu,
args.use_external_data_format,
args.optimize_onnx,
args.precision,
args.verbose,
args.use_decoder_start_token,
not args.separate_encoder_and_decoder_init,
args.overwrite,
args.disable_auto_mixed_precision,
not args.use_int64_inputs,
args.model_type,
)
logger.info(f"Done! Outputs: {output_paths}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,150 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
from typing import List, Tuple
import torch
logger = logging.getLogger(__name__)
class PastKeyValuesHelper:
"""Helper functions to process past key values for encoder-decoder model"""
@staticmethod
def get_past_names(num_layers, present: bool = False):
past_self_names = []
past_cross_names = []
for i in range(num_layers):
past_self_names.extend(
[f"present_key_self_{i}", f"present_value_self_{i}"]
if present
else [f"past_key_self_{i}", f"past_value_self_{i}"]
)
past_cross_names.extend(
[f"present_key_cross_{i}", f"present_value_cross_{i}"]
if present
else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
)
return past_self_names + past_cross_names
@staticmethod
def group_by_self_or_cross(present_key_values):
"""Split present state from grouped by layer to grouped by self/cross attention.
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
"""
present_self = []
present_cross = []
for _i, present_layer_i in enumerate(present_key_values):
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
(
present_key_self,
present_value_self,
present_key_cross,
present_value_cross,
) = present_layer_i
present_self.extend([present_key_self, present_value_self])
present_cross.extend([present_key_cross, present_value_cross])
return present_self, present_cross
@staticmethod
def group_by_layer(past, num_layers):
"""Reorder past state from grouped by self/cross attention to grouped by layer.
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
"""
assert len(past) == 4 * num_layers
return tuple(
[
past[2 * i],
past[2 * i + 1],
past[2 * num_layers + 2 * i],
past[2 * num_layers + 2 * i + 1],
]
for i in range(num_layers)
)
@staticmethod
def back_group_by_layer(past_key_values: Tuple[Tuple[torch.Tensor]]):
"""Categorize present_key_values from self and cross attention to layer by layer.
Reorder past state from grouped by self/cross attention to grouped by layer.
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
Args:
present_key_values: From past_key_values of a model (group by self and cross attention)
Returns:
past_tuples: present key and values grouped by layer.
"""
past_tuples = ()
half_idx = len(past_key_values) // 2
for i in range(len(past_key_values) // 4):
idx = 2 * i
past_tuples += (
(
past_key_values[idx],
past_key_values[idx + 1],
past_key_values[half_idx + idx],
past_key_values[half_idx + idx + 1],
),
)
return past_tuples
@staticmethod
def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: bool = False):
"""Categorize present_key_values into self and cross attention.
Split present state from grouped by layer to grouped by self/cross attention.
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
(past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
Args:
present_key_values: From past_key_values of a model (group by layer)
concat: If concat self attention with cross attention key/value to return
Returns:
present_self (Tuple[torch.Tensor]): present key and values from self attention
present_cross (Tuple[torch.Tensor]): present key and values from cross attention
"""
present_self: List[torch.Tensor] = []
present_cross: List[torch.Tensor] = []
for _, present_layer_i in enumerate(present_key_values):
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
present_self.extend([present_key_self, present_value_self])
present_cross.extend([present_key_cross, present_value_cross])
if concat:
return present_self + present_cross
else:
return present_self, present_cross
@staticmethod
def get_input_names(past_key_values: Tuple[Tuple[torch.Tensor]], encoder=True):
"""Process input names of model wrapper.
Args:
past_key_values: Consider `self` and `cross` past_key_values
Returns:
names (List[string]): input names
"""
names = []
num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
prefix = "past_" if not encoder else "present_"
for i in range(num_layers):
names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
for i in range(num_layers):
names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
return names

View File

@ -0,0 +1,438 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Union
import numpy
import onnx
import torch
from io_binding_helper import TypeHelper
from onnx_model import OnnxModel
from past_helper import PastKeyValuesHelper
from t5_encoder import T5EncoderInputs
from torch_onnx_export_helper import torch_onnx_export
from transformers import MT5Config, T5Config
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class T5DecoderInit(torch.nn.Module):
"""A T5 decoder with LM head to create initial past key values.
This model is only called once during starting decoding.
"""
def __init__(
self,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: Union[T5Config, MT5Config],
decoder_start_token_id: Optional[int] = None,
):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
self.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
self.tie_word_embeddings = (
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
)
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_attention_mask: torch.Tensor,
encoder_hidden_states: torch.FloatTensor,
):
if decoder_input_ids is None:
batch_size = encoder_attention_mask.shape[0]
decoder_input_ids = (
torch.ones(
(batch_size, 1),
dtype=torch.long,
device=encoder_attention_mask.device,
)
* self.decoder_start_token_id
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
return_dict=True,
)
sequence_output = decoder_outputs.last_hidden_state
present_key_values = decoder_outputs.past_key_values
if self.tie_word_embeddings:
sequence_output = sequence_output * (self.config.d_model**-0.5)
lm_logits = self.lm_head(sequence_output)
past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
return lm_logits, past_self, past_cross
class T5Decoder(torch.nn.Module):
"""A T5 decoder with LM head and past key values"""
def __init__(self, decoder, lm_head, config):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
self.tie_word_embeddings = (
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
)
def forward(self, decoder_input_ids, encoder_attention_mask, *past):
num_decoder_layers = self.config.num_decoder_layers
past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers)
# This is a hack since only the third dimension of encoder_hidden_states is used here
dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=dummy_encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
return_dict=True,
)
sequence_output = decoder_outputs.last_hidden_state
present_key_values = decoder_outputs.past_key_values
if self.tie_word_embeddings:
sequence_output = sequence_output * (self.config.d_model**-0.5)
lm_logits = self.lm_head(sequence_output)
present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
# Do not return present_cross since they are identical to corresponding past_cross input
return lm_logits, present_self
class T5DecoderInputs:
def __init__(
self,
decoder_input_ids,
encoder_attention_mask,
past_key_values=None,
):
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
@staticmethod
def create_dummy(
config: Union[T5Config, MT5Config],
batch_size: int,
encode_sequence_length: int,
past_decode_sequence_length: int,
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
): # -> T5DecoderInputs:
"""Create dummy inputs for T5Decoder.
Args:
decoder: decoder
batch_size (int): batch size
encode_sequence_length (int): sequence length of input_ids for encoder
past_decode_sequence_length (int): past sequence length of input_ids for decoder
device (torch.device): device of output tensors
float16 (bool): whether the model uses float32 or float16 in input
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
Returns:
T5DecoderInputs: dummy inputs for decoder
"""
num_attention_heads: int = config.num_heads
num_layers: int = config.num_decoder_layers
vocab_size: int = config.vocab_size
# Do not use head_size = hidden_size / num_attention_heads here.
# For example, mt5-small, d_model=512 and num_heads=6
head_size: int = config.d_kv
sequence_length: int = 1 # fixed for decoding
decoder_input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=(torch.int32 if use_int32_inputs else torch.int64),
device=device,
)
encoder_inputs = T5EncoderInputs.create_dummy(
batch_size,
encode_sequence_length,
vocab_size,
device,
use_int32_inputs=use_int32_inputs,
)
float_type = torch.float16 if float16 else torch.float32
if past_decode_sequence_length > 0:
self_attention_past_shape = [
batch_size,
num_attention_heads,
past_decode_sequence_length,
head_size,
]
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length,
head_size,
]
past = []
for _ in range(2 * num_layers):
past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
for _ in range(2 * num_layers):
past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
else:
past = None
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
def to_list(self) -> List:
input_list = [
self.decoder_input_ids,
self.encoder_attention_mask,
]
if self.past_key_values:
input_list.extend(self.past_key_values)
return input_list
def to_fp32(self):
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
return T5DecoderInputs(
self.decoder_input_ids.clone(),
self.encoder_attention_mask.clone(),
past,
)
class T5DecoderHelper:
@staticmethod
def export_onnx(
decoder: Union[T5Decoder, T5DecoderInit],
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
Args:
decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
device (torch.device): device of decoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs
"""
assert isinstance(decoder, (T5Decoder, T5DecoderInit))
inputs = T5DecoderInputs.create_dummy(
decoder.config,
batch_size=2,
encode_sequence_length=3,
past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
num_decoder_layers = decoder.config.num_decoder_layers
past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True)
present_self_names = present_names[: 2 * num_decoder_layers]
input_past_names = past_names if isinstance(decoder, T5Decoder) else []
output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
output_names = ["logits", *output_present_names]
# Shape of input tensors (sequence_length==1):
# input_ids: (batch_size, sequence_length)
# encoder_attention_mask: (batch_size, encode_sequence_length)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
# Shape of output tensors:
# logits: (batch_size, sequence_length, vocab_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
input_names = ["input_ids"]
input_names.append("encoder_attention_mask")
input_names.extend(input_past_names)
dynamic_axes = {
"input_ids": {
0: "batch_size",
# 1: 'sequence_length'
},
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
"encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
"logits": {
0: "batch_size",
# 1: 'sequence_length'
},
}
for name in input_past_names:
dynamic_axes[name] = {
0: "batch_size",
2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
}
for name in output_present_names:
if "cross" in name:
dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
else: # self attention past state
if isinstance(decoder, T5Decoder):
dynamic_axes[name] = {
0: "batch_size",
2: "past_decode_sequence_length + 1",
}
else:
dynamic_axes[name] = {
0: "batch_size",
# 2: 'sequence_length'
}
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
decoder,
args=tuple(input_list),
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
if use_external_data_format:
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
"""Run inference of ONNX model."""
logger.debug("start onnxruntime_inference")
ort_inputs = {
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
}
if inputs.past_key_values:
assert len(inputs.past_key_values) % 4 == 0
num_layers = int(len(inputs.past_key_values) / 4)
past_names = PastKeyValuesHelper.get_past_names(num_layers)
for i, past_tensor in enumerate(inputs.past_key_values):
ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs
@staticmethod
def verify_onnx(
model: Union[T5Decoder, T5DecoderInit],
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
test_cases_max_diff = []
for (
batch_size,
encode_sequence_length,
past_decode_sequence_length,
) in test_cases[:max_cases]:
if isinstance(model, T5DecoderInit):
past_decode_sequence_length = 0 # noqa: PLW2901
inputs = T5DecoderInputs.create_dummy(
model.config,
batch_size,
encode_sequence_length,
past_decode_sequence_length,
device=device,
float16=float16,
use_int32_inputs=use_int32_inputs,
)
# We use fp32 PyTroch model as baseline even when ONNX model is fp16
input_list = inputs.to_fp32().to_list()
# Run inference of PyTorch model
with torch.no_grad():
torch_outputs = model(*input_list)
ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
num_decoder_layers = model.config.num_decoder_layers
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
max_diff_all = max_diff
logger.debug(f"logits max_diff={max_diff}")
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
if isinstance(model, T5DecoderInit):
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
test_cases_max_diff.append(max_diff_all)
logger.info(
"batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
batch_size,
encode_sequence_length,
past_decode_sequence_length,
max_diff_all,
)
return max_diff_all

View File

@ -0,0 +1,171 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import random
import tempfile
from pathlib import Path
from typing import List, Union
import numpy
import onnx
import torch
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
from transformers import MT5Config, T5Config
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class T5Encoder(torch.nn.Module):
"""T5 encoder outputs only the last hidden state"""
def __init__(self, encoder, config: Union[T5Config, MT5Config]):
super().__init__()
self.encoder = encoder
self.config = config
def forward(self, input_ids, attention_mask):
return self.encoder(input_ids, attention_mask)[0]
class T5EncoderInputs:
def __init__(self, input_ids, attention_mask):
self.input_ids: torch.LongTensor = input_ids
self.attention_mask: torch.LongTensor = attention_mask
@staticmethod
def create_dummy(
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
): # -> T5EncoderInputs
"""Create dummy inputs for T5 encoder.
Args:
batch_size (int): batch size
sequence_length (int): sequence length
vocab_size (int): vocabulary size
device (torch.device): device of output tensors
Returns:
T5EncoderInputs: dummy inputs for encoder
"""
dtype = torch.int32 if use_int32_inputs else torch.int64
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=dtype,
device=device,
)
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
if sequence_length >= 2:
for i in range(batch_size):
padding_position = random.randint(0, sequence_length - 1)
attention_mask[i, :padding_position] = 0
return T5EncoderInputs(input_ids, attention_mask)
def to_list(self) -> List:
input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
return input_list
class T5EncoderHelper:
@staticmethod
def export_onnx(
encoder: T5Encoder,
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export encoder to ONNX
Args:
encoder (T5Encoder): encoder object
device (torch.device): device of encoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
"""
config = encoder.config
encoder_inputs = T5EncoderInputs.create_dummy(
batch_size=2,
sequence_length=4,
vocab_size=config.vocab_size,
device=device,
use_int32_inputs=use_int32_inputs,
)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
encoder,
args=tuple(encoder_inputs.to_list()),
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
export_params=True,
input_names=["input_ids", "attention_mask"],
output_names=["hidden_states"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"hidden_states": {0: "batch_size", 1: "sequence_length"},
},
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
if use_external_data_format:
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: T5EncoderInputs):
"""Run inference of ONNX model."""
ort_inputs = {
"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
"attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()),
}
return ort_session.run(None, ort_inputs)
@staticmethod
def verify_onnx(
model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
inputs = T5EncoderInputs.create_dummy(
batch_size=4,
sequence_length=11,
vocab_size=model.config.vocab_size,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
torch_outputs = model(*input_list)
ort_outputs = T5EncoderHelper.onnxruntime_inference(ort_session, inputs)
max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
logger.info(f"max_diff={max_diff}")
return max_diff

View File

@ -0,0 +1,299 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Union
import numpy
import onnx
import torch
from onnx_model import OnnxModel
from past_helper import PastKeyValuesHelper
from t5_decoder import T5DecoderInit
from t5_encoder import T5Encoder, T5EncoderInputs
from torch_onnx_export_helper import torch_onnx_export
from transformers import MT5Config, T5Config
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class T5EncoderDecoderInit(torch.nn.Module):
"""A combination of T5Encoder and T5DecoderInit."""
def __init__(
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: Union[T5Config, MT5Config],
decoder_start_token_id: Optional[int] = None,
):
super().__init__()
self.config = config
self.t5_encoder = T5Encoder(encoder, config)
self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
def forward(
self,
encoder_input_ids: torch.Tensor,
encoder_attention_mask: torch.Tensor,
decoder_input_ids: torch.Tensor = None,
):
encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
lm_logits, past_self, past_cross = self.t5_decoder_init(
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
)
return lm_logits, encoder_hidden_states, past_self, past_cross
class T5EncoderDecoderInitInputs:
def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
@staticmethod
def create_dummy(
config: Union[T5Config, MT5Config],
batch_size: int,
encode_sequence_length: int,
use_decoder_input_ids: int,
device: torch.device,
use_int32_inputs: bool = False,
): # -> T5EncoderDecoderInitInputs:
encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
batch_size,
encode_sequence_length,
config.vocab_size,
device,
use_int32_inputs=use_int32_inputs,
)
decoder_input_ids = None
if use_decoder_input_ids:
dtype = torch.int32 if use_int32_inputs else torch.int64
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
def to_list(self) -> List:
input_list = [self.encoder_input_ids, self.encoder_attention_mask]
if self.decoder_input_ids is not None:
input_list.append(self.decoder_input_ids)
return input_list
class T5EncoderDecoderInitHelper:
@staticmethod
def export_onnx(
model: T5EncoderDecoderInit,
device: torch.device,
onnx_model_path: str,
use_decoder_input_ids: bool = True,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
Args:
model (T5EncoderDecoderInit): the model to export
device (torch.device): device of decoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
"""
assert isinstance(model, T5EncoderDecoderInit)
inputs = T5EncoderDecoderInitInputs.create_dummy(
model.config,
batch_size=2,
encode_sequence_length=3,
use_decoder_input_ids=use_decoder_input_ids,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
output_names = ["logits", "encoder_hidden_states", *present_names]
# Shape of input tensors (sequence_length==1):
# input_ids: (batch_size, sequence_length)
# encoder_attention_mask: (batch_size, encode_sequence_length)
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
# Shape of output tensors:
# logits: (batch_size, sequence_length, vocab_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
input_names = ["encoder_input_ids", "encoder_attention_mask"]
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference.
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
sequence_length = "1"
num_heads = str(model.config.num_heads)
hidden_size = str(model.config.d_model)
head_size = str(model.config.d_kv)
dynamic_axes = {
"encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
"encoder_hidden_states": {
0: "batch_size",
1: "encode_sequence_length",
2: hidden_size,
},
"logits": {
0: "batch_size",
1: sequence_length,
},
}
if use_decoder_input_ids:
input_names.append("decoder_input_ids")
dynamic_axes["decoder_input_ids"] = {
0: "batch_size",
1: sequence_length,
}
for name in present_names:
if "cross" in name:
dynamic_axes[name] = {
0: "batch_size",
1: num_heads,
2: "encode_sequence_length",
3: head_size,
}
else: # self attention past state
dynamic_axes[name] = {
0: "batch_size",
1: num_heads,
2: sequence_length,
3: head_size,
}
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
args=tuple(input_list),
f=temp_onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
# Workaround as mentioned earlier: change numeric dim_param to dim_value
model = onnx.load(temp_onnx_model_path)
for tensor in model.graph.output:
for dim_proto in tensor.type.tensor_type.shape.dim:
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
sequence_length,
num_heads,
hidden_size,
head_size,
]:
dim_value = int(dim_proto.dim_param)
dim_proto.Clear()
dim_proto.dim_value = dim_value
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
"""Run inference of ONNX model."""
logger.debug("start onnxruntime_inference")
ort_inputs = {
"encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
}
if inputs.decoder_input_ids is not None:
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs
@staticmethod
def verify_onnx(
model: T5EncoderDecoderInit,
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
ort_inputs = ort_session.get_inputs()
use_decoder_input_ids = len(ort_inputs) == 3
test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
test_cases_max_diff = []
for batch_size, encode_sequence_length in test_cases[:max_cases]:
inputs = T5EncoderDecoderInitInputs.create_dummy(
model.config,
batch_size,
encode_sequence_length,
use_decoder_input_ids=use_decoder_input_ids,
device=device,
use_int32_inputs=use_int32_inputs,
)
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
# Run inference of PyTorch model
input_list = inputs.to_list()
torch_outputs = model(*input_list)
num_decoder_layers = model.config.num_decoder_layers
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
logger.debug(f"logits max_diff={max_diff}")
max_diff_all = max_diff
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
test_cases_max_diff.append(max_diff_all)
logger.info(
f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
)
return max(test_cases_max_diff)

View File

@ -0,0 +1,272 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
from pathlib import Path
from typing import Dict, List, Union
import torch
from float16 import float_to_float16_max_diff
from onnx_model import OnnxModel
from optimizer import optimize_model
from t5_decoder import T5Decoder, T5DecoderHelper, T5DecoderInit
from t5_encoder import T5Encoder, T5EncoderHelper
from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
PRETRAINED_MT5_MODELS = ["google/mt5-small", "google/mt5-base", "google/mt5-large", "google/mt5-xl", "google/mt5-xxl"]
class T5Helper:
@staticmethod
def get_onnx_path(
output_dir: str,
model_name_or_path: str,
suffix: str = "",
new_folder: bool = False,
) -> str:
"""Build onnx path
Args:
output_dir (str): output directory
model_name_or_path (str): pretrained model name, or path to the model checkpoint
suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
new_folder (bool, optional): create a new directory for the model. Defaults to False.
Returns:
str: path of onnx model
"""
model_name = model_name_or_path
if os.path.isdir(model_name_or_path):
model_name = Path(model_name_or_path).parts[-1]
else:
model_name.split("/")[-1]
model_name += suffix
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
return os.path.join(directory, model_name + ".onnx")
@staticmethod
def load_model(
model_name_or_path: str,
cache_dir: str,
device: torch.device,
merge_encoder_and_decoder_init: bool = True,
model_type: str = "t5",
state_dict_path: str = "",
) -> Dict[str, torch.nn.Module]:
"""Load model given a pretrained name or path, then build models for ONNX conversion.
Args:
model_name_or_path (str): pretrained model name or path
cache_dir (str): cache directory
device (torch.device): device to run the model
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
is_mt5 (bool, optional): whether the model is MT5 instead of T5
Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
if model_type == "t5":
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
elif model_type == "mt5":
model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
else:
raise ValueError("only support mode_type=t5 or mt5")
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path))
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
decoder.eval().to(device)
if merge_encoder_and_decoder_init:
encoder_decoder_init = T5EncoderDecoderInit(
model.encoder,
model.decoder,
model.lm_head,
model.config,
decoder_start_token_id=None,
)
return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
else:
encoder = T5Encoder(model.encoder, model.config)
encoder.eval().to(device)
decoder_init = T5DecoderInit(model.decoder, model.lm_head, model.config)
decoder_init.eval().to(device)
return {
"encoder": encoder,
"decoder": decoder,
"decoder_init": decoder_init,
}
@staticmethod
def export_onnx(
model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_decoder_input_ids: bool = True,
use_int32_inputs: bool = False,
):
if isinstance(model, T5Encoder):
T5EncoderHelper.export_onnx(
model,
device,
onnx_model_path,
verbose,
use_external_data_format,
use_int32_inputs,
)
elif isinstance(model, T5EncoderDecoderInit):
T5EncoderDecoderInitHelper.export_onnx(
model,
device,
onnx_model_path,
use_decoder_input_ids,
verbose,
use_external_data_format,
use_int32_inputs,
)
else:
T5DecoderHelper.export_onnx(
model,
device,
onnx_model_path,
verbose,
use_external_data_format,
use_int32_inputs,
)
@staticmethod
def auto_mixed_precision(
onnx_model: OnnxModel,
op_block_list: List[str] = [ # noqa: B006
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"Relu",
"Add",
],
):
"""Convert model to mixed precision.
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
Args:
onnx_model (OnnxModel): optimized ONNX model
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
Returns:
parameters(dict): a dictionary of parameters used in float16 conversion
"""
op_full_set = {node.op_type for node in onnx_model.nodes()}
fp32_op_set = set(op_block_list)
fp16_op_set = op_full_set.difference(fp32_op_set)
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
# logits is the first output
logits_output_name = onnx_model.graph().output[0].name
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
is_weight_fp16_precision = False
output_name_to_node = onnx_model.output_name_to_node()
assert logits_output_name in output_name_to_node
node = output_name_to_node[logits_output_name]
last_matmul_node = None
if node.op_type == "MatMul":
last_matmul_node = node
logger.info(f"Found last MatMul node for logits: {node.name}")
initializer = None
for input in node.input:
initializer = onnx_model.get_initializer(input)
if initializer is not None:
break
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
# we can deduce that the weights are stored in float16 precision.
max_diff = float_to_float16_max_diff(initializer)
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
is_weight_fp16_precision = max_diff < 1e-6
else:
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
keep_io_types = []
node_block_list = []
if (not is_weight_fp16_precision) and (last_matmul_node is not None):
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
keep_io_types = [logits_output_name]
node_block_list = [last_matmul_node.name]
parameters = {
"keep_io_types": keep_io_types,
"op_block_list": op_block_list,
"node_block_list": node_block_list,
"force_fp16_initializers": is_weight_fp16_precision,
}
logger.info(f"auto_mixed_precision parameters: {parameters}")
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
return parameters
@staticmethod
def optimize_onnx(
onnx_model_path: str,
optimized_model_path: str,
is_float16: bool,
num_attention_heads: int,
hidden_size: int,
use_external_data_format: bool = False,
auto_mixed_precision: bool = True,
use_gpu: bool = False,
):
"""Optimize ONNX model with an option to convert it to use mixed precision."""
from fusion_options import FusionOptions
optimization_options = None
if is_float16:
optimization_options = FusionOptions("t5")
optimization_options.enable_skip_layer_norm = False
m = optimize_model(
onnx_model_path,
model_type="t5",
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=2 if not use_external_data_format else 0,
optimization_options=optimization_options,
use_gpu=False,
only_onnxruntime=not use_gpu,
)
if is_float16:
if auto_mixed_precision:
T5Helper.auto_mixed_precision(m)
else:
m.convert_model_float32_to_float16(cast_input_output=False)
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
@staticmethod
def verify_onnx(
model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
if isinstance(model, T5Encoder):
return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
if isinstance(model, T5EncoderDecoderInit):
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)