154 lines
5.2 KiB
Python
154 lines
5.2 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx._custom_element_types import (
|
|
bfloat16,
|
|
float8e4m3fn,
|
|
float8e4m3fnuz,
|
|
float8e5m2,
|
|
float8e5m2fnuz,
|
|
int4,
|
|
uint4,
|
|
)
|
|
from onnx.reference.op_run import OpRun, RefAttrName
|
|
|
|
|
|
def _check_dtype(val): # type: ignore
|
|
a = val.dtype
|
|
if not isinstance(a, np.dtype) and a not in {
|
|
bfloat16,
|
|
float8e4m3fn,
|
|
float8e4m3fnuz,
|
|
float8e5m2,
|
|
float8e5m2fnuz,
|
|
uint4,
|
|
int4,
|
|
np.int8,
|
|
np.uint8,
|
|
np.float16,
|
|
np.float32,
|
|
np.float64,
|
|
np.int32,
|
|
np.int64,
|
|
np.int16,
|
|
np.uint16,
|
|
np.uint32,
|
|
np.bool_,
|
|
np.str_,
|
|
np.uint64,
|
|
bool,
|
|
str,
|
|
}:
|
|
raise TypeError(
|
|
f"Type ({a}, {type(a)}) is not a numpy type (operator 'Constant')"
|
|
)
|
|
|
|
|
|
class ConstantCommon(OpRun):
|
|
def _check(self, cst): # type: ignore
|
|
if isinstance(cst, tuple):
|
|
raise TypeError(f"Unexpected type {type(cst)} for a constant.")
|
|
return cst
|
|
|
|
|
|
class Constant_1(ConstantCommon):
|
|
def __init__(self, onnx_node, run_params): # type: ignore
|
|
ConstantCommon.__init__(self, onnx_node, run_params)
|
|
self.cst = self.value # type: ignore
|
|
_check_dtype(self.cst)
|
|
|
|
def _run(self, **overridden_attributes): # type: ignore
|
|
if overridden_attributes and (
|
|
len(overridden_attributes) > 1
|
|
or "value" not in overridden_attributes
|
|
or id(overridden_attributes["value"]) != id(self.value)
|
|
):
|
|
raise RuntimeError(
|
|
"Function attributes are not implemented for opset <= 11. Use opset > 12."
|
|
)
|
|
return (self._check(self.cst),)
|
|
|
|
|
|
class Constant_9(Constant_1):
|
|
def __init__(self, onnx_node, run_params): # type: ignore
|
|
Constant_1.__init__(self, onnx_node, run_params)
|
|
|
|
|
|
class Constant_11(ConstantCommon):
|
|
def __init__(self, onnx_node, run_params): # type: ignore
|
|
ConstantCommon.__init__(self, onnx_node, run_params)
|
|
if getattr(self, "sparse_value", None) is None:
|
|
self.cst = self.value # type: ignore
|
|
else:
|
|
self.cst = self.sparse_value # type: ignore
|
|
_check_dtype(self.cst)
|
|
|
|
def _run(self, **overridden_attributes): # type: ignore
|
|
if overridden_attributes and (
|
|
len(overridden_attributes) > 1
|
|
or "value" not in overridden_attributes
|
|
or id(overridden_attributes["value"]) != id(self.value)
|
|
):
|
|
raise RuntimeError(
|
|
"Function attributes are not implemented for opset <= 11. Use opset > 12."
|
|
)
|
|
return (self._check(self.cst),)
|
|
|
|
|
|
class Constant_12(ConstantCommon):
|
|
def __init__(self, onnx_node, run_params): # type: ignore
|
|
ConstantCommon.__init__(self, onnx_node, run_params)
|
|
if hasattr(self, "sparse_value") and self.sparse_value is not None: # type: ignore
|
|
self.cst_name = "sparse_value"
|
|
self.cst = self.sparse_value # type: ignore
|
|
self.cst_convert = lambda v: v
|
|
elif hasattr(self, "value") and self.value is not None: # type: ignore
|
|
self.cst_name = "value" # type: ignore
|
|
self.cst = self.value if isinstance(self.value, RefAttrName) else self.value # type: ignore
|
|
self.cst_convert = lambda v: v
|
|
else:
|
|
for attr, np_dtype in {
|
|
"value_float": np.float32,
|
|
"value_floats": np.float32,
|
|
"value_int": np.int64,
|
|
"value_ints": np.int64,
|
|
"value_string": np.str_,
|
|
"value_strings": np.str_,
|
|
}.items():
|
|
if hasattr(self, attr) and getattr(self, attr) is not None: # type: ignore
|
|
self.cst_name = attr
|
|
v = getattr(self, attr)
|
|
self.cst = (
|
|
v # type: ignore
|
|
if isinstance(v, RefAttrName) # type: ignore
|
|
else np.array(v, dtype=np_dtype) # type: ignore
|
|
)
|
|
self.cst_convert = lambda v, np_dtype=np_dtype: np.array( # type: ignore
|
|
v, dtype=np_dtype
|
|
)
|
|
break
|
|
if not hasattr(self, "cst_name"):
|
|
raise AttributeError(
|
|
f"No constant is defined for operator 'Constant', outputs are {onnx_node.output}."
|
|
)
|
|
|
|
def _run(self, **overridden_attributes): # type: ignore
|
|
if self.has_linked_attribute:
|
|
if overridden_attributes is None:
|
|
raise RuntimeError(
|
|
f"Attributes are empty, cannot retrieve value for {self.cst!r}."
|
|
)
|
|
if self.cst_name not in overridden_attributes:
|
|
raise RuntimeError(
|
|
f"Cannot find attribute {self.cst_name!r} in {list(overridden_attributes)!r}."
|
|
)
|
|
value = overridden_attributes[self.cst_name]
|
|
if isinstance(value, np.ndarray):
|
|
return (value,)
|
|
return (self.cst_convert(value),)
|
|
return (self._check(self.cst),)
|