File size: 3,120 Bytes
d1b91e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch.nn.functional as F
from torch import nn

from modules.vocoder.hifigan.hifigan import HifiGanGenerator, MultiPeriodDiscriminator, MultiScaleDiscriminator, \
    generator_loss, feature_loss, discriminator_loss
from modules.vocoder.hifigan.mel_utils import mel_spectrogram
from modules.vocoder.hifigan.stft_loss import MultiResolutionSTFTLoss
from tasks.vocoder.vocoder_base import VocoderBaseTask
from utils.commons.hparams import hparams
from utils.nn.model_utils import print_arch


class HifiGanTask(VocoderBaseTask):
    def build_model(self):
        self.model_gen = HifiGanGenerator(hparams)
        self.model_disc = nn.ModuleDict()
        self.model_disc['mpd'] = MultiPeriodDiscriminator()
        self.model_disc['msd'] = MultiScaleDiscriminator()
        self.stft_loss = MultiResolutionSTFTLoss()
        print_arch(self.model_gen)
        if hparams['load_ckpt'] != '':
            self.load_ckpt(hparams['load_ckpt'], 'model_gen', 'model_gen', force=True, strict=True)
            self.load_ckpt(hparams['load_ckpt'], 'model_disc', 'model_disc', force=True, strict=True)
        return self.model_gen

    def _training_step(self, sample, batch_idx, optimizer_idx):
        mel = sample['mels']
        y = sample['wavs']
        f0 = sample['f0']
        loss_output = {}
        if optimizer_idx == 0:
            #######################
            #      Generator      #
            #######################
            y_ = self.model_gen(mel, f0)
            y_mel = mel_spectrogram(y.squeeze(1), hparams).transpose(1, 2)
            y_hat_mel = mel_spectrogram(y_.squeeze(1), hparams).transpose(1, 2)
            loss_output['mel'] = F.l1_loss(y_hat_mel, y_mel) * hparams['lambda_mel']
            _, y_p_hat_g, fmap_f_r, fmap_f_g = self.model_disc['mpd'](y, y_, mel)
            _, y_s_hat_g, fmap_s_r, fmap_s_g = self.model_disc['msd'](y, y_, mel)
            loss_output['a_p'] = generator_loss(y_p_hat_g) * hparams['lambda_adv']
            loss_output['a_s'] = generator_loss(y_s_hat_g) * hparams['lambda_adv']
            if hparams['use_fm_loss']:
                loss_output['fm_f'] = feature_loss(fmap_f_r, fmap_f_g)
                loss_output['fm_s'] = feature_loss(fmap_s_r, fmap_s_g)
            if hparams['use_ms_stft']:
                loss_output['sc'], loss_output['mag'] = self.stft_loss(y.squeeze(1), y_.squeeze(1))
            self.y_ = y_.detach()
            self.y_mel = y_mel.detach()
            self.y_hat_mel = y_hat_mel.detach()
        else:
            #######################
            #    Discriminator    #
            #######################
            y_ = self.y_
            # MPD
            y_p_hat_r, y_p_hat_g, _, _ = self.model_disc['mpd'](y, y_.detach(), mel)
            loss_output['r_p'], loss_output['f_p'] = discriminator_loss(y_p_hat_r, y_p_hat_g)
            # MSD
            y_s_hat_r, y_s_hat_g, _, _ = self.model_disc['msd'](y, y_.detach(), mel)
            loss_output['r_s'], loss_output['f_s'] = discriminator_loss(y_s_hat_r, y_s_hat_g)
        total_loss = sum(loss_output.values())
        return total_loss, loss_output