188 lines
5.7 KiB
Python
188 lines
5.7 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
"""Graph utilities for checking whether an ONNX proto message is legal."""
|
|
|
|
from __future__ import annotations
|
|
|
|
__all__ = [
|
|
"check_attribute",
|
|
"check_function",
|
|
"check_graph",
|
|
"check_model",
|
|
"check_node",
|
|
"check_sparse_tensor",
|
|
"check_tensor",
|
|
"check_value_info",
|
|
"DEFAULT_CONTEXT",
|
|
"LEXICAL_SCOPE_CONTEXT",
|
|
"ValidationError",
|
|
"C",
|
|
"MAXIMUM_PROTOBUF",
|
|
]
|
|
|
|
import os
|
|
import sys
|
|
from typing import Any, Callable, TypeVar
|
|
|
|
from google.protobuf.message import Message
|
|
|
|
import onnx.defs
|
|
import onnx.onnx_cpp2py_export.checker as C # noqa: N812
|
|
import onnx.shape_inference
|
|
from onnx import (
|
|
IR_VERSION,
|
|
AttributeProto,
|
|
FunctionProto,
|
|
GraphProto,
|
|
ModelProto,
|
|
NodeProto,
|
|
SparseTensorProto,
|
|
TensorProto,
|
|
ValueInfoProto,
|
|
)
|
|
|
|
# Limitation of single protobuf file is 2GB
|
|
MAXIMUM_PROTOBUF = 2000000000
|
|
|
|
# TODO: This thing where we reserialize the protobuf back into the
|
|
# string, only to deserialize it at the call site, is really goofy.
|
|
# Stop doing that.
|
|
|
|
|
|
# NB: Please don't edit this context!
|
|
DEFAULT_CONTEXT = C.CheckerContext()
|
|
DEFAULT_CONTEXT.ir_version = IR_VERSION
|
|
# TODO: Maybe ONNX-ML should also be defaulted?
|
|
DEFAULT_CONTEXT.opset_imports = {"": onnx.defs.onnx_opset_version()}
|
|
|
|
LEXICAL_SCOPE_CONTEXT = C.LexicalScopeContext()
|
|
|
|
|
|
FuncType = TypeVar("FuncType", bound=Callable[..., Any])
|
|
|
|
|
|
def _ensure_proto_type(proto: Message, proto_type: type[Message]) -> None:
|
|
if not isinstance(proto, proto_type):
|
|
raise TypeError(
|
|
f"The proto message needs to be of type '{proto_type.__name__}'"
|
|
)
|
|
|
|
|
|
def check_value_info(
|
|
value_info: ValueInfoProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
|
|
) -> None:
|
|
_ensure_proto_type(value_info, ValueInfoProto)
|
|
return C.check_value_info(value_info.SerializeToString(), ctx)
|
|
|
|
|
|
def check_tensor(tensor: TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None:
|
|
_ensure_proto_type(tensor, TensorProto)
|
|
return C.check_tensor(tensor.SerializeToString(), ctx)
|
|
|
|
|
|
def check_attribute(
|
|
attr: AttributeProto,
|
|
ctx: C.CheckerContext = DEFAULT_CONTEXT,
|
|
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
|
|
) -> None:
|
|
_ensure_proto_type(attr, AttributeProto)
|
|
return C.check_attribute(attr.SerializeToString(), ctx, lexical_scope_ctx)
|
|
|
|
|
|
def check_node(
|
|
node: NodeProto,
|
|
ctx: C.CheckerContext = DEFAULT_CONTEXT,
|
|
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
|
|
) -> None:
|
|
_ensure_proto_type(node, NodeProto)
|
|
return C.check_node(node.SerializeToString(), ctx, lexical_scope_ctx)
|
|
|
|
|
|
def check_function(
|
|
function: FunctionProto,
|
|
ctx: C.CheckerContext | None = None,
|
|
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
|
|
) -> None:
|
|
_ensure_proto_type(function, FunctionProto)
|
|
if ctx is None:
|
|
ctx = C.CheckerContext()
|
|
ctx.ir_version = onnx.helper.find_min_ir_version_for(
|
|
function.opset_import, ignore_unknown=True
|
|
)
|
|
ctx.opset_imports = {
|
|
domain_version.domain: domain_version.version
|
|
for domain_version in function.opset_import
|
|
}
|
|
C.check_function(function.SerializeToString(), ctx, lexical_scope_ctx)
|
|
|
|
|
|
def check_graph(
|
|
graph: GraphProto,
|
|
ctx: C.CheckerContext = DEFAULT_CONTEXT,
|
|
lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT,
|
|
) -> None:
|
|
_ensure_proto_type(graph, GraphProto)
|
|
return C.check_graph(graph.SerializeToString(), ctx, lexical_scope_ctx)
|
|
|
|
|
|
def check_sparse_tensor(
|
|
sparse: SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT
|
|
) -> None:
|
|
_ensure_proto_type(sparse, SparseTensorProto)
|
|
C.check_sparse_tensor(sparse.SerializeToString(), ctx)
|
|
|
|
|
|
def check_model(
|
|
model: ModelProto | str | bytes | os.PathLike,
|
|
full_check: bool = False,
|
|
skip_opset_compatibility_check: bool = False,
|
|
check_custom_domain: bool = False,
|
|
) -> None:
|
|
"""Check the consistency of a model.
|
|
|
|
An exception will be raised if the model's ir_version is not set
|
|
properly or is higher than checker's ir_version, or if the model
|
|
has duplicate keys in metadata_props.
|
|
|
|
If IR version >= 3, the model must specify opset_import.
|
|
If IR version < 3, the model cannot have any opset_import specified.
|
|
|
|
Args:
|
|
model: Model to check. If model is a path, the function checks model
|
|
path first. If the model bytes size is larger than 2GB, function
|
|
should be called using model path.
|
|
full_check: If True, the function also runs shape inference check.
|
|
skip_opset_compatibility_check: If True, the function skips the check for
|
|
opset compatibility.
|
|
check_custom_domain: If True, the function will check all domains. Otherwise
|
|
only check built-in domains.
|
|
"""
|
|
# If model is a path instead of ModelProto
|
|
if isinstance(model, (str, os.PathLike)):
|
|
C.check_model_path(
|
|
os.fspath(model),
|
|
full_check,
|
|
skip_opset_compatibility_check,
|
|
check_custom_domain,
|
|
)
|
|
else:
|
|
protobuf_string = (
|
|
model if isinstance(model, bytes) else model.SerializeToString()
|
|
)
|
|
# If the protobuf is larger than 2GB,
|
|
# remind users should use the model path to check
|
|
if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF:
|
|
raise ValueError(
|
|
"This protobuf of onnx model is too large (>2GB). Call check_model with model path instead."
|
|
)
|
|
C.check_model(
|
|
protobuf_string,
|
|
full_check,
|
|
skip_opset_compatibility_check,
|
|
check_custom_domain,
|
|
)
|
|
|
|
|
|
ValidationError = C.ValidationError
|