70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx.reference.op_run import OpRun
|
|
|
|
|
|
class If(OpRun):
|
|
def __init__(self, onnx_node, run_params): # type: ignore
|
|
OpRun.__init__(self, onnx_node, run_params)
|
|
if "opsets" not in self.run_params:
|
|
raise KeyError("run_params must contains key 'opsets'.")
|
|
if "verbose" not in run_params:
|
|
raise KeyError("run_params must contains key 'verbose'.")
|
|
if "existing_functions" not in self.run_params:
|
|
raise KeyError("run_params must contains key 'existing_functions'.")
|
|
|
|
def need_context(self) -> bool:
|
|
"""Tells the runtime if this node needs the context
|
|
(all the results produced so far) as it may silently access
|
|
one of them (operator Loop).
|
|
The default answer is `False`.
|
|
"""
|
|
return True
|
|
|
|
def _run(
|
|
self,
|
|
cond: np.ndarray | np.bool_,
|
|
context=None,
|
|
else_branch=None, # noqa: ARG002
|
|
then_branch=None, # noqa: ARG002
|
|
attributes=None,
|
|
):
|
|
if cond.size != 1:
|
|
raise ValueError(
|
|
f"Operator If ({self.onnx_node.name!r}) expects a single element as condition, but the size of 'cond' is {len(cond)}."
|
|
)
|
|
cond_ = cond.item(0)
|
|
if cond_:
|
|
self._log(" -- then> {%r}", context)
|
|
outputs = self._run_then_branch(context, attributes=attributes) # type: ignore
|
|
self._log(" -- then<")
|
|
final = tuple(outputs)
|
|
branch = "then"
|
|
else:
|
|
self._log(" -- else> {%r}", context)
|
|
outputs = self._run_else_branch(context, attributes=attributes) # type: ignore
|
|
self._log(" -- else<")
|
|
final = tuple(outputs)
|
|
branch = "else"
|
|
|
|
if not final:
|
|
raise RuntimeError(
|
|
f"Operator If ({self.onnx_node.name!r}) does not have any output."
|
|
)
|
|
for i, f in enumerate(final):
|
|
if f is None:
|
|
br = self.then_branch if branch == "then" else self.else_branch # type: ignore
|
|
names = br.output_names
|
|
inits = [i.name for i in br.obj.graph.initializer]
|
|
raise RuntimeError(
|
|
f"Output {i!r} (branch={branch!r}, name={names[i]!r}) is None, "
|
|
f"available inputs={sorted(context)}, initializers={inits}."
|
|
)
|
|
return self._check_and_fix_outputs(final)
|