Spaces:
Runtime error
Runtime error
from torchlibrosa.stft import STFT, ISTFT, magphase | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from tools.pytorch.modules.pqmf import PQMF | |
class FDomainHelper(nn.Module): | |
def __init__( | |
self, | |
window_size=2048, | |
hop_size=441, | |
center=True, | |
pad_mode='reflect', | |
window='hann', | |
freeze_parameters=True, | |
subband=None, | |
root="/Users/admin/Documents/projects/", | |
): | |
super(FDomainHelper, self).__init__() | |
self.subband = subband | |
if self.subband is None: | |
self.stft = STFT( | |
n_fft=window_size, | |
hop_length=hop_size, | |
win_length=window_size, | |
window=window, | |
center=center, | |
pad_mode=pad_mode, | |
freeze_parameters=freeze_parameters, | |
) | |
self.istft = ISTFT( | |
n_fft=window_size, | |
hop_length=hop_size, | |
win_length=window_size, | |
window=window, | |
center=center, | |
pad_mode=pad_mode, | |
freeze_parameters=freeze_parameters, | |
) | |
else: | |
self.stft = STFT( | |
n_fft=window_size // self.subband, | |
hop_length=hop_size // self.subband, | |
win_length=window_size // self.subband, | |
window=window, | |
center=center, | |
pad_mode=pad_mode, | |
freeze_parameters=freeze_parameters, | |
) | |
self.istft = ISTFT( | |
n_fft=window_size // self.subband, | |
hop_length=hop_size // self.subband, | |
win_length=window_size // self.subband, | |
window=window, | |
center=center, | |
pad_mode=pad_mode, | |
freeze_parameters=freeze_parameters, | |
) | |
if subband is not None and root is not None: | |
self.qmf = PQMF(subband, 64, root) | |
def complex_spectrogram(self, input, eps=0.0): | |
# [batchsize, samples] | |
# return [batchsize, 2, t-steps, f-bins] | |
real, imag = self.stft(input) | |
return torch.cat([real, imag], dim=1) | |
def reverse_complex_spectrogram(self, input, eps=0.0, length=None): | |
# [batchsize, 2[real,imag], t-steps, f-bins] | |
wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length) | |
return wav | |
def spectrogram(self, input, eps=0.0): | |
(real, imag) = self.stft(input.float()) | |
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 | |
def spectrogram_phase(self, input, eps=0.0): | |
(real, imag) = self.stft(input.float()) | |
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 | |
cos = real / mag | |
sin = imag / mag | |
return mag, cos, sin | |
def wav_to_spectrogram_phase(self, input, eps=1e-8): | |
"""Waveform to spectrogram. | |
Args: | |
input: (batch_size, channels_num, segment_samples) | |
Outputs: | |
output: (batch_size, channels_num, time_steps, freq_bins) | |
""" | |
sp_list = [] | |
cos_list = [] | |
sin_list = [] | |
channels_num = input.shape[1] | |
for channel in range(channels_num): | |
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) | |
sp_list.append(mag) | |
cos_list.append(cos) | |
sin_list.append(sin) | |
sps = torch.cat(sp_list, dim=1) | |
coss = torch.cat(cos_list, dim=1) | |
sins = torch.cat(sin_list, dim=1) | |
return sps, coss, sins | |
def spectrogram_phase_to_wav(self, sps, coss, sins, length): | |
channels_num = sps.size()[1] | |
res = [] | |
for i in range(channels_num): | |
res.append( | |
self.istft( | |
sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...], | |
sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...], | |
length, | |
) | |
) | |
res[-1] = res[-1].unsqueeze(1) | |
return torch.cat(res, dim=1) | |
def wav_to_spectrogram(self, input, eps=1e-8): | |
"""Waveform to spectrogram. | |
Args: | |
input: (batch_size,channels_num, segment_samples) | |
Outputs: | |
output: (batch_size, channels_num, time_steps, freq_bins) | |
""" | |
sp_list = [] | |
channels_num = input.shape[1] | |
for channel in range(channels_num): | |
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) | |
output = torch.cat(sp_list, dim=1) | |
return output | |
def spectrogram_to_wav(self, input, spectrogram, length=None): | |
"""Spectrogram to waveform. | |
Args: | |
input: (batch_size, segment_samples, channels_num) | |
spectrogram: (batch_size, channels_num, time_steps, freq_bins) | |
Outputs: | |
output: (batch_size, segment_samples, channels_num) | |
""" | |
channels_num = input.shape[1] | |
wav_list = [] | |
for channel in range(channels_num): | |
(real, imag) = self.stft(input[:, channel, :]) | |
(_, cos, sin) = magphase(real, imag) | |
wav_list.append( | |
self.istft( | |
spectrogram[:, channel : channel + 1, :, :] * cos, | |
spectrogram[:, channel : channel + 1, :, :] * sin, | |
length, | |
) | |
) | |
output = torch.stack(wav_list, dim=1) | |
return output | |
# todo the following code is not bug free! | |
def wav_to_complex_spectrogram(self, input, eps=0.0): | |
# [batchsize , channels, samples] | |
# [batchsize, 2[real,imag]*channels, t-steps, f-bins] | |
res = [] | |
channels_num = input.shape[1] | |
for channel in range(channels_num): | |
res.append(self.complex_spectrogram(input[:, channel, :], eps=eps)) | |
return torch.cat(res, dim=1) | |
def complex_spectrogram_to_wav(self, input, eps=0.0, length=None): | |
# [batchsize, 2[real,imag]*channels, t-steps, f-bins] | |
# return [batchsize, channels, samples] | |
channels = input.size()[1] // 2 | |
wavs = [] | |
for i in range(channels): | |
wavs.append( | |
self.reverse_complex_spectrogram( | |
input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length | |
) | |
) | |
wavs[-1] = wavs[-1].unsqueeze(1) | |
return torch.cat(wavs, dim=1) | |
def wav_to_complex_subband_spectrogram(self, input, eps=0.0): | |
# [batchsize, channels, samples] | |
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] | |
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] | |
subspec = self.wav_to_complex_spectrogram(subwav) | |
return subspec | |
def complex_subband_spectrogram_to_wav(self, input, eps=0.0): | |
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] | |
# [batchsize, channels, samples] | |
subwav = self.complex_spectrogram_to_wav(input) | |
data = self.qmf.synthesis(subwav) | |
return data | |
def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8): | |
""" | |
:param input: | |
:param eps: | |
:return: | |
loss = torch.nn.L1Loss() | |
model = FDomainHelper(subband=4) | |
data = torch.randn((3,1, 44100*3)) | |
sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data) | |
wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4) | |
print(loss(data,wav)) | |
print(torch.max(torch.abs(data-wav))) | |
""" | |
# [batchsize, channels, samples] | |
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] | |
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] | |
sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps) | |
return sps, coss, sins | |
def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0): | |
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] | |
# [batchsize, channels, samples] | |
subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length) | |
data = self.qmf.synthesis(subwav) | |
return data | |
if __name__ == "__main__": | |
# from thop import profile | |
# from thop import clever_format | |
# from tools.file.wav import * | |
# import time | |
# | |
# wav = torch.randn((1,2,44100)) | |
# model = FDomainHelper() | |
from tools.file.wav import * | |
loss = torch.nn.L1Loss() | |
model = FDomainHelper() | |
data = torch.randn((3, 1, 44100 * 5)) | |
sps = model.wav_to_complex_spectrogram(data) | |
print(sps.size()) | |
wav = model.complex_spectrogram_to_wav(sps, 44100 * 5) | |
print(loss(data, wav)) | |
print(torch.max(torch.abs(data - wav))) | |