Files
2024-10-30 22:14:35 +01:00

78 lines
2.0 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
def _gemm00(a, b, c, alpha, beta): # type: ignore
o = np.dot(a, b) * alpha
if c is not None and beta != 0:
o += c * beta
return o
def _gemm01(a, b, c, alpha, beta): # type: ignore
o = np.dot(a, b.T) * alpha
if c is not None and beta != 0:
o += c * beta
return o
def _gemm10(a, b, c, alpha, beta): # type: ignore
o = np.dot(a.T, b) * alpha
if c is not None and beta != 0:
o += c * beta
return o
def _gemm11(a, b, c, alpha, beta): # type: ignore
o = np.dot(a.T, b.T) * alpha
if c is not None and beta != 0:
o += c * beta
return o
class Gemm_6(OpRun):
def _run(
self,
a,
b,
c=None,
alpha=None,
beta=None,
transA=None,
transB=None,
broadcast=None,
): # type: ignore
if broadcast == 0:
if transA:
_meth = _gemm11 if transB else _gemm10
else:
_meth = _gemm01 if transB else _gemm00
res = _meth(a, b, None, alpha, beta)
if c is None:
return (res.astype(a.dtype),)
if c.shape != res.shape:
raise ValueError(
f"Unable to add shape {c.shape} to shape {res.shape} without broadcast."
)
return (res + c,)
if transA:
_meth = _gemm11 if transB else _gemm10
else:
_meth = _gemm01 if transB else _gemm00
return (_meth(a, b, c, alpha, beta).astype(a.dtype),)
class Gemm_7(OpRun):
def _run(self, a, b, c=None, alpha=None, beta=None, transA=None, transB=None): # type: ignore
if transA:
_meth = _gemm11 if transB else _gemm10
else:
_meth = _gemm01 if transB else _gemm00
return (_meth(a, b, c, alpha, beta).astype(a.dtype),)