|
import torch.nn as nn |
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import math |
|
from torchlibrosa.stft import magphase |
|
|
|
|
|
def init_layer(layer): |
|
"""Initialize a Linear or Convolutional layer. """ |
|
nn.init.xavier_uniform_(layer.weight) |
|
|
|
if hasattr(layer, "bias"): |
|
if layer.bias is not None: |
|
layer.bias.data.fill_(0.0) |
|
|
|
|
|
def init_bn(bn): |
|
"""Initialize a Batchnorm layer. """ |
|
bn.bias.data.fill_(0.0) |
|
bn.weight.data.fill_(1.0) |
|
|
|
|
|
def init_embedding(layer): |
|
"""Initialize a Linear or Convolutional layer. """ |
|
nn.init.uniform_(layer.weight, -1., 1.) |
|
|
|
if hasattr(layer, 'bias'): |
|
if layer.bias is not None: |
|
layer.bias.data.fill_(0.) |
|
|
|
|
|
def init_gru(rnn): |
|
"""Initialize a GRU layer. """ |
|
|
|
def _concat_init(tensor, init_funcs): |
|
(length, fan_out) = tensor.shape |
|
fan_in = length // len(init_funcs) |
|
|
|
for (i, init_func) in enumerate(init_funcs): |
|
init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) |
|
|
|
def _inner_uniform(tensor): |
|
fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") |
|
nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) |
|
|
|
for i in range(rnn.num_layers): |
|
_concat_init( |
|
getattr(rnn, "weight_ih_l{}".format(i)), |
|
[_inner_uniform, _inner_uniform, _inner_uniform], |
|
) |
|
torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) |
|
|
|
_concat_init( |
|
getattr(rnn, "weight_hh_l{}".format(i)), |
|
[_inner_uniform, _inner_uniform, nn.init.orthogonal_], |
|
) |
|
torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) |
|
|
|
|
|
def act(x, activation): |
|
if activation == "relu": |
|
return F.relu_(x) |
|
|
|
elif activation == "leaky_relu": |
|
return F.leaky_relu_(x, negative_slope=0.01) |
|
|
|
elif activation == "swish": |
|
return x * torch.sigmoid(x) |
|
|
|
else: |
|
raise Exception("Incorrect activation!") |
|
|
|
|
|
class Base: |
|
def __init__(self): |
|
pass |
|
|
|
def spectrogram(self, input, eps=0.): |
|
(real, imag) = self.stft(input) |
|
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 |
|
|
|
def spectrogram_phase(self, input, eps=0.): |
|
(real, imag) = self.stft(input) |
|
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 |
|
cos = real / mag |
|
sin = imag / mag |
|
return mag, cos, sin |
|
|
|
|
|
def wav_to_spectrogram_phase(self, input, eps=1e-10): |
|
"""Waveform to spectrogram. |
|
|
|
Args: |
|
input: (batch_size, segment_samples, channels_num) |
|
|
|
Outputs: |
|
output: (batch_size, channels_num, time_steps, freq_bins) |
|
""" |
|
sp_list = [] |
|
cos_list = [] |
|
sin_list = [] |
|
channels_num = input.shape[1] |
|
for channel in range(channels_num): |
|
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) |
|
sp_list.append(mag) |
|
cos_list.append(cos) |
|
sin_list.append(sin) |
|
|
|
sps = torch.cat(sp_list, dim=1) |
|
coss = torch.cat(cos_list, dim=1) |
|
sins = torch.cat(sin_list, dim=1) |
|
return sps, coss, sins |
|
|
|
def wav_to_spectrogram(self, input, eps=0.): |
|
"""Waveform to spectrogram. |
|
|
|
Args: |
|
input: (batch_size, segment_samples, channels_num) |
|
|
|
Outputs: |
|
output: (batch_size, channels_num, time_steps, freq_bins) |
|
""" |
|
sp_list = [] |
|
channels_num = input.shape[1] |
|
for channel in range(channels_num): |
|
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) |
|
|
|
output = torch.cat(sp_list, dim=1) |
|
return output |
|
|
|
|
|
def spectrogram_to_wav(self, input, spectrogram, length=None): |
|
"""Spectrogram to waveform. |
|
|
|
Args: |
|
input: (batch_size, segment_samples, channels_num) |
|
spectrogram: (batch_size, channels_num, time_steps, freq_bins) |
|
|
|
Outputs: |
|
output: (batch_size, segment_samples, channels_num) |
|
""" |
|
channels_num = input.shape[1] |
|
wav_list = [] |
|
for channel in range(channels_num): |
|
(real, imag) = self.stft(input[:, channel, :]) |
|
(_, cos, sin) = magphase(real, imag) |
|
wav_list.append(self.istft(spectrogram[:, channel : channel + 1, :, :] * cos, |
|
spectrogram[:, channel : channel + 1, :, :] * sin, length)) |
|
|
|
output = torch.stack(wav_list, dim=1) |
|
return output |
|
|