Files
2024-10-30 22:14:35 +01:00

59 lines
2.0 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
def _gather_nd_impl(
data: np.ndarray, indices: np.ndarray, batch_dims: int
) -> tuple[np.ndarray]:
# Note the data rank - will be reused multiple times later
data_rank = len(data.shape)
# The list of data/indice shape of batch_dims.
batch_dims_shape = []
# The number of elements in the batch_dims for data/indice array.
batch_dims_size = 1
# Check the shape of indice and data are identical for batch dims.
for i in range(batch_dims):
batch_dims_shape.append(indices.shape[i])
batch_dims_size *= indices.shape[i]
# Compute output of the op as below.
# Compute shape of output array.
output_shape = (
batch_dims_shape + list(indices.shape)[batch_dims:-1]
if (indices.shape[-1] == data_rank - batch_dims)
else batch_dims_shape
+ list(indices.shape)[batch_dims:-1]
+ list(data.shape)[batch_dims + indices.shape[-1] :]
)
# Placeholder for output data.
output_data_buffer = []
# Flatten 'indices' to 2D array.
reshaped_indices = indices.reshape(batch_dims_size, -1, indices.shape[-1])
# Flatten 'data' to array of shape
# (batch_dim_size, data.shape[batch_dimes:]).
reshaped_data = data.reshape((batch_dims_size,) + data.shape[batch_dims:])
# Gather each scalar value from 'data'.
for batch_dim in range(reshaped_indices.shape[0]):
for outer_dim in range(reshaped_indices.shape[1]):
gather_index = tuple(reshaped_indices[batch_dim][outer_dim])
output_data_buffer.append(reshaped_data[(batch_dim, *gather_index)])
return (np.asarray(output_data_buffer, dtype=data.dtype).reshape(output_shape),)
class GatherND(OpRun):
def _run(self, data, indices, batch_dims=None): # type: ignore
return _gather_nd_impl(data, indices, batch_dims) # type: ignore