1652 lines
54 KiB
Python
1652 lines
54 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import collections.abc
|
|
import numbers
|
|
import struct
|
|
from cmath import isnan
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
KeysView,
|
|
List,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import google.protobuf.message
|
|
import numpy as np
|
|
|
|
import onnx._custom_element_types as custom_np_types
|
|
from onnx import (
|
|
IR_VERSION,
|
|
AttributeProto,
|
|
FunctionProto,
|
|
GraphProto,
|
|
MapProto,
|
|
ModelProto,
|
|
NodeProto,
|
|
OperatorSetIdProto,
|
|
OptionalProto,
|
|
SequenceProto,
|
|
SparseTensorProto,
|
|
TensorProto,
|
|
TensorShapeProto,
|
|
TrainingInfoProto,
|
|
TypeProto,
|
|
ValueInfoProto,
|
|
defs,
|
|
mapping,
|
|
subbyte,
|
|
)
|
|
|
|
VersionRowType = Union[Tuple[str, int, int, int], Tuple[str, int, int, int, int]]
|
|
VersionTableType = List[VersionRowType]
|
|
AssignmentBindingType = List[Tuple[str, str]]
|
|
|
|
# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions
|
|
# Both must be updated whenever a new version of ONNX is released.
|
|
VERSION_TABLE: VersionTableType = [
|
|
# Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version
|
|
("1.0", 3, 1, 1),
|
|
("1.1", 3, 5, 1),
|
|
("1.1.2", 3, 6, 1),
|
|
("1.2", 3, 7, 1),
|
|
("1.3", 3, 8, 1),
|
|
("1.4.1", 4, 9, 1),
|
|
("1.5.0", 5, 10, 1),
|
|
("1.6.0", 6, 11, 2),
|
|
("1.7.0", 7, 12, 2, 1),
|
|
("1.8.0", 7, 13, 2, 1),
|
|
("1.8.1", 7, 13, 2, 1),
|
|
("1.9.0", 7, 14, 2, 1),
|
|
("1.10.0", 8, 15, 2, 1),
|
|
("1.10.1", 8, 15, 2, 1),
|
|
("1.10.2", 8, 15, 2, 1),
|
|
("1.11.0", 8, 16, 3, 1),
|
|
("1.12.0", 8, 17, 3, 1),
|
|
("1.13.0", 8, 18, 3, 1),
|
|
("1.13.1", 8, 18, 3, 1),
|
|
("1.14.0", 9, 19, 3, 1),
|
|
("1.14.1", 9, 19, 3, 1),
|
|
("1.15.0", 9, 20, 4, 1),
|
|
("1.16.0", 10, 21, 5, 1),
|
|
("1.17.0", 10, 22, 5, 1),
|
|
]
|
|
|
|
VersionMapType = Dict[Tuple[str, int], int]
|
|
|
|
|
|
def create_op_set_id_version_map(table: VersionTableType) -> VersionMapType:
|
|
"""Create a map from (opset-domain, opset-version) to ir-version from above table."""
|
|
result: VersionMapType = {}
|
|
|
|
def process(release_version: str, ir_version: int, *args: Any) -> None:
|
|
del release_version # Unused
|
|
for pair in zip(["ai.onnx", "ai.onnx.ml", "ai.onnx.training"], args):
|
|
if pair not in result:
|
|
result[pair] = ir_version
|
|
if pair[0] == "ai.onnx.training":
|
|
result["ai.onnx.preview.training", pair[1]] = ir_version
|
|
|
|
for row in table:
|
|
process(*row)
|
|
return result
|
|
|
|
|
|
OP_SET_ID_VERSION_MAP = create_op_set_id_version_map(VERSION_TABLE)
|
|
|
|
|
|
def find_min_ir_version_for(
|
|
opsetidlist: Sequence[OperatorSetIdProto], ignore_unknown: bool = False
|
|
) -> int:
|
|
"""Given list of opset ids, determine minimum IR version required.
|
|
|
|
Args:
|
|
opsetidlist: A sequence of OperatorSetIdProto.
|
|
ignore_unknown: If True, ignore unknown domain and return default minimum
|
|
version for that domain.
|
|
|
|
Returns:
|
|
The minimum IR version required (integer)
|
|
"""
|
|
default_min_version = 3
|
|
|
|
def find_min(domain: str | None, version: int) -> int:
|
|
key = (domain or "ai.onnx", version)
|
|
if key in OP_SET_ID_VERSION_MAP:
|
|
return OP_SET_ID_VERSION_MAP[key]
|
|
if ignore_unknown:
|
|
return default_min_version
|
|
raise ValueError("Unsupported opset-version.")
|
|
|
|
if opsetidlist:
|
|
return max(find_min(x.domain, x.version) for x in opsetidlist)
|
|
return default_min_version # if no opsets specified
|
|
|
|
|
|
def make_node(
|
|
op_type: str,
|
|
inputs: Sequence[str],
|
|
outputs: Sequence[str],
|
|
name: str | None = None,
|
|
doc_string: str | None = None,
|
|
domain: str | None = None,
|
|
overload: str | None = None,
|
|
**kwargs: Any,
|
|
) -> NodeProto:
|
|
"""Construct a NodeProto.
|
|
|
|
Args:
|
|
op_type (string): The name of the operator to construct
|
|
inputs (list of string): list of input names
|
|
outputs (list of string): list of output names
|
|
name (string, default None): optional unique identifier for NodeProto
|
|
doc_string (string, default None): optional documentation string for NodeProto
|
|
domain (string, default None): optional domain for NodeProto.
|
|
If it's None, we will just use default domain (which is empty)
|
|
overload (string, default None): optional field, used to
|
|
resolve calls to model-local functions
|
|
**kwargs (dict): the attributes of the node. The acceptable values
|
|
are documented in :func:`make_attribute`.
|
|
|
|
Returns:
|
|
NodeProto
|
|
"""
|
|
node = NodeProto()
|
|
node.op_type = op_type
|
|
node.input.extend(inputs)
|
|
node.output.extend(outputs)
|
|
if name:
|
|
node.name = name
|
|
if doc_string:
|
|
node.doc_string = doc_string
|
|
if domain is not None:
|
|
node.domain = domain
|
|
if overload is not None:
|
|
node.overload = overload
|
|
if kwargs:
|
|
node.attribute.extend(
|
|
make_attribute(key, value)
|
|
for key, value in sorted(kwargs.items())
|
|
if value is not None
|
|
)
|
|
return node
|
|
|
|
|
|
def make_operatorsetid(
|
|
domain: str,
|
|
version: int,
|
|
) -> OperatorSetIdProto:
|
|
"""Construct an OperatorSetIdProto.
|
|
|
|
Args:
|
|
domain (string): The domain of the operator set id
|
|
version (integer): Version of operator set id
|
|
Returns:
|
|
OperatorSetIdProto
|
|
"""
|
|
operatorsetid = OperatorSetIdProto()
|
|
operatorsetid.domain = domain
|
|
operatorsetid.version = version
|
|
return operatorsetid
|
|
|
|
|
|
def make_graph(
|
|
nodes: Sequence[NodeProto],
|
|
name: str,
|
|
inputs: Sequence[ValueInfoProto],
|
|
outputs: Sequence[ValueInfoProto],
|
|
initializer: Sequence[TensorProto] | None = None,
|
|
doc_string: str | None = None,
|
|
value_info: Sequence[ValueInfoProto] | None = None,
|
|
sparse_initializer: Sequence[SparseTensorProto] | None = None,
|
|
) -> GraphProto:
|
|
"""Construct a GraphProto
|
|
|
|
Args:
|
|
nodes: list of NodeProto
|
|
name (string): graph name
|
|
inputs: list of ValueInfoProto
|
|
outputs: list of ValueInfoProto
|
|
initializer: list of TensorProto
|
|
doc_string (string): graph documentation
|
|
value_info: list of ValueInfoProto
|
|
sparse_initializer: list of SparseTensorProto
|
|
Returns:
|
|
GraphProto
|
|
"""
|
|
if initializer is None:
|
|
initializer = []
|
|
if sparse_initializer is None:
|
|
sparse_initializer = []
|
|
if value_info is None:
|
|
value_info = []
|
|
graph = GraphProto()
|
|
graph.node.extend(nodes)
|
|
graph.name = name
|
|
graph.input.extend(inputs)
|
|
graph.output.extend(outputs)
|
|
graph.initializer.extend(initializer)
|
|
graph.sparse_initializer.extend(sparse_initializer)
|
|
graph.value_info.extend(value_info)
|
|
if doc_string:
|
|
graph.doc_string = doc_string
|
|
return graph
|
|
|
|
|
|
def make_opsetid(domain: str, version: int) -> OperatorSetIdProto:
|
|
"""Construct an OperatorSetIdProto.
|
|
|
|
Args:
|
|
domain (string): The domain of the operator set id
|
|
version (integer): Version of operator set id
|
|
Returns:
|
|
OperatorSetIdProto
|
|
"""
|
|
opsetid = OperatorSetIdProto()
|
|
opsetid.domain = domain
|
|
opsetid.version = version
|
|
return opsetid
|
|
|
|
|
|
def make_function(
|
|
domain: str,
|
|
fname: str,
|
|
inputs: Sequence[str],
|
|
outputs: Sequence[str],
|
|
nodes: Sequence[NodeProto],
|
|
opset_imports: Sequence[OperatorSetIdProto],
|
|
attributes: Sequence[str] | None = None,
|
|
attribute_protos: Sequence[AttributeProto] | None = None,
|
|
doc_string: str | None = None,
|
|
overload: str | None = None,
|
|
value_info: Sequence[ValueInfoProto] | None = None,
|
|
) -> FunctionProto:
|
|
if attributes is None:
|
|
attributes = []
|
|
if attribute_protos is None:
|
|
attribute_protos = []
|
|
if value_info is None:
|
|
value_info = []
|
|
f = FunctionProto()
|
|
f.domain = domain
|
|
f.name = fname
|
|
f.input.extend(inputs)
|
|
f.output.extend(outputs)
|
|
f.node.extend(nodes)
|
|
f.opset_import.extend(opset_imports)
|
|
f.attribute.extend(attributes)
|
|
f.attribute_proto.extend(attribute_protos)
|
|
if doc_string:
|
|
f.doc_string = doc_string
|
|
if overload is not None:
|
|
f.overload = overload
|
|
f.value_info.extend(value_info)
|
|
return f
|
|
|
|
|
|
def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto:
|
|
"""Construct a ModelProto
|
|
|
|
Args:
|
|
graph (GraphProto): *make_graph* returns
|
|
**kwargs: any attribute to add to the returned instance
|
|
Returns:
|
|
ModelProto
|
|
"""
|
|
model = ModelProto()
|
|
# Touch model.ir_version so it is stored as the version from which it is
|
|
# generated.
|
|
model.ir_version = IR_VERSION
|
|
model.graph.CopyFrom(graph)
|
|
|
|
opset_imports: Sequence[OperatorSetIdProto] | None = None
|
|
opset_imports = kwargs.pop("opset_imports", None) # type: ignore
|
|
if opset_imports is not None:
|
|
model.opset_import.extend(opset_imports)
|
|
else:
|
|
# Default import
|
|
imp = model.opset_import.add()
|
|
imp.version = defs.onnx_opset_version()
|
|
|
|
functions: Sequence[FunctionProto] | None = None
|
|
functions = kwargs.pop("functions", None) # type: ignore
|
|
if functions is not None:
|
|
model.functions.extend(functions)
|
|
|
|
for k, v in kwargs.items():
|
|
# TODO: Does this work with repeated fields?
|
|
setattr(model, k, v)
|
|
return model
|
|
|
|
|
|
# An extension of make_model that infers an IR_VERSION for the model,
|
|
# if not specified, using a best-effort-basis.
|
|
def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto:
|
|
ir_version_field = "ir_version"
|
|
if ir_version_field not in kwargs:
|
|
opset_imports_field = "opset_imports"
|
|
imports = kwargs.get(opset_imports_field, [])
|
|
kwargs[ir_version_field] = find_min_ir_version_for(imports)
|
|
return make_model(graph, **kwargs)
|
|
|
|
|
|
def set_metadata_props(
|
|
proto: (
|
|
ModelProto
|
|
| GraphProto
|
|
| FunctionProto
|
|
| NodeProto
|
|
| TensorProto
|
|
| ValueInfoProto
|
|
),
|
|
dict_value: dict[str, str],
|
|
) -> None:
|
|
del proto.metadata_props[:]
|
|
for k, v in dict_value.items():
|
|
entry = proto.metadata_props.add()
|
|
entry.key = k
|
|
entry.value = v
|
|
|
|
|
|
def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None:
|
|
set_metadata_props(model, dict_value)
|
|
|
|
|
|
def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]:
|
|
return [
|
|
(ca[i // 2].real if (i % 2 == 0) else ca[i // 2].imag) # type: ignore[misc]
|
|
for i in range(len(ca) * 2)
|
|
]
|
|
|
|
|
|
# convert a float32 value to a bfloat16 (as int)
|
|
# By default, this conversion rounds-to-nearest-even and supports NaN
|
|
# Setting `truncate` to True enables a simpler conversion. In this mode the
|
|
# conversion is performed by simply dropping the 2 least significant bytes of
|
|
# the significand. In this mode an error of up to 1 bit may be introduced and
|
|
# preservation of NaN values is not be guaranteed.
|
|
def float32_to_bfloat16(fval: float, truncate: bool = False) -> int:
|
|
ival = int.from_bytes(struct.pack("<f", fval), "little")
|
|
if truncate:
|
|
return ival >> 16
|
|
# NaN requires at least 1 significand bit set
|
|
if isnan(fval):
|
|
return 0x7FC0 # sign=0, exp=all-ones, sig=0b1000000
|
|
# drop bottom 16-bits
|
|
# round remaining bits using round-to-nearest-even
|
|
rounded = ((ival >> 16) & 1) + 0x7FFF
|
|
return (ival + rounded) >> 16
|
|
|
|
|
|
def float32_to_float8e4m3( # noqa: PLR0911
|
|
fval: float,
|
|
scale: float = 1.0,
|
|
fn: bool = True,
|
|
uz: bool = False,
|
|
saturate: bool = True,
|
|
) -> int:
|
|
"""Convert a float32 value to a float8, e4m3 (as int).
|
|
|
|
See :ref:`onnx-detail-float8` for technical details.
|
|
|
|
Args:
|
|
fval: float to convert
|
|
scale: scale, divide *fval* by *scale* before casting it
|
|
fn: no infinite values
|
|
uz: no negative zero
|
|
saturate: if True, any value out of range included inf becomes
|
|
the maximum value, otherwise, it becomes NaN. The
|
|
description of operator Cast fully describes the
|
|
differences.
|
|
|
|
Returns:
|
|
converted float
|
|
"""
|
|
if not fn:
|
|
raise NotImplementedError(
|
|
"float32_to_float8e4m3 not implemented with fn=False."
|
|
)
|
|
x = fval / scale
|
|
b = int.from_bytes(struct.pack("<f", np.float32(x)), "little")
|
|
ret = (b & 0x80000000) >> 24 # sign
|
|
if uz:
|
|
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
|
|
return 0x80
|
|
if np.isinf(x):
|
|
if saturate:
|
|
return ret | 127
|
|
return 0x80
|
|
e = (b & 0x7F800000) >> 23 # exponent
|
|
m = b & 0x007FFFFF # mantissa
|
|
|
|
if e < 116: # noqa: PLR2004
|
|
ret = 0
|
|
elif e < 120: # noqa: PLR2004
|
|
# denormalized number
|
|
ex = e - 119
|
|
if ex >= -2: # noqa: PLR2004
|
|
ret |= 1 << (2 + ex)
|
|
ret |= m >> (21 - ex)
|
|
elif m > 0:
|
|
ret |= 1
|
|
else:
|
|
ret = 0
|
|
mask = 1 << (20 - ex)
|
|
if m & mask and (
|
|
ret & 1
|
|
or m & (mask - 1) > 0
|
|
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
|
|
):
|
|
# rounding
|
|
ret += 1
|
|
elif e < 135: # noqa: PLR2004
|
|
# normalized number
|
|
ex = e - 119 # 127 - 8
|
|
if ex == 0:
|
|
ret |= 0x4
|
|
ret |= m >> 21
|
|
else:
|
|
ret |= ex << 3
|
|
ret |= m >> 20
|
|
if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
|
|
if (ret & 0x7F) < 0x7F: # noqa: PLR2004
|
|
# rounding
|
|
ret += 1
|
|
elif not saturate:
|
|
return 0x80
|
|
elif saturate:
|
|
ret |= 0x7F # 01111110
|
|
else:
|
|
ret = 0x80
|
|
return int(ret)
|
|
else:
|
|
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
|
|
return 0x7F | ret
|
|
if np.isinf(x):
|
|
if saturate:
|
|
return ret | 126
|
|
return 0x7F | ret
|
|
e = (b & 0x7F800000) >> 23 # exponent
|
|
m = b & 0x007FFFFF # mantissa
|
|
|
|
if e != 0:
|
|
if e < 117: # noqa: PLR2004
|
|
pass
|
|
elif e < 121: # noqa: PLR2004
|
|
# denormalized number
|
|
ex = e - 120
|
|
if ex >= -2: # noqa: PLR2004
|
|
ret |= 1 << (2 + ex)
|
|
ret |= m >> (21 - ex)
|
|
elif m > 0:
|
|
ret |= 1
|
|
mask = 1 << (20 - ex)
|
|
if m & mask and (
|
|
ret & 1
|
|
or m & (mask - 1) > 0
|
|
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
|
|
):
|
|
# rounding
|
|
ret += 1
|
|
elif e < 136: # noqa: PLR2004
|
|
# normalized number
|
|
ex = e - 120
|
|
if ex == 0:
|
|
ret |= 0x4
|
|
ret |= m >> 21
|
|
else:
|
|
ret |= ex << 3
|
|
ret |= m >> 20
|
|
if (ret & 0x7F) == 0x7F: # noqa: PLR2004
|
|
ret &= 0xFE
|
|
if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)):
|
|
if (ret & 0x7F) < 0x7E: # noqa: PLR2004
|
|
# rounding
|
|
ret += 1
|
|
elif not saturate:
|
|
ret |= 0x7F
|
|
elif saturate:
|
|
ret |= 126 # 01111110
|
|
else:
|
|
ret |= 0x7F
|
|
return int(ret)
|
|
|
|
|
|
def float32_to_float8e5m2( # noqa: PLR0911
|
|
fval: float,
|
|
scale: float = 1.0,
|
|
fn: bool = False,
|
|
uz: bool = False,
|
|
saturate: bool = True,
|
|
) -> int:
|
|
"""Convert a float32 value to a float8, e5m2 (as int).
|
|
|
|
Args:
|
|
fval: float to convert
|
|
scale: scale, divide *fval* by *scale* before casting it
|
|
fn: no infinite values
|
|
uz: no negative zero
|
|
saturate: if True, any value out of range included inf becomes
|
|
the maximum value, otherwise, it becomes NaN. The
|
|
description of operator Cast fully describes the
|
|
differences.
|
|
|
|
Returns:
|
|
converted float
|
|
"""
|
|
x = fval / scale
|
|
b = int.from_bytes(struct.pack("<f", np.float32(x)), "little")
|
|
ret = (b & 0x80000000) >> 24 # sign
|
|
|
|
if fn and uz:
|
|
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
|
|
return 0x80
|
|
if (b & 0x7FFFFFFF) == 0x7F800000: # noqa: PLR2004
|
|
# inf
|
|
if saturate:
|
|
return ret | 0x7F
|
|
return 0x80
|
|
e = (b & 0x7F800000) >> 23 # exponent
|
|
m = b & 0x007FFFFF # mantissa
|
|
|
|
if e < 109: # noqa: PLR2004
|
|
ret = 0
|
|
elif e < 112: # noqa: PLR2004
|
|
# denormalized number
|
|
ex = e - 111
|
|
if ex >= -1:
|
|
ret |= 1 << (1 + ex)
|
|
ret |= m >> (22 - ex)
|
|
elif m > 0:
|
|
ret |= 1
|
|
else:
|
|
ret = 0
|
|
mask = 1 << (21 - ex)
|
|
if m & mask and (
|
|
ret & 1
|
|
or m & (mask - 1) > 0
|
|
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
|
|
):
|
|
# rounding
|
|
ret += 1
|
|
elif e < 143: # noqa: PLR2004
|
|
# normalized number
|
|
ex = e - 111
|
|
ret |= ex << 2
|
|
ret |= m >> 21
|
|
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
|
|
if (ret & 0x7F) < 0x7F: # noqa: PLR2004
|
|
# rounding
|
|
ret += 1
|
|
elif not saturate:
|
|
ret = 0x80
|
|
elif e == 255 and m == 0: # inf # noqa: PLR2004
|
|
ret = 0x80
|
|
elif saturate:
|
|
ret |= 0x7F # last possible number
|
|
else:
|
|
ret = 0x80
|
|
return int(ret)
|
|
elif not fn and not uz:
|
|
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
|
|
return 0x7F | ret
|
|
if np.isinf(x):
|
|
if saturate:
|
|
return 0x7B | ret
|
|
return 0x7C | ret
|
|
e = (b & 0x7F800000) >> 23 # exponent
|
|
m = b & 0x007FFFFF # mantissa
|
|
|
|
if e != 0:
|
|
if e < 110: # noqa: PLR2004
|
|
pass
|
|
elif e < 113: # noqa: PLR2004
|
|
# denormalized number
|
|
ex = e - 112
|
|
if ex >= -1:
|
|
ret |= 1 << (1 + ex)
|
|
ret |= m >> (22 - ex)
|
|
elif m > 0:
|
|
ret |= 1
|
|
mask = 1 << (21 - ex)
|
|
if m & mask and (
|
|
ret & 1
|
|
or m & (mask - 1) > 0
|
|
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
|
|
):
|
|
# rounding
|
|
ret += 1
|
|
elif e < 143: # noqa: PLR2004
|
|
# normalized number
|
|
ex = e - 112
|
|
ret |= ex << 2
|
|
ret |= m >> 21
|
|
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
|
|
if (ret & 0x7F) < 0x7B: # noqa: PLR2004
|
|
# rounding
|
|
ret += 1
|
|
elif saturate:
|
|
ret |= 0x7B
|
|
else:
|
|
ret |= 0x7C
|
|
elif saturate:
|
|
ret |= 0x7B
|
|
else:
|
|
ret |= 0x7C
|
|
return int(ret)
|
|
else:
|
|
raise NotImplementedError("fn and uz must be both False or True.")
|
|
|
|
|
|
def pack_float32_to_4bit(array: np.ndarray | Sequence, signed: bool) -> np.ndarray:
|
|
"""Convert an array of float32 value to a 4bit data-type and pack every two concecutive elements in a byte.
|
|
See :ref:`onnx-detail-int4` for technical details.
|
|
|
|
Args:
|
|
array: array of float to convert and pack
|
|
signed: Whether the 4 bit variant is signed or unsigned
|
|
|
|
Returns:
|
|
Packed array with size `ceil(farray.size/2)` (single dimension).
|
|
"""
|
|
if not isinstance(array, np.ndarray):
|
|
array = np.asarray(array, dtype=np.float32)
|
|
|
|
array_flat = array.ravel()
|
|
is_odd_volume = np.prod(array.shape) % 2 == 1
|
|
if is_odd_volume:
|
|
array_flat = np.append(array_flat, np.array([0]))
|
|
|
|
single_func = lambda x, y: subbyte.float32x2_to_4bitx2(x, y, signed) # noqa: E731
|
|
func = np.frompyfunc(single_func, 2, 1)
|
|
|
|
arr = func(array_flat[0::2], array_flat[1::2])
|
|
return arr.astype(np.uint8) # type: ignore[no-any-return]
|
|
|
|
|
|
def make_tensor(
|
|
name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = False
|
|
) -> TensorProto:
|
|
"""Make a TensorProto with specified arguments. If raw is False, this
|
|
function will choose the corresponding proto field to store the
|
|
values based on data_type. If raw is True, use "raw_data" proto
|
|
field to store the values, and values should be of type bytes in
|
|
this case.
|
|
|
|
Args:
|
|
name (string): tensor name
|
|
data_type (int): a value such as onnx.TensorProto.FLOAT
|
|
dims (List[int]): shape
|
|
vals: values
|
|
raw (bool): if True, vals contains the serialized content of the tensor,
|
|
otherwise, vals should be a list of values of the type defined by *data_type*
|
|
|
|
Returns:
|
|
TensorProto
|
|
"""
|
|
tensor = TensorProto()
|
|
tensor.data_type = data_type
|
|
tensor.name = name
|
|
|
|
if data_type == TensorProto.STRING and raw:
|
|
raise TypeError("Can not use raw_data to store string type.")
|
|
|
|
np_dtype = tensor_dtype_to_np_dtype(data_type)
|
|
|
|
# Check number of vals specified equals tensor size
|
|
expected_size = 1
|
|
if raw:
|
|
# NumPy doesn't have BFLOAT16. TENSOR_TYPE_MAP maps it to float32, which has the wrong itemsize.
|
|
if data_type == TensorProto.BFLOAT16:
|
|
expected_size = 2
|
|
elif data_type in (
|
|
TensorProto.FLOAT8E4M3FN,
|
|
TensorProto.FLOAT8E4M3FNUZ,
|
|
TensorProto.FLOAT8E5M2,
|
|
TensorProto.FLOAT8E5M2FNUZ,
|
|
):
|
|
expected_size = 1
|
|
# NumPy doesn't have INT4. It is packed in couples to UINT8 buffers.
|
|
elif data_type in (TensorProto.UINT4, TensorProto.INT4):
|
|
expected_size = 0.5 # type: ignore[assignment]
|
|
else:
|
|
expected_size = np_dtype.itemsize
|
|
|
|
if type(vals) is np.ndarray and len(vals.shape) > 1:
|
|
vals = vals.flatten()
|
|
for d in dims:
|
|
expected_size *= d
|
|
|
|
if len(vals) != expected_size:
|
|
# padding of half a byte is acceptable for 4bit types
|
|
if not (
|
|
data_type in (TensorProto.UINT4, TensorProto.INT4)
|
|
and len(vals) == expected_size + 0.5
|
|
):
|
|
raise ValueError(
|
|
f"Number of values does not match tensor's size. Expected {expected_size}, but it is {len(vals)}. "
|
|
)
|
|
|
|
if raw:
|
|
tensor.raw_data = vals
|
|
else:
|
|
if data_type in (TensorProto.COMPLEX64, TensorProto.COMPLEX128):
|
|
vals = split_complex_to_pairs(vals)
|
|
elif data_type == TensorProto.FLOAT16:
|
|
vals = (
|
|
np.array(vals).astype(np_dtype).view(dtype=np.uint16).flatten().tolist()
|
|
)
|
|
elif data_type in (
|
|
TensorProto.BFLOAT16,
|
|
TensorProto.FLOAT8E4M3FN,
|
|
TensorProto.FLOAT8E4M3FNUZ,
|
|
TensorProto.FLOAT8E5M2,
|
|
TensorProto.FLOAT8E5M2FNUZ,
|
|
):
|
|
fcast = {
|
|
TensorProto.BFLOAT16: float32_to_bfloat16,
|
|
TensorProto.FLOAT8E4M3FN: float32_to_float8e4m3,
|
|
TensorProto.FLOAT8E4M3FNUZ: lambda *args: float32_to_float8e4m3( # type: ignore[misc]
|
|
*args, uz=True
|
|
),
|
|
TensorProto.FLOAT8E5M2: float32_to_float8e5m2,
|
|
TensorProto.FLOAT8E5M2FNUZ: lambda *args: float32_to_float8e5m2( # type: ignore[misc]
|
|
*args, fn=True, uz=True
|
|
),
|
|
}[
|
|
data_type # type: ignore[index]
|
|
]
|
|
vals = list(
|
|
map( # type: ignore[call-overload]
|
|
fcast,
|
|
np.array(vals).astype(np_dtype).flatten().tolist(),
|
|
)
|
|
)
|
|
elif data_type in (
|
|
TensorProto.UINT4,
|
|
TensorProto.INT4,
|
|
):
|
|
signed = data_type == TensorProto.INT4
|
|
|
|
# Two packed 4-bit values must be represented as a single uint8 value.
|
|
# Therefore, pack_float32_to_4bit() sets the dtype of the output vals
|
|
# to uint8 regardless of the value of 'signed'. Using int8 would cause
|
|
# the size of int4 tensors to increase ~5x if the tensor contains negative values (due to
|
|
# the way negative values are serialized by protobuf).
|
|
vals = pack_float32_to_4bit(vals, signed=signed).flatten().tolist()
|
|
elif data_type == TensorProto.BOOL:
|
|
vals = np.array(vals).astype(int)
|
|
elif data_type == TensorProto.STRING:
|
|
vals = np.array(vals).astype(bytes)
|
|
field = tensor_dtype_to_field(data_type)
|
|
getattr(tensor, field).extend(vals)
|
|
tensor.dims.extend(dims)
|
|
return tensor
|
|
|
|
|
|
def make_sparse_tensor(
|
|
values: TensorProto, indices: TensorProto, dims: Sequence[int]
|
|
) -> SparseTensorProto:
|
|
"""Construct a SparseTensorProto
|
|
|
|
Args:
|
|
values (TensorProto): the values
|
|
indices (TensorProto): the indices
|
|
dims: the shape
|
|
|
|
Returns:
|
|
SparseTensorProto
|
|
"""
|
|
sparse = SparseTensorProto()
|
|
sparse.values.CopyFrom(values)
|
|
sparse.indices.CopyFrom(indices)
|
|
sparse.dims.extend(dims)
|
|
return sparse
|
|
|
|
|
|
def make_sequence(
|
|
name: str,
|
|
elem_type: SequenceProto.DataType,
|
|
values: Sequence[Any],
|
|
) -> SequenceProto:
|
|
"""Make a Sequence with specified value arguments."""
|
|
sequence = SequenceProto()
|
|
sequence.name = name
|
|
sequence.elem_type = elem_type
|
|
|
|
if elem_type == SequenceProto.UNDEFINED:
|
|
return sequence
|
|
if elem_type == SequenceProto.TENSOR:
|
|
attribute = sequence.tensor_values
|
|
elif elem_type == SequenceProto.SPARSE_TENSOR:
|
|
attribute = sequence.sparse_tensor_values # type: ignore[assignment]
|
|
elif elem_type == SequenceProto.SEQUENCE:
|
|
attribute = sequence.sequence_values # type: ignore[assignment]
|
|
elif elem_type == SequenceProto.MAP:
|
|
attribute = sequence.map_values # type: ignore[assignment]
|
|
elif elem_type == OptionalProto.OPTIONAL:
|
|
attribute = sequence.optional_values # type: ignore[assignment]
|
|
else:
|
|
raise TypeError("The element type in the input sequence is not supported.")
|
|
|
|
attribute.extend(values)
|
|
return sequence
|
|
|
|
|
|
def make_map(
|
|
name: str, key_type: int, keys: list[Any], values: SequenceProto
|
|
) -> MapProto:
|
|
"""Make a Map with specified key-value pair arguments.
|
|
|
|
Criteria for conversion:
|
|
- Keys and Values must have the same number of elements
|
|
- Every key in keys must be of the same type
|
|
- Every value in values must be of the same type
|
|
"""
|
|
map_proto = MapProto()
|
|
valid_key_int_types = [
|
|
TensorProto.INT8,
|
|
TensorProto.INT16,
|
|
TensorProto.INT32,
|
|
TensorProto.INT64,
|
|
TensorProto.UINT8,
|
|
TensorProto.UINT16,
|
|
TensorProto.UINT32,
|
|
TensorProto.UINT64,
|
|
]
|
|
map_proto.name = name
|
|
map_proto.key_type = key_type
|
|
if key_type == TensorProto.STRING:
|
|
map_proto.string_keys.extend(keys)
|
|
elif key_type in valid_key_int_types:
|
|
map_proto.keys.extend(keys)
|
|
map_proto.values.CopyFrom(values)
|
|
return map_proto
|
|
|
|
|
|
def make_optional(
|
|
name: str,
|
|
elem_type: OptionalProto.DataType,
|
|
value: Any | None,
|
|
) -> OptionalProto:
|
|
"""Make an Optional with specified value arguments."""
|
|
optional = OptionalProto()
|
|
optional.name = name
|
|
optional.elem_type = elem_type
|
|
|
|
if elem_type == OptionalProto.UNDEFINED:
|
|
return optional
|
|
if elem_type == OptionalProto.TENSOR:
|
|
attribute = optional.tensor_value
|
|
elif elem_type == OptionalProto.SPARSE_TENSOR:
|
|
attribute = optional.sparse_tensor_value # type: ignore[assignment]
|
|
elif elem_type == OptionalProto.SEQUENCE:
|
|
attribute = optional.sequence_value # type: ignore[assignment]
|
|
elif elem_type == OptionalProto.MAP:
|
|
attribute = optional.map_value # type: ignore[assignment]
|
|
elif elem_type == OptionalProto.OPTIONAL:
|
|
attribute = optional.optional_value # type: ignore[assignment]
|
|
else:
|
|
raise TypeError("The element type in the input optional is not supported.")
|
|
|
|
attribute.CopyFrom(value) # type: ignore[arg-type]
|
|
return optional
|
|
|
|
|
|
def _to_bytes(value: str | bytes) -> bytes:
|
|
"""Coerce a string (or bytes) value into UTF-8 bytes."""
|
|
return value if isinstance(value, bytes) else value.encode("utf-8")
|
|
|
|
|
|
def make_attribute(
|
|
key: str,
|
|
value: Any,
|
|
doc_string: str | None = None,
|
|
attr_type: int | None = None,
|
|
) -> AttributeProto:
|
|
"""Makes an AttributeProto based on the value type."""
|
|
attr = AttributeProto()
|
|
attr.name = key
|
|
if doc_string:
|
|
attr.doc_string = doc_string
|
|
|
|
# Singular cases
|
|
if isinstance(value, numbers.Integral):
|
|
attr.i = int(value)
|
|
attr.type = AttributeProto.INT
|
|
elif isinstance(value, numbers.Real):
|
|
attr.f = float(value)
|
|
attr.type = AttributeProto.FLOAT
|
|
elif isinstance(value, (str, bytes)):
|
|
# Encode strings into utf-8
|
|
attr.s = _to_bytes(value)
|
|
attr.type = AttributeProto.STRING
|
|
elif isinstance(value, TensorProto):
|
|
attr.t.CopyFrom(value)
|
|
attr.type = AttributeProto.TENSOR
|
|
elif isinstance(value, SparseTensorProto):
|
|
attr.sparse_tensor.CopyFrom(value)
|
|
attr.type = AttributeProto.SPARSE_TENSOR
|
|
elif isinstance(value, GraphProto):
|
|
attr.g.CopyFrom(value)
|
|
attr.type = AttributeProto.GRAPH
|
|
elif isinstance(value, TypeProto):
|
|
attr.tp.CopyFrom(value)
|
|
attr.type = AttributeProto.TYPE_PROTO
|
|
# Iterable cases
|
|
elif isinstance(value, collections.abc.Iterable):
|
|
value = list(value)
|
|
if len(value) == 0 and attr_type is None:
|
|
raise ValueError(
|
|
f"Could not infer attribute `{key}` type from empty iterator"
|
|
)
|
|
if attr_type is None:
|
|
types = {type(v) for v in value}
|
|
for exp_t, exp_enum in (
|
|
(numbers.Integral, AttributeProto.INTS),
|
|
(numbers.Real, AttributeProto.FLOATS),
|
|
((str, bytes), AttributeProto.STRINGS),
|
|
(TensorProto, AttributeProto.TENSORS),
|
|
(SparseTensorProto, AttributeProto.SPARSE_TENSORS),
|
|
(GraphProto, AttributeProto.GRAPHS),
|
|
(TypeProto, AttributeProto.TYPE_PROTOS),
|
|
):
|
|
if all(issubclass(t, exp_t) for t in types): # type: ignore[arg-type]
|
|
attr_type = exp_enum
|
|
break
|
|
if attr_type is None:
|
|
raise ValueError(
|
|
"Could not infer the attribute type from the elements of the passed Iterable value."
|
|
)
|
|
|
|
if attr_type == AttributeProto.INTS:
|
|
attr.ints.extend(value)
|
|
attr.type = AttributeProto.INTS
|
|
elif attr_type == AttributeProto.FLOATS:
|
|
attr.floats.extend(value)
|
|
attr.type = AttributeProto.FLOATS
|
|
elif attr_type == AttributeProto.STRINGS:
|
|
attr.strings.extend(_to_bytes(v) for v in value)
|
|
attr.type = AttributeProto.STRINGS
|
|
elif attr_type == AttributeProto.TENSORS:
|
|
attr.tensors.extend(value)
|
|
attr.type = AttributeProto.TENSORS
|
|
elif attr_type == AttributeProto.SPARSE_TENSORS:
|
|
attr.sparse_tensors.extend(value)
|
|
attr.type = AttributeProto.SPARSE_TENSORS
|
|
elif attr_type == AttributeProto.GRAPHS:
|
|
attr.graphs.extend(value)
|
|
attr.type = AttributeProto.GRAPHS
|
|
elif attr_type == AttributeProto.TYPE_PROTOS:
|
|
attr.type_protos.extend(value)
|
|
attr.type = AttributeProto.TYPE_PROTOS
|
|
else:
|
|
raise AssertionError() # Should not reach since `ValueError` must be raised in attr_type checking
|
|
else:
|
|
raise TypeError(f"'{value}' is not an accepted attribute value.")
|
|
|
|
if attr_type is not None and attr.type != attr_type:
|
|
raise TypeError(
|
|
f"Inferred attribute type '{_attr_type_to_str(attr.type)}'({attr.type}) mismatched with specified type '{_attr_type_to_str(attr_type)}'({attr_type})"
|
|
)
|
|
return attr
|
|
|
|
|
|
def make_attribute_ref(
|
|
name: str, attr_type: AttributeProto.AttributeType, doc_string: str | None = None
|
|
) -> AttributeProto:
|
|
"""Make an AttributeProto holding a reference to the parent function's attribute of given name and type."""
|
|
attr = AttributeProto()
|
|
attr.name = name
|
|
attr.type = attr_type
|
|
if doc_string:
|
|
attr.doc_string = doc_string
|
|
return attr
|
|
|
|
|
|
def get_attribute_value(attr: AttributeProto) -> Any: # noqa: PLR0911
|
|
if attr.ref_attr_name:
|
|
raise ValueError(f"Cannot get value of reference attribute: {attr}")
|
|
if attr.type == AttributeProto.FLOAT:
|
|
return attr.f
|
|
if attr.type == AttributeProto.INT:
|
|
return attr.i
|
|
if attr.type == AttributeProto.STRING:
|
|
return attr.s
|
|
if attr.type == AttributeProto.TENSOR:
|
|
return attr.t
|
|
if attr.type == AttributeProto.SPARSE_TENSOR:
|
|
return attr.sparse_tensor
|
|
if attr.type == AttributeProto.GRAPH:
|
|
return attr.g
|
|
if attr.type == AttributeProto.TYPE_PROTO:
|
|
return attr.tp
|
|
if attr.type == AttributeProto.FLOATS:
|
|
return list(attr.floats)
|
|
if attr.type == AttributeProto.INTS:
|
|
return list(attr.ints)
|
|
if attr.type == AttributeProto.STRINGS:
|
|
return list(attr.strings)
|
|
if attr.type == AttributeProto.TENSORS:
|
|
return list(attr.tensors)
|
|
if attr.type == AttributeProto.SPARSE_TENSORS:
|
|
return list(attr.sparse_tensors)
|
|
if attr.type == AttributeProto.GRAPHS:
|
|
return list(attr.graphs)
|
|
if attr.type == AttributeProto.TYPE_PROTOS:
|
|
return list(attr.type_protos)
|
|
if attr.type == AttributeProto.UNDEFINED:
|
|
return None
|
|
raise ValueError(f"Unsupported ONNX attribute: {attr}")
|
|
|
|
|
|
def get_node_attr_value(node: NodeProto, attr_name: str) -> Any:
|
|
matching = [x for x in node.attribute if x.name == attr_name]
|
|
if len(matching) > 1:
|
|
raise ValueError(f"Node has multiple attributes with name {attr_name}")
|
|
if len(matching) < 1:
|
|
raise ValueError(f"Node has no attribute with name {attr_name}")
|
|
return get_attribute_value(matching[0])
|
|
|
|
|
|
def make_empty_tensor_value_info(name: str) -> ValueInfoProto:
|
|
value_info_proto = ValueInfoProto()
|
|
value_info_proto.name = name
|
|
return value_info_proto
|
|
|
|
|
|
def make_tensor_type_proto(
|
|
elem_type: int,
|
|
shape: Sequence[str | int | None] | None,
|
|
shape_denotation: list[str] | None = None,
|
|
) -> TypeProto:
|
|
"""Makes a Tensor TypeProto based on the data type and shape."""
|
|
type_proto = TypeProto()
|
|
tensor_type_proto = type_proto.tensor_type
|
|
tensor_type_proto.elem_type = elem_type
|
|
tensor_shape_proto = tensor_type_proto.shape
|
|
|
|
if shape is not None:
|
|
# You might think this is a no-op (extending a normal Python
|
|
# list by [] certainly is), but protobuf lists work a little
|
|
# differently; if a field is never set, it is omitted from the
|
|
# resulting protobuf; a list that is explicitly set to be
|
|
# empty will get an (empty) entry in the protobuf. This
|
|
# difference is visible to our consumers, so make sure we emit
|
|
# an empty shape!
|
|
tensor_shape_proto.dim.extend([])
|
|
|
|
if shape_denotation and len(shape_denotation) != len(shape):
|
|
raise ValueError(
|
|
"Invalid shape_denotation. Must be of the same length as shape."
|
|
)
|
|
|
|
for i, d in enumerate(shape):
|
|
dim = tensor_shape_proto.dim.add()
|
|
if d is None:
|
|
pass
|
|
elif isinstance(d, int):
|
|
dim.dim_value = d
|
|
elif isinstance(d, str):
|
|
dim.dim_param = d
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid item in shape: {d}. Needs to be of int or str."
|
|
)
|
|
|
|
if shape_denotation:
|
|
dim.denotation = shape_denotation[i]
|
|
|
|
return type_proto
|
|
|
|
|
|
def make_tensor_value_info(
|
|
name: str,
|
|
elem_type: int,
|
|
shape: Sequence[str | int | None] | None,
|
|
doc_string: str = "",
|
|
shape_denotation: list[str] | None = None,
|
|
) -> ValueInfoProto:
|
|
"""Makes a ValueInfoProto based on the data type and shape."""
|
|
value_info_proto = ValueInfoProto()
|
|
value_info_proto.name = name
|
|
if doc_string:
|
|
value_info_proto.doc_string = doc_string
|
|
|
|
tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation)
|
|
value_info_proto.type.CopyFrom(tensor_type_proto)
|
|
return value_info_proto
|
|
|
|
|
|
def make_sparse_tensor_type_proto(
|
|
elem_type: int,
|
|
shape: Sequence[str | int | None] | None,
|
|
shape_denotation: list[str] | None = None,
|
|
) -> TypeProto:
|
|
"""Makes a SparseTensor TypeProto based on the data type and shape."""
|
|
type_proto = TypeProto()
|
|
sparse_tensor_type_proto = type_proto.sparse_tensor_type
|
|
sparse_tensor_type_proto.elem_type = elem_type
|
|
sparse_tensor_shape_proto = sparse_tensor_type_proto.shape
|
|
|
|
if shape is not None:
|
|
# You might think this is a no-op (extending a normal Python
|
|
# list by [] certainly is), but protobuf lists work a little
|
|
# differently; if a field is never set, it is omitted from the
|
|
# resulting protobuf; a list that is explicitly set to be
|
|
# empty will get an (empty) entry in the protobuf. This
|
|
# difference is visible to our consumers, so make sure we emit
|
|
# an empty shape!
|
|
sparse_tensor_shape_proto.dim.extend([])
|
|
|
|
if shape_denotation and len(shape_denotation) != len(shape):
|
|
raise ValueError(
|
|
"Invalid shape_denotation. Must be of the same length as shape."
|
|
)
|
|
|
|
for i, d in enumerate(shape):
|
|
dim = sparse_tensor_shape_proto.dim.add()
|
|
if d is None:
|
|
pass
|
|
elif isinstance(d, int):
|
|
dim.dim_value = d
|
|
elif isinstance(d, str):
|
|
dim.dim_param = d
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid item in shape: {d}. Needs to be of int or text."
|
|
)
|
|
|
|
if shape_denotation:
|
|
dim.denotation = shape_denotation[i]
|
|
|
|
return type_proto
|
|
|
|
|
|
def make_sparse_tensor_value_info(
|
|
name: str,
|
|
elem_type: int,
|
|
shape: Sequence[str | int | None] | None,
|
|
doc_string: str = "",
|
|
shape_denotation: list[str] | None = None,
|
|
) -> ValueInfoProto:
|
|
"""Makes a SparseTensor ValueInfoProto based on the data type and shape."""
|
|
value_info_proto = ValueInfoProto()
|
|
value_info_proto.name = name
|
|
if doc_string:
|
|
value_info_proto.doc_string = doc_string
|
|
|
|
sparse_tensor_type_proto = make_sparse_tensor_type_proto(
|
|
elem_type, shape, shape_denotation
|
|
)
|
|
value_info_proto.type.sparse_tensor_type.CopyFrom(
|
|
sparse_tensor_type_proto.sparse_tensor_type
|
|
)
|
|
return value_info_proto
|
|
|
|
|
|
def make_sequence_type_proto(
|
|
inner_type_proto: TypeProto,
|
|
) -> TypeProto:
|
|
"""Makes a sequence TypeProto."""
|
|
type_proto = TypeProto()
|
|
type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto)
|
|
return type_proto
|
|
|
|
|
|
def make_optional_type_proto(
|
|
inner_type_proto: TypeProto,
|
|
) -> TypeProto:
|
|
"""Makes an optional TypeProto."""
|
|
type_proto = TypeProto()
|
|
type_proto.optional_type.elem_type.CopyFrom(inner_type_proto)
|
|
return type_proto
|
|
|
|
|
|
def make_map_type_proto(
|
|
key_type: int,
|
|
value_type: TypeProto,
|
|
) -> TypeProto:
|
|
"""Makes a map TypeProto."""
|
|
type_proto = TypeProto()
|
|
type_proto.map_type.key_type = key_type
|
|
type_proto.map_type.value_type.CopyFrom(value_type)
|
|
return type_proto
|
|
|
|
|
|
def make_value_info(
|
|
name: str,
|
|
type_proto: TypeProto,
|
|
doc_string: str = "",
|
|
) -> ValueInfoProto:
|
|
"""Makes a ValueInfoProto with the given type_proto."""
|
|
value_info_proto = ValueInfoProto()
|
|
value_info_proto.name = name
|
|
if doc_string:
|
|
value_info_proto.doc_string = doc_string
|
|
|
|
value_info_proto.type.CopyFrom(type_proto)
|
|
return value_info_proto
|
|
|
|
|
|
def _sanitize_str(s: str | bytes) -> str:
|
|
if isinstance(s, str):
|
|
sanitized = s
|
|
elif isinstance(s, bytes):
|
|
sanitized = s.decode("utf-8", errors="ignore")
|
|
else:
|
|
sanitized = str(s)
|
|
if len(sanitized) < 64: # noqa: PLR2004
|
|
return sanitized
|
|
return sanitized[:64] + f"...<+len={(len(sanitized) - 64)}>"
|
|
|
|
|
|
def make_tensor_sequence_value_info(
|
|
name: str,
|
|
elem_type: int,
|
|
shape: Sequence[str | int | None] | None,
|
|
doc_string: str = "",
|
|
elem_shape_denotation: list[str] | None = None,
|
|
) -> ValueInfoProto:
|
|
"""Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape."""
|
|
value_info_proto = ValueInfoProto()
|
|
value_info_proto.name = name
|
|
if doc_string:
|
|
value_info_proto.doc_string = doc_string
|
|
|
|
tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation)
|
|
sequence_type_proto = make_sequence_type_proto(tensor_type_proto)
|
|
value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type)
|
|
|
|
return value_info_proto
|
|
|
|
|
|
def printable_attribute(
|
|
attr: AttributeProto, subgraphs: bool = False
|
|
) -> str | tuple[str, list[GraphProto]]:
|
|
content = []
|
|
content.append(attr.name)
|
|
content.append("=")
|
|
|
|
def str_float(f: float) -> str:
|
|
# NB: Different Python versions print different numbers of trailing
|
|
# decimals, specifying this explicitly keeps it consistent for all
|
|
# versions
|
|
return f"{f:.15g}"
|
|
|
|
def str_int(i: int) -> str:
|
|
return str(i)
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str:
|
|
return "[" + ", ".join(map(str_elem, xs)) + "]"
|
|
|
|
# for now, this logic should continue to work as long as we are running on a proto3
|
|
# implementation. If/when we switch to proto3, we will need to use attr.type
|
|
|
|
# To support printing subgraphs, if we find a graph attribute, print out
|
|
# its name here and pass the graph itself up to the caller for later
|
|
# printing.
|
|
graphs = []
|
|
if attr.HasField("f"):
|
|
content.append(str_float(attr.f))
|
|
elif attr.HasField("i"):
|
|
content.append(str_int(attr.i))
|
|
elif attr.HasField("s"):
|
|
# TODO: Bit nervous about Python 2 / Python 3 determinism implications
|
|
content.append(repr(_sanitize_str(attr.s)))
|
|
elif attr.HasField("t"):
|
|
if len(attr.t.dims) > 0:
|
|
content.append("<Tensor>")
|
|
else:
|
|
# special case to print scalars
|
|
field = tensor_dtype_to_field(attr.t.data_type)
|
|
content.append(f"<Scalar Tensor {getattr(attr.t, field)}>")
|
|
elif attr.HasField("g"):
|
|
content.append(f"<graph {attr.g.name}>")
|
|
graphs.append(attr.g)
|
|
elif attr.HasField("tp"):
|
|
content.append(f"<Type Proto {attr.tp}>")
|
|
elif attr.floats:
|
|
content.append(str_list(str_float, attr.floats))
|
|
elif attr.ints:
|
|
content.append(str_list(str_int, attr.ints))
|
|
elif attr.strings:
|
|
# TODO: Bit nervous about Python 2 / Python 3 determinism implications
|
|
content.append(str(list(map(_sanitize_str, attr.strings))))
|
|
elif attr.tensors:
|
|
content.append("[<Tensor>, ...]")
|
|
elif attr.type_protos:
|
|
content.append("[")
|
|
for i, tp in enumerate(attr.type_protos):
|
|
comma = "," if i != len(attr.type_protos) - 1 else ""
|
|
content.append(f"<Type Proto {tp}>{comma}")
|
|
content.append("]")
|
|
elif attr.graphs:
|
|
content.append("[")
|
|
for i, g in enumerate(attr.graphs):
|
|
comma = "," if i != len(attr.graphs) - 1 else ""
|
|
content.append(f"<graph {g.name}>{comma}")
|
|
content.append("]")
|
|
graphs.extend(attr.graphs)
|
|
else:
|
|
content.append("<Unknown>")
|
|
if subgraphs:
|
|
return " ".join(content), graphs
|
|
return " ".join(content)
|
|
|
|
|
|
def printable_dim(dim: TensorShapeProto.Dimension) -> str:
|
|
which = dim.WhichOneof("value")
|
|
if which is None:
|
|
return "?"
|
|
return str(getattr(dim, which))
|
|
|
|
|
|
def printable_type(t: TypeProto) -> str:
|
|
if t.WhichOneof("value") == "tensor_type":
|
|
s = TensorProto.DataType.Name(t.tensor_type.elem_type)
|
|
if t.tensor_type.HasField("shape"):
|
|
if len(t.tensor_type.shape.dim):
|
|
s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim)))
|
|
else:
|
|
s += ", scalar"
|
|
return s # type: ignore[no-any-return]
|
|
if t.WhichOneof("value") is None:
|
|
return ""
|
|
return f"Unknown type {t.WhichOneof('value')}"
|
|
|
|
|
|
def printable_value_info(v: ValueInfoProto) -> str:
|
|
s = f"%{v.name}"
|
|
if v.type:
|
|
s = f"{s}[{printable_type(v.type)}]"
|
|
return s
|
|
|
|
|
|
def printable_tensor_proto(t: TensorProto) -> str:
|
|
s = f"%{t.name}["
|
|
s += TensorProto.DataType.Name(t.data_type)
|
|
if t.dims is not None:
|
|
if len(t.dims):
|
|
s += str(", " + "x".join(map(str, t.dims)))
|
|
else:
|
|
s += ", scalar"
|
|
s += "]"
|
|
return s
|
|
|
|
|
|
def printable_node(
|
|
node: NodeProto, prefix: str = "", subgraphs: bool = False
|
|
) -> str | tuple[str, list[GraphProto]]:
|
|
content = []
|
|
if len(node.output):
|
|
content.append(", ".join([f"%{name}" for name in node.output]))
|
|
content.append("=")
|
|
# To deal with nested graphs
|
|
graphs: list[GraphProto] = []
|
|
printed_attrs = []
|
|
for attr in node.attribute:
|
|
if subgraphs:
|
|
printed_attr_subgraphs = printable_attribute(attr, subgraphs)
|
|
if not isinstance(printed_attr_subgraphs[1], list):
|
|
raise TypeError(
|
|
f"printed_attr_subgraphs[1] must be an instance of {list}."
|
|
)
|
|
graphs.extend(printed_attr_subgraphs[1])
|
|
printed_attrs.append(printed_attr_subgraphs[0])
|
|
else:
|
|
printed = printable_attribute(attr)
|
|
if not isinstance(printed, str):
|
|
raise TypeError(f"printed must be an instance of {str}.")
|
|
printed_attrs.append(printed)
|
|
printed_attributes = ", ".join(sorted(printed_attrs))
|
|
printed_inputs = ", ".join([f"%{name}" for name in node.input])
|
|
if node.attribute:
|
|
content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})")
|
|
else:
|
|
content.append(f"{node.op_type}({printed_inputs})")
|
|
if subgraphs:
|
|
return prefix + " ".join(content), graphs
|
|
return prefix + " ".join(content)
|
|
|
|
|
|
def printable_graph(graph: GraphProto, prefix: str = "") -> str:
|
|
"""Display a GraphProto as a string.
|
|
|
|
Args:
|
|
graph (GraphProto): the graph to display
|
|
prefix (string): prefix of every line
|
|
|
|
Returns:
|
|
string
|
|
"""
|
|
content = []
|
|
indent = prefix + " "
|
|
# header
|
|
header = ["graph", graph.name]
|
|
initializers = {t.name for t in graph.initializer}
|
|
if len(graph.input):
|
|
header.append("(")
|
|
in_strs = [] # required inputs
|
|
in_with_init_strs = (
|
|
[]
|
|
) # optional inputs with initializer providing default value
|
|
for inp in graph.input:
|
|
if inp.name not in initializers:
|
|
in_strs.append(printable_value_info(inp))
|
|
else:
|
|
in_with_init_strs.append(printable_value_info(inp))
|
|
if in_strs:
|
|
content.append(prefix + " ".join(header))
|
|
header = []
|
|
for line in in_strs:
|
|
content.append(prefix + " " + line) # noqa: PERF401
|
|
header.append(")")
|
|
|
|
if in_with_init_strs:
|
|
header.append("optional inputs with matching initializers (")
|
|
content.append(prefix + " ".join(header))
|
|
header = []
|
|
for line in in_with_init_strs:
|
|
content.append(prefix + " " + line) # noqa: PERF401
|
|
header.append(")")
|
|
|
|
# from IR 4 onwards an initializer is not required to have a matching graph input
|
|
# so output the name, type and shape of those as well
|
|
if len(in_with_init_strs) < len(initializers):
|
|
graph_inputs = {i.name for i in graph.input}
|
|
init_strs = [
|
|
printable_tensor_proto(i)
|
|
for i in graph.initializer
|
|
if i.name not in graph_inputs
|
|
]
|
|
header.append("initializers (")
|
|
content.append(prefix + " ".join(header))
|
|
header = []
|
|
for line in init_strs:
|
|
content.append(prefix + " " + line) # noqa: PERF401
|
|
header.append(")")
|
|
|
|
header.append("{")
|
|
content.append(prefix + " ".join(header))
|
|
graphs: list[GraphProto] = []
|
|
# body
|
|
for node in graph.node:
|
|
contents_subgraphs = printable_node(node, indent, subgraphs=True)
|
|
if not isinstance(contents_subgraphs[1], list):
|
|
raise TypeError(f"contents_subgraphs[1] must be an instance of {list}.")
|
|
content.append(contents_subgraphs[0])
|
|
graphs.extend(contents_subgraphs[1])
|
|
# tail
|
|
tail = ["return"]
|
|
if len(graph.output):
|
|
tail.append(", ".join([f"%{out.name}" for out in graph.output]))
|
|
content.append(indent + " ".join(tail))
|
|
# closing bracket
|
|
content.append(prefix + "}")
|
|
for g in graphs:
|
|
content.append("\n" + printable_graph(g)) # noqa: PERF401
|
|
return "\n".join(content)
|
|
|
|
|
|
def strip_doc_string(proto: google.protobuf.message.Message) -> None:
|
|
"""Empties `doc_string` field on any nested protobuf messages"""
|
|
if not isinstance(proto, google.protobuf.message.Message):
|
|
raise TypeError(
|
|
f"proto must be an instance of {google.protobuf.message.Message}."
|
|
)
|
|
for descriptor in proto.DESCRIPTOR.fields:
|
|
if descriptor.name == "doc_string":
|
|
proto.ClearField(descriptor.name)
|
|
elif descriptor.type == descriptor.TYPE_MESSAGE:
|
|
if descriptor.label == descriptor.LABEL_REPEATED:
|
|
for x in getattr(proto, descriptor.name):
|
|
strip_doc_string(x)
|
|
elif proto.HasField(descriptor.name):
|
|
strip_doc_string(getattr(proto, descriptor.name))
|
|
|
|
|
|
def make_training_info(
|
|
algorithm: GraphProto,
|
|
algorithm_bindings: AssignmentBindingType,
|
|
initialization: GraphProto | None,
|
|
initialization_bindings: AssignmentBindingType | None,
|
|
) -> TrainingInfoProto:
|
|
training_info = TrainingInfoProto()
|
|
training_info.algorithm.CopyFrom(algorithm)
|
|
for k, v in algorithm_bindings:
|
|
binding = training_info.update_binding.add()
|
|
binding.key = k
|
|
binding.value = v
|
|
|
|
if initialization:
|
|
training_info.initialization.CopyFrom(initialization)
|
|
if initialization_bindings:
|
|
for k, v in initialization_bindings:
|
|
binding = training_info.initialization_binding.add()
|
|
binding.key = k
|
|
binding.value = v
|
|
|
|
return training_info
|
|
|
|
|
|
# Following functions are used for mapping
|
|
def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
|
|
"""Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor.
|
|
|
|
Args:
|
|
tensor_dtype: TensorProto's data_type
|
|
|
|
Returns:
|
|
numpy's data_type
|
|
"""
|
|
return mapping.TENSOR_TYPE_MAP[tensor_dtype].np_dtype
|
|
|
|
|
|
def tensor_dtype_to_storage_tensor_dtype(tensor_dtype: int) -> int:
|
|
"""Convert a TensorProto's data_type to corresponding data_type for storage.
|
|
|
|
Args:
|
|
tensor_dtype: TensorProto's data_type
|
|
|
|
Returns:
|
|
data_type for storage
|
|
"""
|
|
return mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
|
|
|
|
|
|
def tensor_dtype_to_string(tensor_dtype: int) -> str:
|
|
"""Get the name of given TensorProto's data_type.
|
|
|
|
Args:
|
|
tensor_dtype: TensorProto's data_type
|
|
|
|
Returns:
|
|
the name of data_type
|
|
"""
|
|
return mapping.TENSOR_TYPE_MAP[tensor_dtype].name
|
|
|
|
|
|
def tensor_dtype_to_field(tensor_dtype: int) -> str:
|
|
"""Convert a TensorProto's data_type to corresponding field name for storage. It can be used while making tensors.
|
|
|
|
Args:
|
|
tensor_dtype: TensorProto's data_type
|
|
|
|
Returns:
|
|
field name
|
|
"""
|
|
return mapping._STORAGE_TENSOR_TYPE_TO_FIELD[
|
|
mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
|
|
]
|
|
|
|
|
|
def np_dtype_to_tensor_dtype(np_dtype: np.dtype) -> int:
|
|
"""Convert a numpy's dtype to corresponding tensor type. It can be used while converting numpy arrays to tensors.
|
|
|
|
Args:
|
|
np_dtype: numpy's data_type
|
|
|
|
Returns:
|
|
TensorsProto's data_type
|
|
"""
|
|
if np_dtype in mapping._NP_TYPE_TO_TENSOR_TYPE:
|
|
return cast(
|
|
int,
|
|
mapping._NP_TYPE_TO_TENSOR_TYPE[np_dtype],
|
|
)
|
|
|
|
if np.issubdtype(np_dtype, np.str_):
|
|
return TensorProto.STRING
|
|
|
|
if np_dtype in {
|
|
custom_np_types.bfloat16,
|
|
custom_np_types.float8e4m3fn,
|
|
custom_np_types.float8e4m3fnuz,
|
|
custom_np_types.float8e5m2,
|
|
custom_np_types.float8e5m2fnuz,
|
|
custom_np_types.int4,
|
|
custom_np_types.uint4,
|
|
}:
|
|
return custom_np_types.mapping_name_to_data_type[np_dtype.descr[0][0]]
|
|
|
|
raise ValueError(
|
|
f"Unable to convert type {np_dtype!r} into TensorProto element type."
|
|
)
|
|
|
|
|
|
def get_all_tensor_dtypes() -> KeysView[int]:
|
|
"""Get all tensor types from TensorProto.
|
|
|
|
Returns:
|
|
all tensor types from TensorProto
|
|
"""
|
|
return mapping.TENSOR_TYPE_MAP.keys()
|
|
|
|
|
|
_ATTRIBUTE_TYPE_TO_STR = {k: v for v, k in AttributeProto.AttributeType.items()}
|
|
|
|
|
|
def _attr_type_to_str(attr_type: int) -> str:
|
|
"""Convert AttributeProto type to string.
|
|
|
|
Args:
|
|
attr_type: AttributeProto type.
|
|
|
|
Returns:
|
|
String representing the supplied attr_type.
|
|
"""
|
|
if attr_type in AttributeProto.AttributeType.values():
|
|
return _ATTRIBUTE_TYPE_TO_STR[attr_type] # type: ignore[no-any-return]
|
|
return AttributeProto.AttributeType.keys()[0] # type: ignore[no-any-return]
|