26 lines
676 B
Python
26 lines
676 B
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx.reference.ops._op import OpRunBinaryNum
|
|
|
|
|
|
def numpy_matmul(a, b): # type: ignore
|
|
"""Implements a matmul product. See :func:`np.matmul`.
|
|
Handles sparse matrices.
|
|
"""
|
|
try:
|
|
if len(a.shape) <= 2 and len(b.shape) <= 2:
|
|
return np.dot(a, b)
|
|
return np.matmul(a, b)
|
|
except ValueError as e:
|
|
raise ValueError(f"Unable to multiply shapes {a.shape!r}, {b.shape!r}.") from e
|
|
|
|
|
|
class MatMul(OpRunBinaryNum):
|
|
def _run(self, a, b): # type: ignore
|
|
return (numpy_matmul(a, b),)
|