172 lines
5.7 KiB
Python
172 lines
5.7 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""onnx shape inference. Shape inference is not guaranteed to be
|
|
complete.
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Sequence
|
|
|
|
import onnx
|
|
import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812
|
|
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto
|
|
|
|
|
|
def infer_shapes(
|
|
model: ModelProto | bytes,
|
|
check_type: bool = False,
|
|
strict_mode: bool = False,
|
|
data_prop: bool = False,
|
|
) -> ModelProto:
|
|
"""Apply shape inference to the provided ModelProto.
|
|
|
|
Inferred shapes are added to the value_info field of the graph.
|
|
|
|
If the inferred values conflict with values already provided in the
|
|
graph, that means that the provided values are invalid (or there is a
|
|
bug in shape inference), and the result is unspecified.
|
|
|
|
Arguments:
|
|
model: ModelProto.
|
|
check_type: Checks the type-equality for input and output.
|
|
strict_mode: Stricter shape inference, it will throw errors if any;
|
|
Otherwise, simply stop if any error.
|
|
data_prop: Enables data propagation for limited operators to perform shape computation.
|
|
|
|
Returns:
|
|
(ModelProto) model with inferred shape information
|
|
"""
|
|
if isinstance(model, (ModelProto, bytes)):
|
|
model_str = model if isinstance(model, bytes) else model.SerializeToString()
|
|
inferred_model_str = C.infer_shapes(
|
|
model_str, check_type, strict_mode, data_prop
|
|
)
|
|
return onnx.load_from_string(inferred_model_str)
|
|
if isinstance(model, str):
|
|
raise TypeError(
|
|
"infer_shapes only accepts ModelProto or bytes,"
|
|
"you can use infer_shapes_path for the model path (String)."
|
|
)
|
|
|
|
raise TypeError(
|
|
f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}"
|
|
)
|
|
|
|
|
|
def infer_shapes_path(
|
|
model_path: str | os.PathLike,
|
|
output_path: str | os.PathLike = "",
|
|
check_type: bool = False,
|
|
strict_mode: bool = False,
|
|
data_prop: bool = False,
|
|
) -> None:
|
|
"""Take model path for shape_inference.
|
|
|
|
This function is the same as :func:`infer_shape` but supports >2GB models.
|
|
The function outputs the inferred model to the `output_path`. The original model path
|
|
is used if not specified.
|
|
"""
|
|
if isinstance(model_path, ModelProto):
|
|
raise TypeError(
|
|
"infer_shapes_path only accepts model Path (String),"
|
|
"you can use infer_shapes for the ModelProto."
|
|
)
|
|
try:
|
|
model_path = os.fspath(model_path)
|
|
except TypeError as exp:
|
|
raise TypeError(
|
|
"infer_shapes_path only accepts model path as a string or PathLike, "
|
|
f"incorrect model path type: {type(model_path)}"
|
|
) from exp
|
|
try:
|
|
output_path = os.fspath(output_path)
|
|
except TypeError as exp:
|
|
raise TypeError(
|
|
"infer_shapes_path only accepts output path as a string or PathLike, "
|
|
f"incorrect output path type: {type(output_path)}"
|
|
) from exp
|
|
|
|
if output_path == "":
|
|
output_path = model_path
|
|
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
|
|
|
|
|
|
def infer_node_outputs(
|
|
schema: onnx.defs.OpSchema,
|
|
node: onnx.NodeProto,
|
|
input_types: dict[str, onnx.TypeProto],
|
|
input_data: dict[str, onnx.TensorProto] | None = None,
|
|
input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,
|
|
opset_imports: list[onnx.OperatorSetIdProto] | None = None,
|
|
ir_version: int = onnx.IR_VERSION,
|
|
) -> dict[str, onnx.TypeProto]:
|
|
if not schema.has_type_and_shape_inference_function: # type: ignore
|
|
return {}
|
|
if input_data is None:
|
|
input_data = {}
|
|
if input_sparse_data is None:
|
|
input_sparse_data = {}
|
|
if opset_imports is None:
|
|
passed_opset_imports = {}
|
|
else:
|
|
passed_opset_imports = {opset.domain: opset.version for opset in opset_imports}
|
|
|
|
# catch KeyError if node's input does not exist in input_types
|
|
passed_input_types = {
|
|
key: input_types[key].SerializeToString() for key in node.input if key != ""
|
|
}
|
|
# input_types will also be used as outer_scope_value_types so do not filter by node's input here
|
|
for key in input_types:
|
|
if key not in passed_input_types:
|
|
passed_input_types[key] = input_types[key].SerializeToString()
|
|
passed_input_data = {
|
|
key: input_data[key].SerializeToString()
|
|
for key in node.input
|
|
if key in input_data
|
|
}
|
|
passed_sparse_input_data = {
|
|
key: input_sparse_data[key].SerializeToString()
|
|
for key in node.input
|
|
if key in input_sparse_data
|
|
}
|
|
|
|
outputs = schema._infer_node_outputs(
|
|
node.SerializeToString(),
|
|
passed_input_types,
|
|
passed_input_data,
|
|
passed_sparse_input_data,
|
|
passed_opset_imports,
|
|
ir_version,
|
|
) # type: ignore[call-arg]
|
|
return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}
|
|
|
|
|
|
def infer_function_output_types(
|
|
function: FunctionProto,
|
|
input_types: Sequence[TypeProto],
|
|
attributes: Sequence[AttributeProto],
|
|
) -> list[TypeProto]:
|
|
"""Apply type-and-shape-inference to given function body, with given input types
|
|
and given input attribute values.
|
|
"""
|
|
result = C.infer_function_output_types(
|
|
function.SerializeToString(),
|
|
[x.SerializeToString() for x in input_types],
|
|
[x.SerializeToString() for x in attributes],
|
|
)
|
|
|
|
def to_type_proto(x) -> TypeProto:
|
|
type_proto = onnx.TypeProto()
|
|
type_proto.ParseFromString(x)
|
|
return type_proto
|
|
|
|
return [to_type_proto(x) for x in result]
|
|
|
|
|
|
InferenceError = C.InferenceError
|