171 lines
5.4 KiB
Python
171 lines
5.4 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx.reference.op_run import OpRun
|
|
from onnx.reference.ops.op_concat_from_sequence import _concat_from_sequence
|
|
from onnx.reference.ops.op_dft import _cfft as _dft
|
|
from onnx.reference.ops.op_slice import _slice
|
|
|
|
|
|
def _concat(*args, axis=0): # type: ignore
|
|
return np.concatenate(args, axis=axis)
|
|
|
|
|
|
def _unsqueeze(a, axis): # type: ignore
|
|
try:
|
|
return np.expand_dims(a, axis=axis)
|
|
except TypeError:
|
|
# numpy 1.18 supports axes as a tuple
|
|
if len(axis) == 1:
|
|
return np.expand_dims(a, axis=tuple(axis)[0])
|
|
for x in reversed(axis):
|
|
a = np.expand_dims(a, axis=x)
|
|
return a
|
|
|
|
|
|
def _stft(x, fft_length: int, hop_length, n_frames, window, onesided=False): # type: ignore
|
|
"""Applies one dimensional FFT with window weights.
|
|
|
|
torch defines the number of frames as:
|
|
`n_frames = 1 + (len - n_fft) // hop_length`.
|
|
"""
|
|
last_axis = len(x.shape) - 1 # op.Sub(op.Shape(op.Shape(x)), one)
|
|
axis = [-2]
|
|
axis2 = [-3]
|
|
window_size = window.shape[0]
|
|
|
|
# building frames
|
|
seq = []
|
|
for fs in range(n_frames):
|
|
begin = fs * hop_length
|
|
end = begin + window_size
|
|
sliced_x = _slice(x, np.array([begin]), np.array([end]), axis) # type: ignore
|
|
|
|
# sliced_x may be smaller
|
|
new_dim = sliced_x.shape[-2:-1]
|
|
missing = (window_size - new_dim[0],)
|
|
new_shape = sliced_x.shape[:-2] + missing + sliced_x.shape[-1:]
|
|
cst = np.zeros(new_shape, dtype=x.dtype)
|
|
pad_sliced_x = _concat(sliced_x, cst, axis=-2)
|
|
|
|
# same size
|
|
un_sliced_x = _unsqueeze(pad_sliced_x, axis2)
|
|
seq.append(un_sliced_x)
|
|
|
|
# concatenation
|
|
new_x = _concat_from_sequence(seq, axis=-3, new_axis=0)
|
|
|
|
# calling weighted dft with weights=window
|
|
shape_x = new_x.shape
|
|
shape_x_short = shape_x[:-2]
|
|
shape_x_short_one = tuple(1 for _ in shape_x_short)
|
|
window_shape = (*shape_x_short_one, window_size, 1)
|
|
weights = np.reshape(window, window_shape)
|
|
weighted_new_x = new_x * weights
|
|
|
|
result = _dft(
|
|
weighted_new_x, fft_length, last_axis, onesided=onesided, normalize=False
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
def _istft(x, fft_length: int, hop_length, window, onesided=False): # type: ignore
|
|
"""Reverses of `stft`."""
|
|
zero = [0]
|
|
one = [1]
|
|
two = [2]
|
|
axisf = [-2]
|
|
n_frames = x.shape[-2]
|
|
expected_signal_len = fft_length + hop_length * (n_frames - 1)
|
|
|
|
# building frames
|
|
seqr = []
|
|
seqi = []
|
|
seqc = []
|
|
for fs in range(n_frames):
|
|
begin = fs
|
|
end = fs + 1
|
|
frame_x = np.squeeze(
|
|
_slice(x, np.array([begin]), np.array([end]), axisf),
|
|
axis=axisf[0], # type: ignore
|
|
)
|
|
|
|
# ifft
|
|
ift = _dft(frame_x, fft_length, axis=-1, onesided=onesided, normalize=True)
|
|
n_dims = len(ift.shape)
|
|
|
|
# real part
|
|
n_dims_1 = n_dims - 1
|
|
sliced = _slice(ift, np.array(zero), np.array(one), [n_dims_1]) # type: ignore
|
|
ytmp = np.squeeze(sliced, axis=n_dims_1)
|
|
ctmp = np.full(ytmp.shape, fill_value=1, dtype=x.dtype) * window
|
|
|
|
shape_begin = ytmp.shape[:-1]
|
|
n_left = fs * hop_length
|
|
size = ytmp.shape[-1]
|
|
n_right = expected_signal_len - (n_left + size)
|
|
|
|
left_shape = (*shape_begin, n_left)
|
|
right_shape = (*shape_begin, n_right)
|
|
right = np.zeros(right_shape, dtype=x.dtype)
|
|
left = np.zeros(left_shape, dtype=x.dtype)
|
|
|
|
y = _concat(left, ytmp, right, axis=-1)
|
|
yc = _concat(left, ctmp, right, axis=-1)
|
|
|
|
# imaginary part
|
|
sliced = _slice(ift, np.array(one), np.array(two), [n_dims_1]) # type: ignore
|
|
itmp = np.squeeze(sliced, axis=n_dims_1)
|
|
yi = _concat(left, itmp, right, axis=-1)
|
|
|
|
# append
|
|
seqr.append(_unsqueeze(y, axis=-1))
|
|
seqi.append(_unsqueeze(yi, axis=-1))
|
|
seqc.append(_unsqueeze(yc, axis=-1))
|
|
|
|
# concatenation
|
|
redr = _concat_from_sequence(seqr, axis=-1, new_axis=0)
|
|
redi = _concat_from_sequence(seqi, axis=-1, new_axis=0)
|
|
redc = _concat_from_sequence(seqc, axis=-1, new_axis=0)
|
|
|
|
# unweight
|
|
resr = redr.sum(axis=-1, keepdims=0) # type: ignore
|
|
resi = redi.sum(axis=-1, keepdims=0) # type: ignore
|
|
resc = redc.sum(axis=-1, keepdims=0) # type: ignore
|
|
rr = resr / resc
|
|
ri = resi / resc
|
|
|
|
# Make complex
|
|
rr0 = np.expand_dims(rr, axis=0)
|
|
ri0 = np.expand_dims(ri, axis=0)
|
|
conc = _concat(rr0, ri0, axis=0)
|
|
|
|
# rotation, bring first dimension to the last position
|
|
result_shape = conc.shape
|
|
reshaped_result = conc.reshape((2, -1))
|
|
transposed = np.transpose(reshaped_result, (1, 0))
|
|
other_dimensions = result_shape[1:]
|
|
final_shape = _concat(other_dimensions, two, axis=0)
|
|
final = transposed.reshape(final_shape)
|
|
return final
|
|
|
|
|
|
class STFT(OpRun):
|
|
def _run(self, x, frame_step, window=None, frame_length=None, onesided=None): # type: ignore
|
|
if frame_length is None:
|
|
if window is None:
|
|
frame_length = x.shape[-2]
|
|
else:
|
|
frame_length = window.shape[0]
|
|
hop_length = frame_step
|
|
if window is None:
|
|
window = np.ones((frame_length,), dtype=x.dtype)
|
|
n_frames = 1 + (x.shape[-2] - frame_length) // frame_step
|
|
res = _stft(x, frame_length, hop_length, n_frames, window, onesided=onesided)
|
|
return (res.astype(x.dtype),)
|