35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx.reference.op_run import OpRun
|
|
|
|
|
|
class PRelu(OpRun):
|
|
def _run(self, x, slope): # type: ignore
|
|
try:
|
|
return (np.where(x > 0, x, x * slope).astype(x.dtype),)
|
|
except ValueError:
|
|
# Broadcast did not work according to numpy.
|
|
# The logic is then the following, if slope has d elements,
|
|
# the following code is looking for d in x.shape. If it is found
|
|
# only once, x * slope is broadcasted on any other dimension.
|
|
# Otherwise, it raises e.
|
|
if len(slope.shape) == 1:
|
|
dim = slope.shape[0]
|
|
new_shape = []
|
|
n = 0
|
|
for d in x.shape:
|
|
if d == dim:
|
|
new_shape.append(d)
|
|
n += 1
|
|
else:
|
|
new_shape.append(1)
|
|
if n == 1:
|
|
xs = x * slope.reshape(tuple(new_shape))
|
|
return (np.where(x > 0, x, xs).astype(x.dtype),)
|
|
raise
|