# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import numpy as np from onnx.reference.op_run import OpRun from onnx.reference.ops.op_concat_from_sequence import _concat_from_sequence from onnx.reference.ops.op_dft import _cfft as _dft from onnx.reference.ops.op_slice import _slice def _concat(*args, axis=0): # type: ignore return np.concatenate(args, axis=axis) def _unsqueeze(a, axis): # type: ignore try: return np.expand_dims(a, axis=axis) except TypeError: # numpy 1.18 supports axes as a tuple if len(axis) == 1: return np.expand_dims(a, axis=tuple(axis)[0]) for x in reversed(axis): a = np.expand_dims(a, axis=x) return a def _stft(x, fft_length: int, hop_length, n_frames, window, onesided=False): # type: ignore """Applies one dimensional FFT with window weights. torch defines the number of frames as: `n_frames = 1 + (len - n_fft) // hop_length`. """ last_axis = len(x.shape) - 1 # op.Sub(op.Shape(op.Shape(x)), one) axis = [-2] axis2 = [-3] window_size = window.shape[0] # building frames seq = [] for fs in range(n_frames): begin = fs * hop_length end = begin + window_size sliced_x = _slice(x, np.array([begin]), np.array([end]), axis) # type: ignore # sliced_x may be smaller new_dim = sliced_x.shape[-2:-1] missing = (window_size - new_dim[0],) new_shape = sliced_x.shape[:-2] + missing + sliced_x.shape[-1:] cst = np.zeros(new_shape, dtype=x.dtype) pad_sliced_x = _concat(sliced_x, cst, axis=-2) # same size un_sliced_x = _unsqueeze(pad_sliced_x, axis2) seq.append(un_sliced_x) # concatenation new_x = _concat_from_sequence(seq, axis=-3, new_axis=0) # calling weighted dft with weights=window shape_x = new_x.shape shape_x_short = shape_x[:-2] shape_x_short_one = tuple(1 for _ in shape_x_short) window_shape = (*shape_x_short_one, window_size, 1) weights = np.reshape(window, window_shape) weighted_new_x = new_x * weights result = _dft( weighted_new_x, fft_length, last_axis, onesided=onesided, normalize=False ) return result def _istft(x, fft_length: int, hop_length, window, onesided=False): # type: ignore """Reverses of `stft`.""" zero = [0] one = [1] two = [2] axisf = [-2] n_frames = x.shape[-2] expected_signal_len = fft_length + hop_length * (n_frames - 1) # building frames seqr = [] seqi = [] seqc = [] for fs in range(n_frames): begin = fs end = fs + 1 frame_x = np.squeeze( _slice(x, np.array([begin]), np.array([end]), axisf), axis=axisf[0], # type: ignore ) # ifft ift = _dft(frame_x, fft_length, axis=-1, onesided=onesided, normalize=True) n_dims = len(ift.shape) # real part n_dims_1 = n_dims - 1 sliced = _slice(ift, np.array(zero), np.array(one), [n_dims_1]) # type: ignore ytmp = np.squeeze(sliced, axis=n_dims_1) ctmp = np.full(ytmp.shape, fill_value=1, dtype=x.dtype) * window shape_begin = ytmp.shape[:-1] n_left = fs * hop_length size = ytmp.shape[-1] n_right = expected_signal_len - (n_left + size) left_shape = (*shape_begin, n_left) right_shape = (*shape_begin, n_right) right = np.zeros(right_shape, dtype=x.dtype) left = np.zeros(left_shape, dtype=x.dtype) y = _concat(left, ytmp, right, axis=-1) yc = _concat(left, ctmp, right, axis=-1) # imaginary part sliced = _slice(ift, np.array(one), np.array(two), [n_dims_1]) # type: ignore itmp = np.squeeze(sliced, axis=n_dims_1) yi = _concat(left, itmp, right, axis=-1) # append seqr.append(_unsqueeze(y, axis=-1)) seqi.append(_unsqueeze(yi, axis=-1)) seqc.append(_unsqueeze(yc, axis=-1)) # concatenation redr = _concat_from_sequence(seqr, axis=-1, new_axis=0) redi = _concat_from_sequence(seqi, axis=-1, new_axis=0) redc = _concat_from_sequence(seqc, axis=-1, new_axis=0) # unweight resr = redr.sum(axis=-1, keepdims=0) # type: ignore resi = redi.sum(axis=-1, keepdims=0) # type: ignore resc = redc.sum(axis=-1, keepdims=0) # type: ignore rr = resr / resc ri = resi / resc # Make complex rr0 = np.expand_dims(rr, axis=0) ri0 = np.expand_dims(ri, axis=0) conc = _concat(rr0, ri0, axis=0) # rotation, bring first dimension to the last position result_shape = conc.shape reshaped_result = conc.reshape((2, -1)) transposed = np.transpose(reshaped_result, (1, 0)) other_dimensions = result_shape[1:] final_shape = _concat(other_dimensions, two, axis=0) final = transposed.reshape(final_shape) return final class STFT(OpRun): def _run(self, x, frame_step, window=None, frame_length=None, onesided=None): # type: ignore if frame_length is None: if window is None: frame_length = x.shape[-2] else: frame_length = window.shape[0] hop_length = frame_step if window is None: window = np.ones((frame_length,), dtype=x.dtype) n_frames = 1 + (x.shape[-2] - frame_length) // frame_step res = _stft(x, frame_length, hop_length, n_frames, window, onesided=onesided) return (res.astype(x.dtype),)