I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

View File

@ -0,0 +1,576 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations
import argparse
import logging
import os
from pathlib import Path
import onnx
import torch
from benchmark_helper import Precision
from fusion_options import AttentionOpType
from onnx_model import OnnxModel
from transformers import AutoConfig, AutoModelForCausalLM
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
class ConvertPhi2ToONNX:
def __init__(
self,
device: torch.device,
model_class: str = "microsoft/phi-2",
cache_dir: str = "./cache",
):
self.model_class = model_class
self.device = device
self.cache_dir = cache_dir
self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir)
self.phi_model = None
self.batch_size = 2
self.sequence_length = 8
self.attn_op_type = None
self.precision = None
self.block_size = 16
self.accuracy_level = None
def set_quantization_params(self, block_size: int, accuracy_level: int | None):
self.block_size = block_size
self.accuracy_level = accuracy_level
def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision):
self.attn_op_type = attn_op_type
self.precision = precision
def erase_onnx_model(self, onnx_path: str) -> None:
assert onnx_path.endswith(".onnx")
if not os.path.exists(onnx_path):
return
model = onnx.load_model(onnx_path, load_external_data=False)
onnx_data_path = None
for initializer in model.graph.initializer:
if initializer.data_location == 1 and initializer.external_data[0].key == "location":
onnx_data_path = "./" + initializer.external_data[0].value
break
logging.info(f"Erasing {onnx_path}...")
os.remove(onnx_path)
if onnx_data_path is not None:
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path)
logging.info(f"Erasing {onnx_data_path}...")
os.remove(onnx_data_path)
def get_phi2_torch_model(self):
logging.info("Loading phi2 torch model...")
if self.phi_model is not None:
return
self.phi_model = AutoModelForCausalLM.from_pretrained(
self.model_class, trust_remote_code=True, cache_dir=self.cache_dir
)
self.phi_model.eval()
self.phi_model.to(self.device)
def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int):
input_ids = torch.randint(
low=0,
high=self.phi_config.vocab_size,
size=(batch_size, sequence_length),
dtype=torch.int64,
device=self.device,
)
self.get_phi2_torch_model()
torch_inputs = self.phi_model.prepare_inputs_for_generation(
input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"]
)
return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"]
def dynamo_export(self, onnx_path: str):
input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length)
self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
from torch._dynamo import config
config.capture_scalar_outputs = True
logging.info("Exporting Phi2 torch model to ONNX...")
torch.onnx.dynamo_export(
self.phi_model,
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
).save(onnx_path)
onnx.checker.check_model(onnx_path)
onnx.shape_inference.infer_shapes_path(onnx_path)
def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
from fusion_options import FusionOptions
from optimizer import optimize_model
optimization_options = FusionOptions("phi")
optimization_options.set_attention_op_type(self.attn_op_type)
optimizer = optimize_model(
onnx_path,
model_type="phi",
num_heads=self.phi_config.num_attention_heads,
hidden_size=self.phi_config.hidden_size,
opt_level=0,
optimization_options=optimization_options,
only_onnxruntime=False,
)
fused_op_count = optimizer.get_fused_operator_statistics()
if optimizer.is_fully_optimized(fused_op_count):
logging.info("Model is fully optimized.")
else:
logging.info("Model is not fully optimized.")
if self.precision == Precision.FLOAT32:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
if (
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
node_block_list = (
[
"Attention_29",
"Attention_30",
"Attention_31",
]
if self.attn_op_type != AttentionOpType.PagedAttention
else []
) # TODO: temp setting for paged attention
logging.info("Converting onnx model to float16/bfloat16...")
optimizer.convert_float_to_float16(
keep_io_types=False,
node_block_list=node_block_list,
use_symbolic_shape_infer=True,
use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention,
)
logging.info("Converting onnx model to float16/bfloat16 done.")
if self.precision == Precision.FLOAT16:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
else:
assert self.precision == Precision.INT4
quant = MatMul4BitsQuantizer(
model=optimizer.model,
block_size=self.block_size,
is_symmetric=True,
accuracy_level=self.accuracy_level,
)
quant.process()
quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True)
# This function currently only works for phi2 model
def convert_to_use_cuda_graph(self, in_onnx_path: str, out_onnx_path: str):
onnx_model = OnnxModel(onnx.load(in_onnx_path, load_external_data=True))
from onnx import TensorProto, helper
graph = onnx_model.graph()
new_inputs = []
for vi in graph.input:
if "attention_mask" in vi.name:
vi_seqlen_k = helper.make_tensor_value_info(
"seqlens_k",
elem_type=TensorProto.INT32,
shape=["batch_size"],
)
vi_total_seq_len = helper.make_tensor_value_info(
"total_sequence_length",
elem_type=TensorProto.INT32,
shape=[1],
)
new_inputs.extend([vi_seqlen_k, vi_total_seq_len])
else:
new_inputs.append(vi)
graph.ClearField("input")
graph.input.extend(new_inputs)
gqas = onnx_model.get_nodes_by_op_type("GroupQueryAttention")
gqa = gqas[0]
seqlens_path = onnx_model.match_parent_path(
gqa,
["Cast", "Sub", "ReduceSum", "Cast"],
[5, 0, 0, 0],
)
if seqlens_path is None:
raise RuntimeError("Failed to find seqlens path for GroupQueryAttention node.")
total_seq_len_path = onnx_model.match_parent_path(
gqa,
["Cast", "Gather", "Shape"],
[6, 0, 0],
)
if total_seq_len_path is None:
raise RuntimeError("Failed to find total_seq_len path for GroupQueryAttention node.")
onnx_model.remove_nodes(seqlens_path)
onnx_model.remove_nodes(total_seq_len_path)
for gqa in gqas:
gqa.input[5] = "seqlens_k"
gqa.input[6] = "total_sequence_length"
onnx_model.save(onnx_model.model, out_onnx_path, save_as_external_data=True)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--fp32_cpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for CPU",
)
parser.add_argument(
"--int4_cpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for CPU",
)
parser.add_argument(
"--fp32_gpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--int4_gpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu_sm8x",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--int4_gpu_sm8x",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--fp16_vllm",
required=False,
action="store_true",
help="Generate fp16 ONNX model for ORT VLLM",
)
parser.add_argument(
"--int4_vllm",
required=False,
action="store_true",
help="Generate int4 ONNX model for ORT VLLM",
)
parser.add_argument(
"--use_cuda_graph",
required=False,
action="store_true",
help="Use CUDA Graph in decoding process",
)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing ONNX models",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./cache",
help="The cache directory for the pytorch model",
)
parser.add_argument(
"--device_id",
required=False,
type=int,
default=0,
help="The device id for the pytorch model",
)
parser.add_argument(
"--run_example",
required=False,
action="store_true",
help="Run ORT inference example",
)
parser.add_argument(
"--run_benchmark",
required=False,
action="store_true",
help="Run ORT benchmark",
)
parser.add_argument(
"--skip_export",
required=False,
action="store_true",
help="Skip exporting ONNX model",
)
parser.add_argument(
"--output_dir",
type=str,
help="The output directory for the ONNX models",
default="phi2_onnx_models",
)
parser.add_argument(
"--block_size",
required=False,
default=16,
type=int,
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
)
parser.add_argument(
"--int4_accuracy_level",
required=False,
type=int,
help="Accuracy level of the 4-bit quantized MatMul computation. "
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu")
converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir)
converter.set_quantization_params(args.block_size, args.int4_accuracy_level)
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
original_onnx_path = os.path.join(output_dir, "phi2_original.onnx")
if not args.skip_export:
if not os.path.exists(original_onnx_path) or args.overwrite:
converter.dynamo_export(original_onnx_path)
model_type_to_args = {
"fp32_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"),
),
"int4_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"),
),
"fp32_gpu": (
AttentionOpType.Attention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"),
),
"fp16_gpu": (
AttentionOpType.Attention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"),
),
"int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")),
"fp16_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"),
),
"int4_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
),
"fp16_vllm": (
AttentionOpType.PagedAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"),
),
"int4_vllm": (
AttentionOpType.PagedAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"),
),
}
if not args.skip_export:
from multiprocessing import Process
def run_optimize_phi2_onnx(
converter: ConvertPhi2ToONNX,
original_onnx_path: str,
attention_type: AttentionOpType,
precision: Precision,
optimized_onnx_path: str,
):
converter.init_attn_type_and_precision(attention_type, precision)
converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path)
if args.use_cuda_graph:
assert args.fp16_gpu_sm8x or args.int4_gpu_sm8x
converter.convert_to_use_cuda_graph(optimized_onnx_path, optimized_onnx_path)
processes = []
if args.fp32_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"])
)
)
if args.int4_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"])
)
)
if args.fp32_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"])
)
)
if args.fp16_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"])
)
)
if args.int4_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"])
)
)
if args.fp16_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]),
)
)
if args.int4_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]),
)
)
if args.fp16_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]),
)
)
if args.int4_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]),
)
)
[p.start() for p in processes]
[p.join() for p in processes]
if args.run_example or args.run_benchmark:
from inference_example import run_phi2
if args.fp16_gpu_sm8x:
logging.info("Running fp16_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
use_cuda_graph=args.use_cuda_graph,
run_benchmark=args.run_benchmark,
)
if args.int4_gpu_sm8x:
logging.info("Running int4_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
use_cuda_graph=args.use_cuda_graph,
run_benchmark=args.run_benchmark,
)
if args.fp32_gpu:
logging.info("Running fp32_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp32_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
use_fp16=False,
run_benchmark=args.run_benchmark,
)
if args.fp16_gpu:
logging.info("Running fp16_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
run_benchmark=args.run_benchmark,
)
if args.int4_gpu:
logging.info("Running int4_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
run_benchmark=args.run_benchmark,
)
if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm:
raise NotImplementedError("CPU/vllm inference example is not implemented yet.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,414 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import numpy as np
import torch
from transformers import AutoTokenizer
import onnxruntime as ort
pt_to_np = {
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.float32": np.float32,
"torch.float16": np.float16,
}
def cuda_memcpy(dst, src):
from cuda import cudart
cudart.cudaMemcpy(
dst.data_ptr(),
src.data_ptr(),
src.element_size() * src.nelement(),
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
)
class ORTGenerator:
def __init__(self, decoder_path):
self.onnx_decoder_path = decoder_path
self.num_heads = 32
self.head_size = 80
self.num_layers = 32
self.max_sequence_length = 2048
self.device_id = 0
self.use_cuda_graph = False
self.use_traced_inputs = False
self.static_inputs_map = {}
def append_static_inputs(self, batch_size):
# Only use this function with GQA and with use_cuda_graph=True
if batch_size in self.static_inputs_map:
return
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda", self.device_id)
static_io = {}
static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
for i in range(self.num_layers):
cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
self.static_inputs_map[batch_size] = static_io
def get_initial_inputs_and_outputs(self, encodings_dict):
self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
batch_size, sequence_length = input_ids.shape
self.use_traced_inputs = (
self.use_cuda_graph
and (batch_size in self.static_inputs_map)
and self.use_buffer_share
and not self.packed_kv
)
step = (
torch.tensor([0], device=self.device, dtype=torch.int64)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["step"]
)
seqlens_k = (
torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["seqlens_k"]
)
cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
total_seq_length = (
torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["total_sequence_length"]
)
total_seq_length[0] = sequence_length
inputs = {
"input_ids": input_ids.contiguous(),
"attention_mask": attention_mask.contiguous(),
}
if self.use_step:
inputs["step"] = step.contiguous()
if self.use_cuda_graph:
inputs["seqlens_k"] = seqlens_k.contiguous()
inputs["total_sequence_length"] = total_seq_length.contiguous()
del inputs["attention_mask"]
past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
past_shape = (
(2, batch_size, self.num_heads, past_seq_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, past_seq_length, self.head_size)
)
if not self.use_traced_inputs:
for i in range(self.num_layers):
past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
(
inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
if not self.packed_kv
else inputs.update({f"past_{i}": past.contiguous()})
)
else:
for i in range(self.num_layers):
inputs.update(
{
f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
}
)
logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
outputs = {"logits": logits.contiguous()}
if not self.use_buffer_share:
present_shape = (
(2, batch_size, self.num_heads, sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
(
outputs.update(
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
)
if not self.packed_kv
else outputs.update({f"present_{i}": present.contiguous()})
)
return inputs, outputs
def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
io_binding = model.io_binding()
device = None
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type=v.device.type,
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
device = v.device
for output in model.get_outputs():
name = output.name
if self.use_buffer_share and "present" in name:
v = inputs[name.replace("present", "past")]
io_binding.bind_output(
name=name,
device_type=v.device.type,
device_id=v.device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
else:
v = outputs[name]
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
return io_binding
def create_session(
self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
):
self.device_id = device_id
sess_options = ort.SessionOptions()
sess_options.log_verbosity_level = 4
sess_options.log_severity_level = 4
self.use_cuda_graph = use_cuda_graph
ep = (
("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
if self.device_id >= 0
else "CPUExecutionProvider"
)
self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
self.ro = ort.RunOptions()
self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
self.use_fp16 = use_fp16
self.use_buffer_share = use_buffer_share
self.packed_kv = packed_kv
self.use_step = use_step
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
self.tokenizer.pad_token = "[PAD]"
def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
all_token_ids = inputs["input_ids"].clone()
batch_size, sequence_length = all_token_ids.shape
current_length = sequence_length
has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
if benchmark:
import time
latency = []
prompt_run = True
while current_length < max_length:
io_binding = self.apply_io_binding(self.sess, inputs, outputs)
if benchmark:
start = time.time()
io_binding.synchronize_inputs()
if prompt_run:
if self.use_cuda_graph:
# Disable CUDA graph for the prompt run
self.ro.add_run_config_entry("gpu_graph_id", "-1")
self.sess.run_with_iobinding(io_binding, self.ro)
if self.use_cuda_graph:
# Enable CUDA graph for the decoding run
self.ro.add_run_config_entry(
"gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
)
prompt_run = False
else:
self.sess.run_with_iobinding(io_binding, self.ro)
io_binding.synchronize_outputs()
if benchmark:
end = time.time()
latency.append(end - start)
# Sample with argmax (greedy search)
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check if we previously reached EOS token id or if generated token id is EOS token id
has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
# Determine which new tokens to add to list of all token ids
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
# Return early if all batch entries have reached EOS token id
if torch.all(has_eos):
break
# Update inputs for next inference run
current_length += 1
inputs["input_ids"] = tokens_to_add.to(torch.int32)
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
if self.use_step:
inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
inputs["step"] = self.static_inputs_map[batch_size]["step"]
if self.use_cuda_graph:
previous_seqlens_k = inputs["seqlens_k"]
inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
inputs["total_sequence_length"][0] = current_length
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
else:
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
).to(torch.int32)
# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
if self.use_traced_inputs:
outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
outputs["logits"].zero_()
if not self.use_buffer_share:
for i in range(self.num_layers):
if not self.packed_kv:
inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
else:
inputs[f"past_{i}"] = outputs[f"present_{i}"]
new_sequence_length = inputs["attention_mask"].shape[1]
present_shape = (
(2, batch_size, self.num_heads, new_sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, new_sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
(
outputs.update(
{
f"present_key_{i}": present.contiguous(),
f"present_value_{i}": present.clone().contiguous(),
}
)
if not self.packed_kv
else outputs.update({f"present_{i}": present.contiguous()})
)
if benchmark:
print(
f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
)
print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
return
texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
return texts
def generate(self, prompt, max_length, cuda_graph_annotation):
encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
batch_size, sequence_length = prompt_shape
max_length = sequence_length + token_num
encodings_dict = {}
encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
# Warm up run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
# Benchmark run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
def run_phi2(
onnx_model_path,
use_buffer_share,
device_id,
packed_kv=False,
use_fp16=True,
use_step=False,
use_cuda_graph=False,
run_benchmark=False,
):
generator = ORTGenerator(onnx_model_path)
generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
def simple_run(prompt):
example_batch_size = len(prompt)
if use_cuda_graph:
generator.append_static_inputs(batch_size=example_batch_size)
texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
for i in range(len(texts)):
print("Prompt: ", prompt[i])
print("Texts: ", texts[i])
prompt = [
'''```python
def print_prime(n):
"""
Print all primes between 1 and n
"""'''
]
if not run_benchmark:
simple_run(prompt)
# Run simple benchmark. Time the decoder only.
if run_benchmark:
token_num = 32
for batch_size in [1, 2, 4, 8]:
generator.append_static_inputs(batch_size)
for sequence_length in [16, 512]:
prompt_shape = (batch_size, sequence_length)
generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)