210 lines
7.7 KiB
Python
210 lines
7.7 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
import warnings
|
|
|
|
__all__ = [
|
|
"registry",
|
|
]
|
|
|
|
import typing
|
|
from typing import Any, Collection, Optional, Protocol, TypeVar
|
|
|
|
import google.protobuf.json_format
|
|
import google.protobuf.message
|
|
import google.protobuf.text_format
|
|
|
|
import onnx
|
|
|
|
_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message)
|
|
# Encoding used for serializing and deserializing text files
|
|
_ENCODING = "utf-8"
|
|
|
|
|
|
class ProtoSerializer(Protocol):
|
|
"""A serializer-deserializer to and from in-memory Protocol Buffers representations."""
|
|
|
|
# Format supported by the serializer. E.g. "protobuf"
|
|
supported_format: str
|
|
# File extensions supported by the serializer. E.g. frozenset({".onnx", ".pb"})
|
|
# Be careful to include the dot in the file extension.
|
|
file_extensions: Collection[str]
|
|
|
|
# NOTE: The methods defined are serialize_proto and deserialize_proto and not the
|
|
# more generic serialize and deserialize to leave space for future protocols
|
|
# that are defined to serialize/deserialize the ONNX in memory IR.
|
|
# This way a class can implement both protocols.
|
|
|
|
def serialize_proto(self, proto: _Proto) -> Any:
|
|
"""Serialize a in-memory proto to a serialized data type."""
|
|
|
|
def deserialize_proto(self, serialized: Any, proto: _Proto) -> _Proto:
|
|
"""Parse a serialized data type into a in-memory proto."""
|
|
|
|
|
|
class _Registry:
|
|
def __init__(self) -> None:
|
|
self._serializers: dict[str, ProtoSerializer] = {}
|
|
# A mapping from file extension to format
|
|
self._extension_to_format: dict[str, str] = {}
|
|
|
|
def register(self, serializer: ProtoSerializer) -> None:
|
|
self._serializers[serializer.supported_format] = serializer
|
|
self._extension_to_format.update(
|
|
{ext: serializer.supported_format for ext in serializer.file_extensions}
|
|
)
|
|
|
|
def get(self, fmt: str) -> ProtoSerializer:
|
|
"""Get a serializer for a format.
|
|
|
|
Args:
|
|
fmt: The format to get a serializer for.
|
|
|
|
Returns:
|
|
ProtoSerializer: The serializer for the format.
|
|
|
|
Raises:
|
|
ValueError: If the format is not supported.
|
|
"""
|
|
try:
|
|
return self._serializers[fmt]
|
|
except KeyError:
|
|
raise ValueError(
|
|
f"Unsupported format: '{fmt}'. Supported formats are: {self._serializers.keys()}"
|
|
) from None
|
|
|
|
def get_format_from_file_extension(self, file_extension: str) -> str | None:
|
|
"""Get the corresponding format from a file extension.
|
|
|
|
Args:
|
|
file_extension: The file extension to get a format for.
|
|
|
|
Returns:
|
|
The format for the file extension, or None if not found.
|
|
"""
|
|
return self._extension_to_format.get(file_extension)
|
|
|
|
|
|
class _ProtobufSerializer(ProtoSerializer):
|
|
"""Serialize and deserialize protobuf message."""
|
|
|
|
supported_format = "protobuf"
|
|
file_extensions = frozenset({".onnx", ".pb"})
|
|
|
|
def serialize_proto(self, proto: _Proto) -> bytes:
|
|
if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString):
|
|
try:
|
|
result = proto.SerializeToString()
|
|
except ValueError as e:
|
|
if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
|
|
raise ValueError(
|
|
"The proto size is larger than the 2 GB limit. "
|
|
"Please use save_as_external_data to save tensors separately from the model file."
|
|
) from e
|
|
raise
|
|
return result # type: ignore
|
|
raise TypeError(
|
|
f"No SerializeToString method is detected.\ntype is {type(proto)}"
|
|
)
|
|
|
|
def deserialize_proto(self, serialized: bytes, proto: _Proto) -> _Proto:
|
|
if not isinstance(serialized, bytes):
|
|
raise TypeError(
|
|
f"Parameter 'serialized' must be bytes, but got type: {type(serialized)}"
|
|
)
|
|
decoded = typing.cast(Optional[int], proto.ParseFromString(serialized))
|
|
if decoded is not None and decoded != len(serialized):
|
|
raise google.protobuf.message.DecodeError(
|
|
f"Protobuf decoding consumed too few bytes: {decoded} out of {len(serialized)}"
|
|
)
|
|
return proto
|
|
|
|
|
|
class _TextProtoSerializer(ProtoSerializer):
|
|
"""Serialize and deserialize text proto."""
|
|
|
|
supported_format = "textproto"
|
|
file_extensions = frozenset({".textproto", ".prototxt", ".pbtxt"})
|
|
|
|
def serialize_proto(self, proto: _Proto) -> bytes:
|
|
textproto = google.protobuf.text_format.MessageToString(proto)
|
|
return textproto.encode(_ENCODING)
|
|
|
|
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
|
|
if not isinstance(serialized, (bytes, str)):
|
|
raise TypeError(
|
|
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
|
|
)
|
|
if isinstance(serialized, bytes):
|
|
serialized = serialized.decode(_ENCODING)
|
|
assert isinstance(serialized, str)
|
|
return google.protobuf.text_format.Parse(serialized, proto)
|
|
|
|
|
|
class _JsonSerializer(ProtoSerializer):
|
|
"""Serialize and deserialize JSON."""
|
|
|
|
supported_format = "json"
|
|
file_extensions = frozenset({".json", ".onnxjson"})
|
|
|
|
def serialize_proto(self, proto: _Proto) -> bytes:
|
|
json_message = google.protobuf.json_format.MessageToJson(
|
|
proto, preserving_proto_field_name=True
|
|
)
|
|
return json_message.encode(_ENCODING)
|
|
|
|
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
|
|
if not isinstance(serialized, (bytes, str)):
|
|
raise TypeError(
|
|
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
|
|
)
|
|
if isinstance(serialized, bytes):
|
|
serialized = serialized.decode(_ENCODING)
|
|
assert isinstance(serialized, str)
|
|
return google.protobuf.json_format.Parse(serialized, proto)
|
|
|
|
|
|
class _TextualSerializer(ProtoSerializer):
|
|
"""Serialize and deserialize the ONNX textual representation."""
|
|
|
|
supported_format = "onnxtxt"
|
|
file_extensions = frozenset({".onnxtxt", ".onnxtext"})
|
|
|
|
def serialize_proto(self, proto: _Proto) -> bytes:
|
|
text = onnx.printer.to_text(proto) # type: ignore[arg-type]
|
|
return text.encode(_ENCODING)
|
|
|
|
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
|
|
warnings.warn(
|
|
"The onnxtxt format is experimental. Please report any errors to the ONNX GitHub repository.",
|
|
stacklevel=2,
|
|
)
|
|
if not isinstance(serialized, (bytes, str)):
|
|
raise TypeError(
|
|
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
|
|
)
|
|
if isinstance(serialized, bytes):
|
|
text = serialized.decode(_ENCODING)
|
|
else:
|
|
text = serialized
|
|
if isinstance(proto, onnx.ModelProto):
|
|
return onnx.parser.parse_model(text) # type: ignore[return-value]
|
|
if isinstance(proto, onnx.GraphProto):
|
|
return onnx.parser.parse_graph(text) # type: ignore[return-value]
|
|
if isinstance(proto, onnx.FunctionProto):
|
|
return onnx.parser.parse_function(text) # type: ignore[return-value]
|
|
if isinstance(proto, onnx.NodeProto):
|
|
return onnx.parser.parse_node(text) # type: ignore[return-value]
|
|
raise ValueError(f"Unsupported proto type: {type(proto)}")
|
|
|
|
|
|
# Register default serializers
|
|
registry = _Registry()
|
|
registry.register(_ProtobufSerializer())
|
|
registry.register(_TextProtoSerializer())
|
|
registry.register(_JsonSerializer())
|
|
registry.register(_TextualSerializer())
|