Spaces:
Paused
Paused
import torch | |
from torch import nn | |
from torch.nn.utils import weight_norm | |
from TTS.utils.io import load_fsspec | |
from TTS.vocoder.layers.melgan import ResidualStack | |
class MelganGenerator(nn.Module): | |
def __init__( | |
self, | |
in_channels=80, | |
out_channels=1, | |
proj_kernel=7, | |
base_channels=512, | |
upsample_factors=(8, 8, 2, 2), | |
res_kernel=3, | |
num_res_blocks=3, | |
): | |
super().__init__() | |
# assert model parameters | |
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." | |
# setup additional model parameters | |
base_padding = (proj_kernel - 1) // 2 | |
act_slope = 0.2 | |
self.inference_padding = 2 | |
# initial layer | |
layers = [] | |
layers += [ | |
nn.ReflectionPad1d(base_padding), | |
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), | |
] | |
# upsampling layers and residual stacks | |
for idx, upsample_factor in enumerate(upsample_factors): | |
layer_in_channels = base_channels // (2**idx) | |
layer_out_channels = base_channels // (2 ** (idx + 1)) | |
layer_filter_size = upsample_factor * 2 | |
layer_stride = upsample_factor | |
layer_output_padding = upsample_factor % 2 | |
layer_padding = upsample_factor // 2 + layer_output_padding | |
layers += [ | |
nn.LeakyReLU(act_slope), | |
weight_norm( | |
nn.ConvTranspose1d( | |
layer_in_channels, | |
layer_out_channels, | |
layer_filter_size, | |
stride=layer_stride, | |
padding=layer_padding, | |
output_padding=layer_output_padding, | |
bias=True, | |
) | |
), | |
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), | |
] | |
layers += [nn.LeakyReLU(act_slope)] | |
# final layer | |
layers += [ | |
nn.ReflectionPad1d(base_padding), | |
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), | |
nn.Tanh(), | |
] | |
self.layers = nn.Sequential(*layers) | |
def forward(self, c): | |
return self.layers(c) | |
def inference(self, c): | |
c = c.to(self.layers[1].weight.device) | |
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") | |
return self.layers(c) | |
def remove_weight_norm(self): | |
for _, layer in enumerate(self.layers): | |
if len(layer.state_dict()) != 0: | |
try: | |
nn.utils.remove_weight_norm(layer) | |
except ValueError: | |
layer.remove_weight_norm() | |
def load_checkpoint( | |
self, config, checkpoint_path, eval=False, cache=False | |
): # pylint: disable=unused-argument, redefined-builtin | |
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) | |
self.load_state_dict(state["model"]) | |
if eval: | |
self.eval() | |
assert not self.training | |
self.remove_weight_norm() | |