357 lines
12 KiB
Python
357 lines
12 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
__all__ = [
|
|
# Constants
|
|
"ONNX_ML",
|
|
"IR_VERSION",
|
|
"IR_VERSION_2017_10_10",
|
|
"IR_VERSION_2017_10_30",
|
|
"IR_VERSION_2017_11_3",
|
|
"IR_VERSION_2019_1_22",
|
|
"IR_VERSION_2019_3_18",
|
|
"IR_VERSION_2019_9_19",
|
|
"IR_VERSION_2020_5_8",
|
|
"IR_VERSION_2021_7_30",
|
|
"IR_VERSION_2023_5_5",
|
|
"EXPERIMENTAL",
|
|
"STABLE",
|
|
# Modules
|
|
"checker",
|
|
"compose",
|
|
"defs",
|
|
"gen_proto",
|
|
"helper",
|
|
"hub",
|
|
"mapping",
|
|
"numpy_helper",
|
|
"parser",
|
|
"printer",
|
|
"shape_inference",
|
|
"utils",
|
|
"version_converter",
|
|
# Proto classes
|
|
"AttributeProto",
|
|
"FunctionProto",
|
|
"GraphProto",
|
|
"MapProto",
|
|
"ModelProto",
|
|
"NodeProto",
|
|
"OperatorProto",
|
|
"OperatorSetIdProto",
|
|
"OperatorSetProto",
|
|
"OperatorStatus",
|
|
"OptionalProto",
|
|
"SequenceProto",
|
|
"SparseTensorProto",
|
|
"StringStringEntryProto",
|
|
"TensorAnnotation",
|
|
"TensorProto",
|
|
"TensorShapeProto",
|
|
"TrainingInfoProto",
|
|
"TypeProto",
|
|
"ValueInfoProto",
|
|
"Version",
|
|
# Utility functions
|
|
"convert_model_to_external_data",
|
|
"load_external_data_for_model",
|
|
"load_model_from_string",
|
|
"load_model",
|
|
"load_tensor_from_string",
|
|
"load_tensor",
|
|
"save_model",
|
|
"save_tensor",
|
|
"write_external_data_tensors",
|
|
]
|
|
# isort:skip_file
|
|
|
|
import os
|
|
import typing
|
|
from typing import IO, Literal, Union
|
|
|
|
|
|
from onnx import serialization
|
|
from onnx.onnx_cpp2py_export import ONNX_ML
|
|
from onnx.external_data_helper import (
|
|
load_external_data_for_model,
|
|
write_external_data_tensors,
|
|
convert_model_to_external_data,
|
|
)
|
|
from onnx.onnx_pb import (
|
|
AttributeProto,
|
|
EXPERIMENTAL,
|
|
FunctionProto,
|
|
GraphProto,
|
|
IR_VERSION,
|
|
IR_VERSION_2017_10_10,
|
|
IR_VERSION_2017_10_30,
|
|
IR_VERSION_2017_11_3,
|
|
IR_VERSION_2019_1_22,
|
|
IR_VERSION_2019_3_18,
|
|
IR_VERSION_2019_9_19,
|
|
IR_VERSION_2020_5_8,
|
|
IR_VERSION_2021_7_30,
|
|
IR_VERSION_2023_5_5,
|
|
ModelProto,
|
|
NodeProto,
|
|
OperatorSetIdProto,
|
|
OperatorStatus,
|
|
STABLE,
|
|
SparseTensorProto,
|
|
StringStringEntryProto,
|
|
TensorAnnotation,
|
|
TensorProto,
|
|
TensorShapeProto,
|
|
TrainingInfoProto,
|
|
TypeProto,
|
|
ValueInfoProto,
|
|
Version,
|
|
)
|
|
from onnx.onnx_operators_pb import OperatorProto, OperatorSetProto
|
|
from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto
|
|
from onnx.version import version as __version__
|
|
|
|
# Import common subpackages so they're available when you 'import onnx'
|
|
from onnx import (
|
|
checker,
|
|
compose,
|
|
defs,
|
|
gen_proto,
|
|
helper,
|
|
hub,
|
|
mapping,
|
|
numpy_helper,
|
|
parser,
|
|
printer,
|
|
shape_inference,
|
|
utils,
|
|
version_converter,
|
|
)
|
|
|
|
# Supported model formats that can be loaded from and saved to
|
|
# The literals are formats with built-in support. But we also allow users to
|
|
# register their own formats. So we allow str as well.
|
|
_SupportedFormat = Union[
|
|
Literal["protobuf", "textproto", "onnxtxt", "json"], str # noqa: PYI051
|
|
]
|
|
# Default serialization format
|
|
_DEFAULT_FORMAT = "protobuf"
|
|
|
|
|
|
def _load_bytes(f: IO[bytes] | str | os.PathLike) -> bytes:
|
|
if hasattr(f, "read") and callable(typing.cast(IO[bytes], f).read):
|
|
content = typing.cast(IO[bytes], f).read()
|
|
else:
|
|
f = typing.cast(Union[str, os.PathLike], f)
|
|
with open(f, "rb") as readable:
|
|
content = readable.read()
|
|
return content
|
|
|
|
|
|
def _save_bytes(content: bytes, f: IO[bytes] | str | os.PathLike) -> None:
|
|
if hasattr(f, "write") and callable(typing.cast(IO[bytes], f).write):
|
|
typing.cast(IO[bytes], f).write(content)
|
|
else:
|
|
f = typing.cast(Union[str, os.PathLike], f)
|
|
with open(f, "wb") as writable:
|
|
writable.write(content)
|
|
|
|
|
|
def _get_file_path(f: IO[bytes] | str | os.PathLike | None) -> str | None:
|
|
if isinstance(f, (str, os.PathLike)):
|
|
return os.path.abspath(f)
|
|
if hasattr(f, "name"):
|
|
assert f is not None
|
|
return os.path.abspath(f.name)
|
|
return None
|
|
|
|
|
|
def _get_serializer(
|
|
fmt: _SupportedFormat | None, f: str | os.PathLike | IO[bytes] | None = None
|
|
) -> serialization.ProtoSerializer:
|
|
"""Get the serializer for the given path and format from the serialization registry."""
|
|
# Use fmt if it is specified
|
|
if fmt is not None:
|
|
return serialization.registry.get(fmt)
|
|
|
|
if (file_path := _get_file_path(f)) is not None:
|
|
_, ext = os.path.splitext(file_path)
|
|
fmt = serialization.registry.get_format_from_file_extension(ext)
|
|
|
|
# Failed to resolve format if fmt is None. Use protobuf as default
|
|
fmt = fmt or _DEFAULT_FORMAT
|
|
assert fmt is not None
|
|
|
|
return serialization.registry.get(fmt)
|
|
|
|
|
|
def load_model(
|
|
f: IO[bytes] | str | os.PathLike,
|
|
format: _SupportedFormat | None = None, # noqa: A002
|
|
load_external_data: bool = True,
|
|
) -> ModelProto:
|
|
"""Loads a serialized ModelProto into memory.
|
|
|
|
Args:
|
|
f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
|
|
format: The serialization format. When it is not specified, it is inferred
|
|
from the file extension when ``f`` is a path. If not specified _and_
|
|
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
|
|
be "utf-8" when the format is a text format.
|
|
load_external_data: Whether to load the external data.
|
|
Set to True if the data is under the same directory of the model.
|
|
If not, users need to call :func:`load_external_data_for_model`
|
|
with directory to load external data from.
|
|
|
|
Returns:
|
|
Loaded in-memory ModelProto.
|
|
"""
|
|
model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
|
|
|
|
if load_external_data:
|
|
model_filepath = _get_file_path(f)
|
|
if model_filepath:
|
|
base_dir = os.path.dirname(model_filepath)
|
|
load_external_data_for_model(model, base_dir)
|
|
|
|
return model
|
|
|
|
|
|
def load_tensor(
|
|
f: IO[bytes] | str | os.PathLike,
|
|
format: _SupportedFormat | None = None, # noqa: A002
|
|
) -> TensorProto:
|
|
"""Loads a serialized TensorProto into memory.
|
|
|
|
Args:
|
|
f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
|
|
format: The serialization format. When it is not specified, it is inferred
|
|
from the file extension when ``f`` is a path. If not specified _and_
|
|
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
|
|
be "utf-8" when the format is a text format.
|
|
|
|
Returns:
|
|
Loaded in-memory TensorProto.
|
|
"""
|
|
return _get_serializer(format, f).deserialize_proto(_load_bytes(f), TensorProto())
|
|
|
|
|
|
def load_model_from_string(
|
|
s: bytes | str,
|
|
format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002
|
|
) -> ModelProto:
|
|
"""Loads a binary string (bytes) that contains serialized ModelProto.
|
|
|
|
Args:
|
|
s: a string, which contains serialized ModelProto
|
|
format: The serialization format. When it is not specified, it is inferred
|
|
from the file extension when ``f`` is a path. If not specified _and_
|
|
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
|
|
be "utf-8" when the format is a text format.
|
|
|
|
Returns:
|
|
Loaded in-memory ModelProto.
|
|
"""
|
|
return _get_serializer(format).deserialize_proto(s, ModelProto())
|
|
|
|
|
|
def load_tensor_from_string(
|
|
s: bytes,
|
|
format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002
|
|
) -> TensorProto:
|
|
"""Loads a binary string (bytes) that contains serialized TensorProto.
|
|
|
|
Args:
|
|
s: a string, which contains serialized TensorProto
|
|
format: The serialization format. When it is not specified, it is inferred
|
|
from the file extension when ``f`` is a path. If not specified _and_
|
|
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
|
|
be "utf-8" when the format is a text format.
|
|
|
|
Returns:
|
|
Loaded in-memory TensorProto.
|
|
"""
|
|
return _get_serializer(format).deserialize_proto(s, TensorProto())
|
|
|
|
|
|
def save_model(
|
|
proto: ModelProto | bytes,
|
|
f: IO[bytes] | str | os.PathLike,
|
|
format: _SupportedFormat | None = None, # noqa: A002
|
|
*,
|
|
save_as_external_data: bool = False,
|
|
all_tensors_to_one_file: bool = True,
|
|
location: str | None = None,
|
|
size_threshold: int = 1024,
|
|
convert_attribute: bool = False,
|
|
) -> None:
|
|
"""Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving.
|
|
|
|
Args:
|
|
proto: should be a in-memory ModelProto
|
|
f: can be a file-like object (has "write" function) or a string containing
|
|
a file name or a pathlike object
|
|
format: The serialization format. When it is not specified, it is inferred
|
|
from the file extension when ``f`` is a path. If not specified _and_
|
|
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
|
|
be "utf-8" when the format is a text format.
|
|
save_as_external_data: If true, save tensors to external file(s).
|
|
all_tensors_to_one_file: Effective only if save_as_external_data is True.
|
|
If true, save all tensors to one external file specified by location.
|
|
If false, save each tensor to a file named with the tensor name.
|
|
location: Effective only if save_as_external_data is true.
|
|
Specify the external file that all tensors to save to.
|
|
Path is relative to the model path.
|
|
If not specified, will use the model name.
|
|
size_threshold: Effective only if save_as_external_data is True.
|
|
Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted
|
|
to external data. To convert every tensor with raw data to external data set size_threshold=0.
|
|
convert_attribute: Effective only if save_as_external_data is True.
|
|
If true, convert all tensors to external data
|
|
If false, convert only non-attribute tensors to external data
|
|
"""
|
|
if isinstance(proto, bytes):
|
|
proto = _get_serializer(_DEFAULT_FORMAT).deserialize_proto(proto, ModelProto())
|
|
|
|
if save_as_external_data:
|
|
convert_model_to_external_data(
|
|
proto, all_tensors_to_one_file, location, size_threshold, convert_attribute
|
|
)
|
|
|
|
model_filepath = _get_file_path(f)
|
|
if model_filepath is not None:
|
|
basepath = os.path.dirname(model_filepath)
|
|
proto = write_external_data_tensors(proto, basepath)
|
|
|
|
serialized = _get_serializer(format, model_filepath).serialize_proto(proto)
|
|
_save_bytes(serialized, f)
|
|
|
|
|
|
def save_tensor(
|
|
proto: TensorProto,
|
|
f: IO[bytes] | str | os.PathLike,
|
|
format: _SupportedFormat | None = None, # noqa: A002
|
|
) -> None:
|
|
"""Saves the TensorProto to the specified path.
|
|
|
|
Args:
|
|
proto: should be a in-memory TensorProto
|
|
f: can be a file-like object (has "write" function) or a string
|
|
containing a file name or a pathlike object.
|
|
format: The serialization format. When it is not specified, it is inferred
|
|
from the file extension when ``f`` is a path. If not specified _and_
|
|
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
|
|
be "utf-8" when the format is a text format.
|
|
"""
|
|
serialized = _get_serializer(format, f).serialize_proto(proto)
|
|
_save_bytes(serialized, f)
|
|
|
|
|
|
# For backward compatibility
|
|
load = load_model
|
|
load_from_string = load_model_from_string
|
|
save = save_model
|