122 lines
3.8 KiB
Python
122 lines
3.8 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx.reference.op_run import OpRun
|
|
|
|
|
|
def _fft(x: np.ndarray, fft_length: int, axis: int) -> np.ndarray:
|
|
"""Compute the FFT return the real representation of the complex result."""
|
|
transformed = np.fft.fft(x, n=fft_length, axis=axis)
|
|
real_frequencies = np.real(transformed)
|
|
imaginary_frequencies = np.imag(transformed)
|
|
return np.concatenate(
|
|
(real_frequencies[..., np.newaxis], imaginary_frequencies[..., np.newaxis]),
|
|
axis=-1,
|
|
)
|
|
|
|
|
|
def _cfft(
|
|
x: np.ndarray,
|
|
fft_length: int,
|
|
axis: int,
|
|
onesided: bool,
|
|
normalize: bool,
|
|
) -> np.ndarray:
|
|
if x.shape[-1] == 1:
|
|
# The input contains only the real part
|
|
signal = x
|
|
else:
|
|
# The input is a real representation of a complex signal
|
|
slices = [slice(0, x) for x in x.shape]
|
|
slices[-1] = slice(0, x.shape[-1], 2)
|
|
real = x[tuple(slices)]
|
|
slices[-1] = slice(1, x.shape[-1], 2)
|
|
imag = x[tuple(slices)]
|
|
signal = real + 1j * imag
|
|
|
|
complex_signals = np.squeeze(signal, -1)
|
|
result = _fft(complex_signals, fft_length, axis=axis)
|
|
# Post process the result based on arguments
|
|
if onesided:
|
|
slices = [slice(0, a) for a in result.shape]
|
|
slices[axis] = slice(0, result.shape[axis] // 2 + 1)
|
|
result = result[tuple(slices)]
|
|
if normalize:
|
|
result /= fft_length
|
|
return result
|
|
|
|
|
|
def _ifft(x: np.ndarray, fft_length: int, axis: int, onesided: bool) -> np.ndarray:
|
|
signals = np.fft.ifft(x, fft_length, axis=axis)
|
|
real_signals = np.real(signals)
|
|
imaginary_signals = np.imag(signals)
|
|
merged = np.concatenate(
|
|
(real_signals[..., np.newaxis], imaginary_signals[..., np.newaxis]),
|
|
axis=-1,
|
|
)
|
|
if onesided:
|
|
slices = [slice(a) for a in merged.shape]
|
|
slices[axis] = slice(0, merged.shape[axis] // 2 + 1)
|
|
return merged[tuple(slices)]
|
|
return merged
|
|
|
|
|
|
def _cifft(
|
|
x: np.ndarray, fft_length: int, axis: int, onesided: bool = False
|
|
) -> np.ndarray:
|
|
if x.shape[-1] == 1:
|
|
frequencies = x
|
|
else:
|
|
slices = [slice(0, x) for x in x.shape]
|
|
slices[-1] = slice(0, x.shape[-1], 2)
|
|
real = x[tuple(slices)]
|
|
slices[-1] = slice(1, x.shape[-1], 2)
|
|
imag = x[tuple(slices)]
|
|
frequencies = real + 1j * imag
|
|
complex_frequencies = np.squeeze(frequencies, -1)
|
|
return _ifft(complex_frequencies, fft_length, axis=axis, onesided=onesided)
|
|
|
|
|
|
class DFT_17(OpRun):
|
|
def _run(
|
|
self,
|
|
x: np.ndarray,
|
|
dft_length: int | None = None,
|
|
axis: int = 1,
|
|
inverse: bool = False,
|
|
onesided: bool = False,
|
|
) -> tuple[np.ndarray]: # type: ignore
|
|
# Convert to positive axis
|
|
axis = axis % len(x.shape)
|
|
if dft_length is None:
|
|
dft_length = x.shape[axis]
|
|
if inverse: # type: ignore
|
|
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
|
|
else:
|
|
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
|
|
return (result.astype(x.dtype),)
|
|
|
|
|
|
class DFT_20(OpRun):
|
|
def _run(
|
|
self,
|
|
x: np.ndarray,
|
|
dft_length: int | None = None,
|
|
axis: int = -2,
|
|
inverse: bool = False,
|
|
onesided: bool = False,
|
|
) -> tuple[np.ndarray]: # type: ignore
|
|
# Convert to positive axis
|
|
axis = axis % len(x.shape)
|
|
if dft_length is None:
|
|
dft_length = x.shape[axis]
|
|
if inverse: # type: ignore
|
|
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
|
|
else:
|
|
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
|
|
return (result.astype(x.dtype),)
|