Spaces:
Runtime error
Runtime error
import torch.nn.functional as F | |
from torch import nn | |
from modules.vocoder.hifigan.hifigan import HifiGanGenerator, MultiPeriodDiscriminator, MultiScaleDiscriminator, \ | |
generator_loss, feature_loss, discriminator_loss | |
from modules.vocoder.hifigan.mel_utils import mel_spectrogram | |
from modules.vocoder.hifigan.stft_loss import MultiResolutionSTFTLoss | |
from tasks.vocoder.vocoder_base import VocoderBaseTask | |
from utils.commons.hparams import hparams | |
from utils.nn.model_utils import print_arch | |
class HifiGanTask(VocoderBaseTask): | |
def build_model(self): | |
self.model_gen = HifiGanGenerator(hparams) | |
self.model_disc = nn.ModuleDict() | |
self.model_disc['mpd'] = MultiPeriodDiscriminator() | |
self.model_disc['msd'] = MultiScaleDiscriminator() | |
self.stft_loss = MultiResolutionSTFTLoss() | |
print_arch(self.model_gen) | |
if hparams['load_ckpt'] != '': | |
self.load_ckpt(hparams['load_ckpt'], 'model_gen', 'model_gen', force=True, strict=True) | |
self.load_ckpt(hparams['load_ckpt'], 'model_disc', 'model_disc', force=True, strict=True) | |
return self.model_gen | |
def _training_step(self, sample, batch_idx, optimizer_idx): | |
mel = sample['mels'] | |
y = sample['wavs'] | |
f0 = sample['f0'] | |
loss_output = {} | |
if optimizer_idx == 0: | |
####################### | |
# Generator # | |
####################### | |
y_ = self.model_gen(mel, f0) | |
y_mel = mel_spectrogram(y.squeeze(1), hparams).transpose(1, 2) | |
y_hat_mel = mel_spectrogram(y_.squeeze(1), hparams).transpose(1, 2) | |
loss_output['mel'] = F.l1_loss(y_hat_mel, y_mel) * hparams['lambda_mel'] | |
_, y_p_hat_g, fmap_f_r, fmap_f_g = self.model_disc['mpd'](y, y_, mel) | |
_, y_s_hat_g, fmap_s_r, fmap_s_g = self.model_disc['msd'](y, y_, mel) | |
loss_output['a_p'] = generator_loss(y_p_hat_g) * hparams['lambda_adv'] | |
loss_output['a_s'] = generator_loss(y_s_hat_g) * hparams['lambda_adv'] | |
if hparams['use_fm_loss']: | |
loss_output['fm_f'] = feature_loss(fmap_f_r, fmap_f_g) | |
loss_output['fm_s'] = feature_loss(fmap_s_r, fmap_s_g) | |
if hparams['use_ms_stft']: | |
loss_output['sc'], loss_output['mag'] = self.stft_loss(y.squeeze(1), y_.squeeze(1)) | |
self.y_ = y_.detach() | |
self.y_mel = y_mel.detach() | |
self.y_hat_mel = y_hat_mel.detach() | |
else: | |
####################### | |
# Discriminator # | |
####################### | |
y_ = self.y_ | |
# MPD | |
y_p_hat_r, y_p_hat_g, _, _ = self.model_disc['mpd'](y, y_.detach(), mel) | |
loss_output['r_p'], loss_output['f_p'] = discriminator_loss(y_p_hat_r, y_p_hat_g) | |
# MSD | |
y_s_hat_r, y_s_hat_g, _, _ = self.model_disc['msd'](y, y_.detach(), mel) | |
loss_output['r_s'], loss_output['f_s'] = discriminator_loss(y_s_hat_r, y_s_hat_g) | |
total_loss = sum(loss_output.values()) | |
return total_loss, loss_output | |