225 lines
8.1 KiB
Python
225 lines
8.1 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import warnings
|
|
from typing import Any, Dict, NamedTuple, Union, cast
|
|
|
|
import numpy as np
|
|
|
|
from onnx import OptionalProto, SequenceProto, TensorProto
|
|
|
|
|
|
class TensorDtypeMap(NamedTuple):
|
|
np_dtype: np.dtype
|
|
storage_dtype: int
|
|
name: str
|
|
|
|
|
|
# tensor_dtype: (numpy type, storage type, string name)
|
|
TENSOR_TYPE_MAP = {
|
|
int(TensorProto.FLOAT): TensorDtypeMap(
|
|
np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT"
|
|
),
|
|
int(TensorProto.UINT8): TensorDtypeMap(
|
|
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8"
|
|
),
|
|
int(TensorProto.INT8): TensorDtypeMap(
|
|
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8"
|
|
),
|
|
int(TensorProto.UINT16): TensorDtypeMap(
|
|
np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16"
|
|
),
|
|
int(TensorProto.INT16): TensorDtypeMap(
|
|
np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16"
|
|
),
|
|
int(TensorProto.INT32): TensorDtypeMap(
|
|
np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32"
|
|
),
|
|
int(TensorProto.INT64): TensorDtypeMap(
|
|
np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64"
|
|
),
|
|
int(TensorProto.BOOL): TensorDtypeMap(
|
|
np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL"
|
|
),
|
|
int(TensorProto.FLOAT16): TensorDtypeMap(
|
|
np.dtype("float16"), int(TensorProto.UINT16), "TensorProto.FLOAT16"
|
|
),
|
|
# Native numpy does not support bfloat16 so now use float32.
|
|
int(TensorProto.BFLOAT16): TensorDtypeMap(
|
|
np.dtype("float32"), int(TensorProto.UINT16), "TensorProto.BFLOAT16"
|
|
),
|
|
int(TensorProto.DOUBLE): TensorDtypeMap(
|
|
np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE"
|
|
),
|
|
int(TensorProto.COMPLEX64): TensorDtypeMap(
|
|
np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64"
|
|
),
|
|
int(TensorProto.COMPLEX128): TensorDtypeMap(
|
|
np.dtype("complex128"), int(TensorProto.DOUBLE), "TensorProto.COMPLEX128"
|
|
),
|
|
int(TensorProto.UINT32): TensorDtypeMap(
|
|
np.dtype("uint32"), int(TensorProto.UINT32), "TensorProto.UINT32"
|
|
),
|
|
int(TensorProto.UINT64): TensorDtypeMap(
|
|
np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64"
|
|
),
|
|
int(TensorProto.STRING): TensorDtypeMap(
|
|
np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING"
|
|
),
|
|
# Native numpy does not support float8 types, so now use float32 for these types.
|
|
int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap(
|
|
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FN"
|
|
),
|
|
int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap(
|
|
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FNUZ"
|
|
),
|
|
int(TensorProto.FLOAT8E5M2): TensorDtypeMap(
|
|
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2"
|
|
),
|
|
int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap(
|
|
np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2FNUZ"
|
|
),
|
|
# Native numpy does not support uint4/int4 so now use uint8/int8 for these types.
|
|
int(TensorProto.UINT4): TensorDtypeMap(
|
|
np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT4"
|
|
),
|
|
int(TensorProto.INT4): TensorDtypeMap(
|
|
np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT4"
|
|
),
|
|
}
|
|
|
|
|
|
class DeprecatedWarningDict(dict): # type: ignore
|
|
def __init__(
|
|
self,
|
|
dictionary: dict[int, int | str | np.dtype],
|
|
original_function: str,
|
|
future_function: str = "",
|
|
) -> None:
|
|
super().__init__(dictionary)
|
|
self._origin_function = original_function
|
|
self._future_function = future_function
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, DeprecatedWarningDict):
|
|
return False
|
|
return (
|
|
self._origin_function == other._origin_function
|
|
and self._future_function == other._future_function
|
|
)
|
|
|
|
def __getitem__(self, key: int | str | np.dtype) -> Any:
|
|
if not self._future_function:
|
|
warnings.warn(
|
|
str(
|
|
f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release."
|
|
"To silence this warning, please simply use if-else statement to get the corresponding value."
|
|
),
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
else:
|
|
warnings.warn(
|
|
str(
|
|
f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release."
|
|
f"To silence this warning, please use `helper.{self._future_function}` instead."
|
|
),
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
return super().__getitem__(key)
|
|
|
|
|
|
# This map is used for converting TensorProto values into numpy arrays
|
|
TENSOR_TYPE_TO_NP_TYPE = DeprecatedWarningDict(
|
|
{tensor_dtype: value.np_dtype for tensor_dtype, value in TENSOR_TYPE_MAP.items()},
|
|
"TENSOR_TYPE_TO_NP_TYPE",
|
|
"tensor_dtype_to_np_dtype",
|
|
)
|
|
# This is only used to get keys into STORAGE_TENSOR_TYPE_TO_FIELD.
|
|
# TODO(https://github.com/onnx/onnx/issues/4554): Move these variables into _mapping.py
|
|
|
|
TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE = DeprecatedWarningDict(
|
|
{
|
|
tensor_dtype: value.storage_dtype
|
|
for tensor_dtype, value in TENSOR_TYPE_MAP.items()
|
|
},
|
|
"TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE",
|
|
"tensor_dtype_to_storage_tensor_dtype",
|
|
)
|
|
|
|
# NP_TYPE_TO_TENSOR_TYPE will be eventually removed in the future
|
|
# and _NP_TYPE_TO_TENSOR_TYPE will only be used internally
|
|
_NP_TYPE_TO_TENSOR_TYPE = {
|
|
v: k
|
|
for k, v in TENSOR_TYPE_TO_NP_TYPE.items()
|
|
if k
|
|
not in (
|
|
TensorProto.BFLOAT16,
|
|
TensorProto.FLOAT8E4M3FN,
|
|
TensorProto.FLOAT8E4M3FNUZ,
|
|
TensorProto.FLOAT8E5M2,
|
|
TensorProto.FLOAT8E5M2FNUZ,
|
|
TensorProto.UINT4,
|
|
TensorProto.INT4,
|
|
)
|
|
}
|
|
|
|
# Currently native numpy does not support bfloat16 so TensorProto.BFLOAT16 is ignored for now
|
|
# Numpy float32 array is only reversed to TensorProto.FLOAT
|
|
NP_TYPE_TO_TENSOR_TYPE = DeprecatedWarningDict(
|
|
cast(Dict[int, Union[int, str, Any]], _NP_TYPE_TO_TENSOR_TYPE),
|
|
"NP_TYPE_TO_TENSOR_TYPE",
|
|
"np_dtype_to_tensor_dtype",
|
|
)
|
|
|
|
# STORAGE_TENSOR_TYPE_TO_FIELD will be eventually removed in the future
|
|
# and _STORAGE_TENSOR_TYPE_TO_FIELD will only be used internally
|
|
_STORAGE_TENSOR_TYPE_TO_FIELD = {
|
|
int(TensorProto.FLOAT): "float_data",
|
|
int(TensorProto.INT32): "int32_data",
|
|
int(TensorProto.INT64): "int64_data",
|
|
int(TensorProto.UINT8): "int32_data",
|
|
int(TensorProto.UINT16): "int32_data",
|
|
int(TensorProto.DOUBLE): "double_data",
|
|
int(TensorProto.COMPLEX64): "float_data",
|
|
int(TensorProto.COMPLEX128): "double_data",
|
|
int(TensorProto.UINT32): "uint64_data",
|
|
int(TensorProto.UINT64): "uint64_data",
|
|
int(TensorProto.STRING): "string_data",
|
|
int(TensorProto.BOOL): "int32_data",
|
|
}
|
|
|
|
STORAGE_TENSOR_TYPE_TO_FIELD = DeprecatedWarningDict(
|
|
cast(Dict[int, Union[int, str, Any]], _STORAGE_TENSOR_TYPE_TO_FIELD),
|
|
"STORAGE_TENSOR_TYPE_TO_FIELD",
|
|
)
|
|
|
|
|
|
# This map will be removed and there is no replacement for it
|
|
STORAGE_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict(
|
|
{
|
|
int(SequenceProto.TENSOR): "tensor_values",
|
|
int(SequenceProto.SPARSE_TENSOR): "sparse_tensor_values",
|
|
int(SequenceProto.SEQUENCE): "sequence_values",
|
|
int(SequenceProto.MAP): "map_values",
|
|
int(OptionalProto.OPTIONAL): "optional_value",
|
|
},
|
|
"STORAGE_ELEMENT_TYPE_TO_FIELD",
|
|
)
|
|
|
|
|
|
# This map will be removed and there is no replacement for it
|
|
OPTIONAL_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict(
|
|
{
|
|
int(OptionalProto.TENSOR): "tensor_value",
|
|
int(OptionalProto.SPARSE_TENSOR): "sparse_tensor_value",
|
|
int(OptionalProto.SEQUENCE): "sequence_value",
|
|
int(OptionalProto.MAP): "map_value",
|
|
int(OptionalProto.OPTIONAL): "optional_value",
|
|
},
|
|
"OPTIONAL_ELEMENT_TYPE_TO_FIELD",
|
|
)
|