55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx.reference.op_run import OpRun
|
|
|
|
|
|
class Unsqueeze_1(OpRun):
|
|
def _run(self, data, axes=None): # type: ignore
|
|
if isinstance(axes, np.ndarray):
|
|
axes = tuple(axes)
|
|
elif axes in ([], ()):
|
|
axes = None
|
|
elif isinstance(axes, list):
|
|
axes = tuple(axes)
|
|
if isinstance(axes, (tuple, list)):
|
|
sq = data
|
|
for a in axes:
|
|
sq = np.expand_dims(sq, axis=a)
|
|
else:
|
|
raise RuntimeError( # noqa: TRY004
|
|
"axes cannot be None for operator Unsqueeze (Unsqueeze_1)."
|
|
)
|
|
return (sq,)
|
|
|
|
|
|
class Unsqueeze_11(Unsqueeze_1):
|
|
pass
|
|
|
|
|
|
class Unsqueeze_13(OpRun):
|
|
def _run(self, data, axes=None): # type: ignore
|
|
if axes is not None:
|
|
if hasattr(axes, "__iter__") and len(axes.shape) > 0:
|
|
try:
|
|
sq = np.expand_dims(data, axis=tuple(axes))
|
|
except TypeError:
|
|
# numpy 1.18 supports axes as a tuple
|
|
if len(axes) == 1:
|
|
sq = np.expand_dims(data, axis=tuple(axes)[0])
|
|
else:
|
|
sq = data
|
|
for a in reversed(axes):
|
|
sq = np.expand_dims(sq, axis=a)
|
|
else:
|
|
sq = np.expand_dims(data, axis=axes)
|
|
else:
|
|
raise RuntimeError(
|
|
"axes cannot be None for operator Unsqueeze (Unsqueeze_13)."
|
|
)
|
|
return (sq,)
|