Files
2024-10-30 22:14:35 +01:00

152 lines
5.4 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
class Scan(OpRun):
def __init__(self, onnx_node, run_params): # type: ignore
OpRun.__init__(self, onnx_node, run_params)
if not hasattr(self.body, "run"): # type: ignore
raise RuntimeError(
f"Parameter 'body' must have a method 'run', type {type(self.body)}." # type: ignore
)
self.input_directions_ = [
(
0
if self.scan_input_directions is None # type: ignore
or i >= len(self.scan_input_directions) # type: ignore
else self.scan_input_directions[i]
) # type: ignore
for i in range(self.num_scan_inputs) # type: ignore
]
max_dir_in = max(self.input_directions_)
if max_dir_in != 0:
raise RuntimeError(
"Scan is not implemented for other output input_direction than 0."
)
self.input_axes_ = [
(
0
if self.scan_input_axes is None or i >= len(self.scan_input_axes) # type: ignore
else self.scan_input_axes[i]
) # type: ignore
for i in range(self.num_scan_inputs) # type: ignore
]
max_axe_in = max(self.input_axes_)
if max_axe_in != 0:
raise RuntimeError("Scan is not implemented for other input axes than 0.")
self.input_names = self.body.input_names # type: ignore
self.output_names = self.body.output_names # type: ignore
def _common_run_shape(self, *args): # type: ignore
num_loop_state_vars = len(args) - self.num_scan_inputs # type: ignore
num_scan_outputs = len(args) - num_loop_state_vars
output_directions = [
(
0
if self.scan_output_directions is None # type: ignore
or i >= len(self.scan_output_directions) # type: ignore
else self.scan_output_directions[i]
) # type: ignore
for i in range(num_scan_outputs)
]
max_dir_out = max(output_directions)
if max_dir_out != 0:
raise RuntimeError(
"Scan is not implemented for other output output_direction than 0."
)
output_axes = [
(
0
if self.scan_output_axes is None or i >= len(self.scan_output_axes) # type: ignore
else self.scan_output_axes[i]
) # type: ignore
for i in range(num_scan_outputs)
]
max_axe_out = max(output_axes)
if max_axe_out != 0:
raise RuntimeError("Scan is not implemented for other output axes than 0.")
state_names_in = self.input_names[: self.num_scan_inputs] # type: ignore
state_names_out = self.output_names[: len(state_names_in)]
scan_names_in = self.input_names[num_loop_state_vars:]
scan_names_out = self.output_names[num_loop_state_vars:]
scan_values = args[num_loop_state_vars:]
states = args[:num_loop_state_vars]
return (
num_loop_state_vars,
num_scan_outputs,
output_directions,
max_dir_out,
output_axes,
max_axe_out,
state_names_in,
state_names_out,
scan_names_in,
scan_names_out,
scan_values,
states,
)
def _run( # type:ignore
self,
*args,
body=None, # noqa: ARG002
num_scan_inputs=None, # noqa: ARG002
scan_input_axes=None, # noqa: ARG002
scan_input_directions=None, # noqa: ARG002
scan_output_axes=None, # noqa: ARG002
scan_output_directions=None, # noqa: ARG002
attributes=None, # noqa: ARG002
):
# TODO: support overridden attributes.
(
num_loop_state_vars,
num_scan_outputs,
output_directions,
max_dir_out,
output_axes,
max_axe_out,
state_names_in,
state_names_out,
scan_names_in,
scan_names_out,
scan_values,
states,
) = self._common_run_shape(*args)
max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
results = [[] for _ in scan_names_out] # type: ignore
for it in range(max_iter):
inputs = {}
for name, value in zip(state_names_in, states):
inputs[name] = value
for name, value in zip(scan_names_in, scan_values):
inputs[name] = value[it]
try:
outputs_list = self._run_body(inputs) # type: ignore
except TypeError as e:
raise TypeError(
f"Unable to call 'run' for type '{type(self.body)}'." # type: ignore
) from e
outputs = dict(zip(self.output_names, outputs_list))
states = [outputs[name] for name in state_names_out]
for i, name in enumerate(scan_names_out):
results[i].append(np.expand_dims(outputs[name], axis=0))
for res in results:
conc = np.vstack(res)
states.append(conc)
return self._check_and_fix_outputs(tuple(states))