188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from onnx.onnx_pb import NodeProto
|
|
from onnx.reference.custom_element_types import (
|
|
convert_from_ml_dtypes,
|
|
convert_to_ml_dtypes,
|
|
)
|
|
from onnx.reference.op_run import OpRun, RuntimeTypeError
|
|
|
|
|
|
class OpRunUnary(OpRun):
|
|
"""Ancestor to all unary operators in this subfolder.
|
|
|
|
Checks that input and output types are the same.
|
|
"""
|
|
|
|
def run(self, x): # type: ignore
|
|
"""Calls method ``_run``, catches exceptions, displays a longer error message.
|
|
|
|
Supports only unary operators.
|
|
"""
|
|
self._log("-- begin %s.run(1 input)", self.__class__.__name__)
|
|
x = convert_to_ml_dtypes(x)
|
|
try:
|
|
res = self._run(x)
|
|
except TypeError as e:
|
|
raise TypeError(
|
|
f"Issues with types {', '.join(str(type(_)) for _ in [x])} "
|
|
f"(unary operator {self.__class__.__name__!r})."
|
|
) from e
|
|
res = (convert_from_ml_dtypes(res[0]),)
|
|
self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
|
|
return self._check_and_fix_outputs(res)
|
|
|
|
|
|
class OpRunUnaryNum(OpRunUnary):
|
|
"""Ancestor to all unary and numerical operators in this subfolder.
|
|
|
|
Checks that input and output types are the same.
|
|
"""
|
|
|
|
def run(self, x): # type: ignore
|
|
"""Calls method ``OpRunUnary.run``.
|
|
|
|
Catches exceptions, displays a longer error message.
|
|
Checks that the result is not empty.
|
|
"""
|
|
res = OpRunUnary.run(self, x)
|
|
if len(res) == 0 or res[0] is None:
|
|
return res
|
|
if not isinstance(res[0], list) and res[0].dtype != x.dtype:
|
|
raise RuntimeTypeError(
|
|
f"Output type mismatch: input '{x.dtype}' != output '{res[0].dtype}' "
|
|
f"(operator {self.__class__.__name__!r})."
|
|
)
|
|
return self._check_and_fix_outputs(res)
|
|
|
|
|
|
class OpRunBinary(OpRun):
|
|
"""Ancestor to all binary operators in this subfolder.
|
|
|
|
Checks that input and output types are the same.
|
|
"""
|
|
|
|
def run(self, x, y): # type: ignore
|
|
"""Calls method ``_run``, catches exceptions, displays a longer error message.
|
|
|
|
Supports only binary operators.
|
|
"""
|
|
self._log("-- begin %s.run(2 inputs)", self.__class__.__name__)
|
|
if x is None or y is None:
|
|
raise RuntimeError(
|
|
f"x and y have different dtype: {type(x)} != {type(y)} ({type(self)})"
|
|
)
|
|
if x.dtype != y.dtype:
|
|
raise RuntimeTypeError(
|
|
f"Input type mismatch: {x.dtype} != {y.dtype} "
|
|
f"(operator '{self.__class__.__name__!r}', "
|
|
f"shapes {x.shape}, {y.shape})."
|
|
)
|
|
x = convert_to_ml_dtypes(x)
|
|
y = convert_to_ml_dtypes(y)
|
|
try:
|
|
res = self._run(x, y)
|
|
except (TypeError, ValueError) as e:
|
|
raise TypeError(
|
|
f"Issues with types {', '.join(str(type(_)) for _ in [x, y])} "
|
|
f"(binary operator {self.__class__.__name__!r})."
|
|
) from e
|
|
res = (convert_from_ml_dtypes(res[0]),)
|
|
self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
|
|
return self._check_and_fix_outputs(res)
|
|
|
|
|
|
class OpRunBinaryComparison(OpRunBinary):
|
|
"""Ancestor to all binary operators in this subfolder comparing tensors."""
|
|
|
|
|
|
class OpRunBinaryNum(OpRunBinary):
|
|
"""Ancestor to all binary operators in this subfolder.
|
|
|
|
Checks that input oud output types are the same.
|
|
"""
|
|
|
|
def run(self, x, y): # type: ignore
|
|
"""Calls method ``OpRunBinary.run``, catches exceptions, displays a longer error message."""
|
|
res = OpRunBinary.run(self, x, y)
|
|
if res[0].dtype != x.dtype:
|
|
raise RuntimeTypeError(
|
|
f"Output type mismatch: {x.dtype} != {res[0].dtype} or {y.dtype} "
|
|
f"(operator {self.__class__.__name__!r})"
|
|
f" type(x)={type(x)} type(y)={type(y)}"
|
|
)
|
|
return self._check_and_fix_outputs(res)
|
|
|
|
|
|
class OpRunBinaryNumpy(OpRunBinaryNum):
|
|
"""*numpy_fct* is a binary numpy function which
|
|
takes two matrices.
|
|
"""
|
|
|
|
def __init__(
|
|
self, numpy_fct: Any, onnx_node: NodeProto, run_params: dict[str, Any]
|
|
):
|
|
OpRunBinaryNum.__init__(self, onnx_node, run_params)
|
|
self.numpy_fct = numpy_fct
|
|
|
|
def _run(self, a, b): # type: ignore
|
|
a = convert_to_ml_dtypes(a)
|
|
b = convert_to_ml_dtypes(b)
|
|
res = (self.numpy_fct(a, b),)
|
|
res = (convert_from_ml_dtypes(res[0]),)
|
|
return self._check_and_fix_outputs(res)
|
|
|
|
|
|
class OpRunReduceNumpy(OpRun): # type: ignore
|
|
"""Implements the reduce logic.
|
|
It must have a parameter *axes*.
|
|
"""
|
|
|
|
def __init__(self, onnx_node: NodeProto, run_params: dict[str, Any]):
|
|
OpRun.__init__(self, onnx_node, run_params)
|
|
if hasattr(self, "axes"):
|
|
if isinstance(self.axes, np.ndarray): # type: ignore
|
|
if len(self.axes.shape) == 0 or self.axes.shape[0] == 0: # type: ignore
|
|
self.axes = None
|
|
else:
|
|
self.axes = tuple(self.axes)
|
|
elif self.axes in [[], ()]:
|
|
self.axes = None
|
|
elif isinstance(self.axes, list):
|
|
self.axes = tuple(self.axes)
|
|
|
|
def is_axes_empty(self, axes):
|
|
return axes is None
|
|
|
|
def handle_axes(self, axes): # noqa: PLR0911
|
|
if isinstance(axes, tuple):
|
|
if len(axes) == 0:
|
|
return None
|
|
return axes
|
|
if axes is None:
|
|
return None
|
|
if isinstance(axes, (int, tuple)):
|
|
return axes
|
|
if not isinstance(axes, np.ndarray):
|
|
raise TypeError(f"axes must be an array, not {type(axes)}.")
|
|
if len(axes.shape) == 0:
|
|
return int(axes)
|
|
if 0 in axes.shape:
|
|
return None
|
|
return tuple(axes.ravel().tolist())
|
|
|
|
def output_shape(self, data, axes, keepdims):
|
|
return np.sum(data, axis=axes, keepdims=keepdims).shape
|
|
|
|
def reduce_constant(self, data, const_val, axes, keepdims):
|
|
"""Special case reduction where the output value is a constant."""
|
|
output_shape = self.output_shape(data, axes, keepdims)
|
|
return (np.full(output_shape, const_val, dtype=data.dtype),)
|