# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import numpy as np from onnx.reference.ops._op import OpRunBinaryNum def numpy_matmul(a, b): # type: ignore """Implements a matmul product. See :func:`np.matmul`. Handles sparse matrices. """ try: if len(a.shape) <= 2 and len(b.shape) <= 2: return np.dot(a, b) return np.matmul(a, b) except ValueError as e: raise ValueError(f"Unable to multiply shapes {a.shape!r}, {b.shape!r}.") from e class MatMul(OpRunBinaryNum): def _run(self, a, b): # type: ignore return (numpy_matmul(a, b),)