File size: 6,316 Bytes
9206300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn.functional as F
from text_to_speech.modules.tts.fs2_orig import FastSpeech2Orig
from tasks.tts.dataset_utils import FastSpeechDataset
from tasks.tts.fs import FastSpeechTask
from text_to_speech.utils.commons.dataset_utils import collate_1d, collate_2d
from text_to_speech.utils.commons.hparams import hparams
from text_to_speech.utils.plot.plot import spec_to_figure
import numpy as np


class FastSpeech2OrigDataset(FastSpeechDataset):
    def __init__(self, prefix, shuffle=False, items=None, data_dir=None):
        super().__init__(prefix, shuffle, items, data_dir)
        self.pitch_type = hparams.get('pitch_type')

    def __getitem__(self, index):
        sample = super().__getitem__(index)
        item = self._get_item(index)
        hparams = self.hparams
        mel = sample['mel']
        T = mel.shape[0]
        sample['energy'] = (mel.exp() ** 2).sum(-1).sqrt()
        if hparams['use_pitch_embed'] and self.pitch_type == 'cwt':
            cwt_spec = torch.Tensor(item['cwt_spec'])[:T]
            f0_mean = item.get('f0_mean', item.get('cwt_mean'))
            f0_std = item.get('f0_std', item.get('cwt_std'))
            sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
        return sample

    def collater(self, samples):
        if len(samples) == 0:
            return {}
        batch = super().collater(samples)
        if hparams['use_pitch_embed']:
            energy = collate_1d([s['energy'] for s in samples], 0.0)
        else:
            energy = None
        batch.update({'energy': energy})
        if self.pitch_type == 'cwt':
            cwt_spec = collate_2d([s['cwt_spec'] for s in samples])
            f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
            f0_std = torch.Tensor([s['f0_std'] for s in samples])
            batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
        return batch


class FastSpeech2OrigTask(FastSpeechTask):
    def __init__(self):
        super(FastSpeech2OrigTask, self).__init__()
        self.dataset_cls = FastSpeech2OrigDataset

    def build_tts_model(self):
        dict_size = len(self.token_encoder)
        self.model = FastSpeech2Orig(dict_size, hparams)

    def run_model(self, sample, infer=False, *args, **kwargs):
        txt_tokens = sample['txt_tokens']  # [B, T_t]
        spk_embed = sample.get('spk_embed')
        spk_id = sample.get('spk_ids')
        if not infer:
            target = sample['mels']  # [B, T_s, 80]
            mel2ph = sample['mel2ph']  # [B, T_s]
            f0 = sample.get('f0')
            uv = sample.get('uv')
            energy = sample.get('energy')
            output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
                                f0=f0, uv=uv, energy=energy, infer=False)
            losses = {}
            self.add_mel_loss(output['mel_out'], target, losses)
            self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
            if hparams['use_pitch_embed']:
                self.add_pitch_loss(output, sample, losses)
            if hparams['use_energy_embed']:
                self.add_energy_loss(output, sample, losses)
            return losses, output
        else:
            mel2ph, uv, f0, energy = None, None, None, None
            use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
            use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0'])
            use_gt_energy = kwargs.get('infer_use_gt_energy', hparams['use_gt_energy'])
            if use_gt_dur:
                mel2ph = sample['mel2ph']
            if use_gt_f0:
                f0 = sample['f0']
                uv = sample['uv']
            if use_gt_energy:
                energy = sample['energy']
            output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
                                f0=f0, uv=uv, energy=energy, infer=True)
            return output

    def save_valid_result(self, sample, batch_idx, model_out):
        super(FastSpeech2OrigTask, self).save_valid_result(sample, batch_idx, model_out)
        self.plot_cwt(batch_idx, model_out['cwt'], sample['cwt_spec'])

    def plot_cwt(self, batch_idx, cwt_out, cwt_gt=None):
        if len(cwt_out.shape) == 3:
            cwt_out = cwt_out[0]
        if isinstance(cwt_out, torch.Tensor):
            cwt_out = cwt_out.cpu().numpy()
        if cwt_gt is not None:
            if len(cwt_gt.shape) == 3:
                cwt_gt = cwt_gt[0]
            if isinstance(cwt_gt, torch.Tensor):
                cwt_gt = cwt_gt.cpu().numpy()
            cwt_out = np.concatenate([cwt_out, cwt_gt], -1)
        name = f'cwt_val_{batch_idx}'
        self.logger.add_figure(name, spec_to_figure(cwt_out), self.global_step)

    def add_pitch_loss(self, output, sample, losses):
        if hparams['pitch_type'] == 'cwt':
            cwt_spec = sample[f'cwt_spec']
            f0_mean = sample['f0_mean']
            uv = sample['uv']
            mel2ph = sample['mel2ph']
            f0_std = sample['f0_std']
            cwt_pred = output['cwt'][:, :, :10]
            f0_mean_pred = output['f0_mean']
            f0_std_pred = output['f0_std']
            nonpadding = (mel2ph != 0).float()
            losses['C'] = F.l1_loss(cwt_pred, cwt_spec) * hparams['lambda_f0']
            if hparams['use_uv']:
                assert output['cwt'].shape[-1] == 11
                uv_pred = output['cwt'][:, :, -1]
                losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none')
                                * nonpadding).sum() / nonpadding.sum() * hparams['lambda_uv']
            losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0']
            losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0']
        else:
            super(FastSpeech2OrigTask, self).add_pitch_loss(output, sample, losses)

    def add_energy_loss(self, output, sample, losses):
        energy_pred, energy = output['energy_pred'], sample['energy']
        nonpadding = (energy != 0).float()
        loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum()
        loss = loss * hparams['lambda_energy']
        losses['e'] = loss