Files
2024-10-30 22:14:35 +01:00

210 lines
7.7 KiB
Python

# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import warnings
__all__ = [
"registry",
]
import typing
from typing import Any, Collection, Optional, Protocol, TypeVar
import google.protobuf.json_format
import google.protobuf.message
import google.protobuf.text_format
import onnx
_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message)
# Encoding used for serializing and deserializing text files
_ENCODING = "utf-8"
class ProtoSerializer(Protocol):
"""A serializer-deserializer to and from in-memory Protocol Buffers representations."""
# Format supported by the serializer. E.g. "protobuf"
supported_format: str
# File extensions supported by the serializer. E.g. frozenset({".onnx", ".pb"})
# Be careful to include the dot in the file extension.
file_extensions: Collection[str]
# NOTE: The methods defined are serialize_proto and deserialize_proto and not the
# more generic serialize and deserialize to leave space for future protocols
# that are defined to serialize/deserialize the ONNX in memory IR.
# This way a class can implement both protocols.
def serialize_proto(self, proto: _Proto) -> Any:
"""Serialize a in-memory proto to a serialized data type."""
def deserialize_proto(self, serialized: Any, proto: _Proto) -> _Proto:
"""Parse a serialized data type into a in-memory proto."""
class _Registry:
def __init__(self) -> None:
self._serializers: dict[str, ProtoSerializer] = {}
# A mapping from file extension to format
self._extension_to_format: dict[str, str] = {}
def register(self, serializer: ProtoSerializer) -> None:
self._serializers[serializer.supported_format] = serializer
self._extension_to_format.update(
{ext: serializer.supported_format for ext in serializer.file_extensions}
)
def get(self, fmt: str) -> ProtoSerializer:
"""Get a serializer for a format.
Args:
fmt: The format to get a serializer for.
Returns:
ProtoSerializer: The serializer for the format.
Raises:
ValueError: If the format is not supported.
"""
try:
return self._serializers[fmt]
except KeyError:
raise ValueError(
f"Unsupported format: '{fmt}'. Supported formats are: {self._serializers.keys()}"
) from None
def get_format_from_file_extension(self, file_extension: str) -> str | None:
"""Get the corresponding format from a file extension.
Args:
file_extension: The file extension to get a format for.
Returns:
The format for the file extension, or None if not found.
"""
return self._extension_to_format.get(file_extension)
class _ProtobufSerializer(ProtoSerializer):
"""Serialize and deserialize protobuf message."""
supported_format = "protobuf"
file_extensions = frozenset({".onnx", ".pb"})
def serialize_proto(self, proto: _Proto) -> bytes:
if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString):
try:
result = proto.SerializeToString()
except ValueError as e:
if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
raise ValueError(
"The proto size is larger than the 2 GB limit. "
"Please use save_as_external_data to save tensors separately from the model file."
) from e
raise
return result # type: ignore
raise TypeError(
f"No SerializeToString method is detected.\ntype is {type(proto)}"
)
def deserialize_proto(self, serialized: bytes, proto: _Proto) -> _Proto:
if not isinstance(serialized, bytes):
raise TypeError(
f"Parameter 'serialized' must be bytes, but got type: {type(serialized)}"
)
decoded = typing.cast(Optional[int], proto.ParseFromString(serialized))
if decoded is not None and decoded != len(serialized):
raise google.protobuf.message.DecodeError(
f"Protobuf decoding consumed too few bytes: {decoded} out of {len(serialized)}"
)
return proto
class _TextProtoSerializer(ProtoSerializer):
"""Serialize and deserialize text proto."""
supported_format = "textproto"
file_extensions = frozenset({".textproto", ".prototxt", ".pbtxt"})
def serialize_proto(self, proto: _Proto) -> bytes:
textproto = google.protobuf.text_format.MessageToString(proto)
return textproto.encode(_ENCODING)
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
if not isinstance(serialized, (bytes, str)):
raise TypeError(
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
)
if isinstance(serialized, bytes):
serialized = serialized.decode(_ENCODING)
assert isinstance(serialized, str)
return google.protobuf.text_format.Parse(serialized, proto)
class _JsonSerializer(ProtoSerializer):
"""Serialize and deserialize JSON."""
supported_format = "json"
file_extensions = frozenset({".json", ".onnxjson"})
def serialize_proto(self, proto: _Proto) -> bytes:
json_message = google.protobuf.json_format.MessageToJson(
proto, preserving_proto_field_name=True
)
return json_message.encode(_ENCODING)
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
if not isinstance(serialized, (bytes, str)):
raise TypeError(
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
)
if isinstance(serialized, bytes):
serialized = serialized.decode(_ENCODING)
assert isinstance(serialized, str)
return google.protobuf.json_format.Parse(serialized, proto)
class _TextualSerializer(ProtoSerializer):
"""Serialize and deserialize the ONNX textual representation."""
supported_format = "onnxtxt"
file_extensions = frozenset({".onnxtxt", ".onnxtext"})
def serialize_proto(self, proto: _Proto) -> bytes:
text = onnx.printer.to_text(proto) # type: ignore[arg-type]
return text.encode(_ENCODING)
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
warnings.warn(
"The onnxtxt format is experimental. Please report any errors to the ONNX GitHub repository.",
stacklevel=2,
)
if not isinstance(serialized, (bytes, str)):
raise TypeError(
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
)
if isinstance(serialized, bytes):
text = serialized.decode(_ENCODING)
else:
text = serialized
if isinstance(proto, onnx.ModelProto):
return onnx.parser.parse_model(text) # type: ignore[return-value]
if isinstance(proto, onnx.GraphProto):
return onnx.parser.parse_graph(text) # type: ignore[return-value]
if isinstance(proto, onnx.FunctionProto):
return onnx.parser.parse_function(text) # type: ignore[return-value]
if isinstance(proto, onnx.NodeProto):
return onnx.parser.parse_node(text) # type: ignore[return-value]
raise ValueError(f"Unsupported proto type: {type(proto)}")
# Register default serializers
registry = _Registry()
registry.register(_ProtobufSerializer())
registry.register(_TextProtoSerializer())
registry.register(_JsonSerializer())
registry.register(_TextualSerializer())