Files
Reinforced-Learning-Godot/rl/Lib/site-packages/onnx/reference/ops/_op.py
2024-10-30 22:14:35 +01:00

188 lines
6.3 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any
import numpy as np
from onnx.onnx_pb import NodeProto
from onnx.reference.custom_element_types import (
convert_from_ml_dtypes,
convert_to_ml_dtypes,
)
from onnx.reference.op_run import OpRun, RuntimeTypeError
class OpRunUnary(OpRun):
"""Ancestor to all unary operators in this subfolder.
Checks that input and output types are the same.
"""
def run(self, x): # type: ignore
"""Calls method ``_run``, catches exceptions, displays a longer error message.
Supports only unary operators.
"""
self._log("-- begin %s.run(1 input)", self.__class__.__name__)
x = convert_to_ml_dtypes(x)
try:
res = self._run(x)
except TypeError as e:
raise TypeError(
f"Issues with types {', '.join(str(type(_)) for _ in [x])} "
f"(unary operator {self.__class__.__name__!r})."
) from e
res = (convert_from_ml_dtypes(res[0]),)
self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
return self._check_and_fix_outputs(res)
class OpRunUnaryNum(OpRunUnary):
"""Ancestor to all unary and numerical operators in this subfolder.
Checks that input and output types are the same.
"""
def run(self, x): # type: ignore
"""Calls method ``OpRunUnary.run``.
Catches exceptions, displays a longer error message.
Checks that the result is not empty.
"""
res = OpRunUnary.run(self, x)
if len(res) == 0 or res[0] is None:
return res
if not isinstance(res[0], list) and res[0].dtype != x.dtype:
raise RuntimeTypeError(
f"Output type mismatch: input '{x.dtype}' != output '{res[0].dtype}' "
f"(operator {self.__class__.__name__!r})."
)
return self._check_and_fix_outputs(res)
class OpRunBinary(OpRun):
"""Ancestor to all binary operators in this subfolder.
Checks that input and output types are the same.
"""
def run(self, x, y): # type: ignore
"""Calls method ``_run``, catches exceptions, displays a longer error message.
Supports only binary operators.
"""
self._log("-- begin %s.run(2 inputs)", self.__class__.__name__)
if x is None or y is None:
raise RuntimeError(
f"x and y have different dtype: {type(x)} != {type(y)} ({type(self)})"
)
if x.dtype != y.dtype:
raise RuntimeTypeError(
f"Input type mismatch: {x.dtype} != {y.dtype} "
f"(operator '{self.__class__.__name__!r}', "
f"shapes {x.shape}, {y.shape})."
)
x = convert_to_ml_dtypes(x)
y = convert_to_ml_dtypes(y)
try:
res = self._run(x, y)
except (TypeError, ValueError) as e:
raise TypeError(
f"Issues with types {', '.join(str(type(_)) for _ in [x, y])} "
f"(binary operator {self.__class__.__name__!r})."
) from e
res = (convert_from_ml_dtypes(res[0]),)
self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
return self._check_and_fix_outputs(res)
class OpRunBinaryComparison(OpRunBinary):
"""Ancestor to all binary operators in this subfolder comparing tensors."""
class OpRunBinaryNum(OpRunBinary):
"""Ancestor to all binary operators in this subfolder.
Checks that input oud output types are the same.
"""
def run(self, x, y): # type: ignore
"""Calls method ``OpRunBinary.run``, catches exceptions, displays a longer error message."""
res = OpRunBinary.run(self, x, y)
if res[0].dtype != x.dtype:
raise RuntimeTypeError(
f"Output type mismatch: {x.dtype} != {res[0].dtype} or {y.dtype} "
f"(operator {self.__class__.__name__!r})"
f" type(x)={type(x)} type(y)={type(y)}"
)
return self._check_and_fix_outputs(res)
class OpRunBinaryNumpy(OpRunBinaryNum):
"""*numpy_fct* is a binary numpy function which
takes two matrices.
"""
def __init__(
self, numpy_fct: Any, onnx_node: NodeProto, run_params: dict[str, Any]
):
OpRunBinaryNum.__init__(self, onnx_node, run_params)
self.numpy_fct = numpy_fct
def _run(self, a, b): # type: ignore
a = convert_to_ml_dtypes(a)
b = convert_to_ml_dtypes(b)
res = (self.numpy_fct(a, b),)
res = (convert_from_ml_dtypes(res[0]),)
return self._check_and_fix_outputs(res)
class OpRunReduceNumpy(OpRun): # type: ignore
"""Implements the reduce logic.
It must have a parameter *axes*.
"""
def __init__(self, onnx_node: NodeProto, run_params: dict[str, Any]):
OpRun.__init__(self, onnx_node, run_params)
if hasattr(self, "axes"):
if isinstance(self.axes, np.ndarray): # type: ignore
if len(self.axes.shape) == 0 or self.axes.shape[0] == 0: # type: ignore
self.axes = None
else:
self.axes = tuple(self.axes)
elif self.axes in [[], ()]:
self.axes = None
elif isinstance(self.axes, list):
self.axes = tuple(self.axes)
def is_axes_empty(self, axes):
return axes is None
def handle_axes(self, axes): # noqa: PLR0911
if isinstance(axes, tuple):
if len(axes) == 0:
return None
return axes
if axes is None:
return None
if isinstance(axes, (int, tuple)):
return axes
if not isinstance(axes, np.ndarray):
raise TypeError(f"axes must be an array, not {type(axes)}.")
if len(axes.shape) == 0:
return int(axes)
if 0 in axes.shape:
return None
return tuple(axes.ravel().tolist())
def output_shape(self, data, axes, keepdims):
return np.sum(data, axis=axes, keepdims=keepdims).shape
def reduce_constant(self, data, const_val, axes, keepdims):
"""Special case reduction where the output value is a constant."""
output_shape = self.output_shape(data, axes, keepdims)
return (np.full(output_shape, const_val, dtype=data.dtype),)