Files
2024-10-30 22:14:35 +01:00

55 lines
1.6 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
class Unsqueeze_1(OpRun):
def _run(self, data, axes=None): # type: ignore
if isinstance(axes, np.ndarray):
axes = tuple(axes)
elif axes in ([], ()):
axes = None
elif isinstance(axes, list):
axes = tuple(axes)
if isinstance(axes, (tuple, list)):
sq = data
for a in axes:
sq = np.expand_dims(sq, axis=a)
else:
raise RuntimeError( # noqa: TRY004
"axes cannot be None for operator Unsqueeze (Unsqueeze_1)."
)
return (sq,)
class Unsqueeze_11(Unsqueeze_1):
pass
class Unsqueeze_13(OpRun):
def _run(self, data, axes=None): # type: ignore
if axes is not None:
if hasattr(axes, "__iter__") and len(axes.shape) > 0:
try:
sq = np.expand_dims(data, axis=tuple(axes))
except TypeError:
# numpy 1.18 supports axes as a tuple
if len(axes) == 1:
sq = np.expand_dims(data, axis=tuple(axes)[0])
else:
sq = data
for a in reversed(axes):
sq = np.expand_dims(sq, axis=a)
else:
sq = np.expand_dims(data, axis=axes)
else:
raise RuntimeError(
"axes cannot be None for operator Unsqueeze (Unsqueeze_13)."
)
return (sq,)