Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
class ISTFT(nn.Module): | |
""" | |
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with | |
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. | |
See issue: https://github.com/pytorch/pytorch/issues/62323 | |
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. | |
The NOLA constraint is met as we trim padded samples anyway. | |
Args: | |
n_fft (int): Size of Fourier transform. | |
hop_length (int): The distance between neighboring sliding window frames. | |
win_length (int): The size of window frame and STFT filter. | |
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
""" | |
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): | |
super().__init__() | |
if padding not in ["center", "same"]: | |
raise ValueError("Padding must be 'center' or 'same'.") | |
self.padding = padding | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.win_length = win_length | |
window = torch.hann_window(win_length) | |
self.register_buffer("window", window) | |
def forward(self, spec: torch.Tensor) -> torch.Tensor: | |
""" | |
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. | |
Args: | |
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, | |
N is the number of frequency bins, and T is the number of time frames. | |
Returns: | |
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. | |
""" | |
if self.padding == "center": | |
# Fallback to pytorch native implementation | |
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) | |
elif self.padding == "same": | |
pad = (self.win_length - self.hop_length) // 2 | |
else: | |
raise ValueError("Padding must be 'center' or 'same'.") | |
assert spec.dim() == 3, "Expected a 3D tensor as input" | |
B, N, T = spec.shape | |
# Inverse FFT | |
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") | |
ifft = ifft * self.window[None, :, None] | |
# Overlap and Add | |
output_size = (T - 1) * self.hop_length + self.win_length | |
y = torch.nn.functional.fold( | |
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), | |
)[:, 0, 0, pad:-pad] | |
# Window envelope | |
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) | |
window_envelope = torch.nn.functional.fold( | |
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), | |
).squeeze()[pad:-pad] | |
# Normalize | |
assert (window_envelope > 1e-11).all() | |
y = y / window_envelope | |
return y | |
class ISTFTHead(nn.Module): | |
""" | |
ISTFT Head module for predicting STFT complex coefficients. | |
Args: | |
dim (int): Hidden dimension of the model. | |
n_fft (int): Size of Fourier transform. | |
hop_length (int): The distance between neighboring sliding window frames, which should align with | |
the resolution of the input features. | |
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". | |
""" | |
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): | |
super().__init__() | |
out_dim = n_fft + 2 | |
self.out = torch.nn.Linear(dim, out_dim) | |
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass of the ISTFTHead module. | |
Args: | |
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, | |
L is the sequence length, and H denotes the model dimension. | |
Returns: | |
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. | |
""" | |
x = self.out(x).transpose(1, 2) | |
mag, p = x.chunk(2, dim=1) | |
mag = torch.exp(mag) | |
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes | |
# wrapping happens here. These two lines produce real and imaginary value | |
x = torch.cos(p) | |
y = torch.sin(p) | |
# recalculating phase here does not produce anything new | |
# only costs time | |
# phase = torch.atan2(y, x) | |
# S = mag * torch.exp(phase * 1j) | |
# better directly produce the complex value | |
S = mag * (x + 1j * y) | |
audio = self.istft(S) | |
return audio |