import torch from torch import nn class ISTFT(nn.Module): """ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. See issue: https://github.com/pytorch/pytorch/issues/62323 Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. The NOLA constraint is met as we trim padded samples anyway. Args: n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames. win_length (int): The size of window frame and STFT filter. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length window = torch.hann_window(win_length) self.register_buffer("window", window) def forward(self, spec: torch.Tensor) -> torch.Tensor: """ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. Args: spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, N is the number of frequency bins, and T is the number of time frames. Returns: Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. """ if self.padding == "center": # Fallback to pytorch native implementation return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) elif self.padding == "same": pad = (self.win_length - self.hop_length) // 2 else: raise ValueError("Padding must be 'center' or 'same'.") assert spec.dim() == 3, "Expected a 3D tensor as input" B, N, T = spec.shape # Inverse FFT ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") ifft = ifft * self.window[None, :, None] # Overlap and Add output_size = (T - 1) * self.hop_length + self.win_length y = torch.nn.functional.fold( ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), )[:, 0, 0, pad:-pad] # Window envelope window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) window_envelope = torch.nn.functional.fold( window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), ).squeeze()[pad:-pad] # Normalize assert (window_envelope > 1e-11).all() y = y / window_envelope return y class ISTFTHead(nn.Module): """ ISTFT Head module for predicting STFT complex coefficients. Args: dim (int): Hidden dimension of the model. n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames, which should align with the resolution of the input features. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): super().__init__() out_dim = n_fft + 2 self.out = torch.nn.Linear(dim, out_dim) self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the ISTFTHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x).transpose(1, 2) mag, p = x.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes # wrapping happens here. These two lines produce real and imaginary value x = torch.cos(p) y = torch.sin(p) # recalculating phase here does not produce anything new # only costs time # phase = torch.atan2(y, x) # S = mag * torch.exp(phase * 1j) # better directly produce the complex value S = mag * (x + 1j * y) audio = self.istft(S) return audio